import SimpleITK
import config
import os
import re
import numpy
import sklearn.cluster
import fitData
import getData
import geometry


def loadTime(r,xsetup):

   tempDir=config.getTempDir(xsetup)
   code=config.getCode(r,xsetup)

   timeFile=os.path.join(tempDir,code,f'{code}_Dummy.csv')
   if not os.path.isfile(timeFile):
      timeFile=os.path.join(tempDir,code,f'{code}_Dummy.mcsv')
   with open(timeFile,'r') as f:
      lines=[re.sub('\n','',x) for x in f.readlines()]
      lines=[x for x in lines if x[0]!='#']
      v=[[float(x) for x in y.split(',')] for y in lines]
      t=numpy.array([x[0] for x in v])
      dt=numpy.array([x[1] for x in v])
      #convert to seconds from miliseconds
      t*=1e-3
      #convert to seconds from miliseconds
      dt*=1e-3
   return t,dt


def loadData(r,xsetup,returnGeometry=False):

   #load data from nrrd

   t,dt=loadTime(r,xsetup)
   c1=len(t)
   nodes=[config.getNodeName(r,xsetup,'NM',i) for i in range(0,c1)]
   files=[f'{x}.nrrd' for x in nodes]
   files=[os.path.join(config.getLocalDir(r,xsetup),x) for x in files]
   filesPresent=[os.path.isfile(x) for x in files]
   #possible side exit when missing files are encountered

   xdata=[SimpleITK.ReadImage(x) for x in files]
   geo=geometry.getGeometry(xdata[0])

   xdata=[SimpleITK.GetArrayFromImage(x) for x in xdata]
   #create new array to hold all data
   data=numpy.zeros((*xdata[0].shape,len(xdata)))
   for i in range(len(xdata)):
      data[...,i]=numpy.array(xdata[i])/dt[i]

   if returnGeometry:
      return data,geo
   return data

def getTACAtPixels(data,loc):
    #data is 4D array, loc are indices as returned by numpy.nonzero()
    #return nxm array where n is number of time points and m is number of locations
    #to get TAC for i-th location, do v[:,i]
    loc1=[loc+(numpy.array([i,i,i,i]),) for i in range(data.shape[3])]
    v=[data[x] for x in loc1]
    return numpy.vstack(v)



def loadCT(r,xsetup,returnGeometry=False):
   file='{}.nrrd'.format(config.getNodeName(r,xsetup,'CT'))
   file=getData.getLocalPath(r,xsetup,file)
   xd=SimpleITK.ReadImage(file)
   geo=geometry.getGeometry(xd)
   xd=SimpleITK.GetArrayFromImage(xd)
   if returnGeometry:
      return xd,geo
   return xd


def saveCenters(r,xsetup,data=None,ir=0):
   
   #if not data:
   spect,gSPECT=loadData(r,xsetup,returnGeometry=True)
   ct,gCT=loadCT(r,xsetup,returnGeometry=True)
   A=spect.reshape(-1,spect.shape[3])
   nclass=xsetup['nclass'][0]

   #kmeans0 = sklearn.cluster.KMeans(n_clusters=k, random_state=0, n_init="auto").fit(A)
   #cmeans = sklearn.mixture.GaussianMixture(n_components=k, random_state=0, n_init=1).fit(A)
   kmeans = sklearn.cluster.BisectingKMeans(n_clusters=nclass, random_state=0, n_init=1).fit(A)
   centers=kmeans.cluster_centers_
   u=kmeans.labels_
   u=u.reshape(spect.shape[0:3])
   print(u.shape)
   code=config.getCode(r,xsetup)
   for i in range(nclass):
      #ui=(u==i)*numpy.ones(u.shape)
      #file=getData.getLocalPath(r,xsetup,config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i))
      #img=SimpleITK.GetImageFromArray(ui)
      #SimpleITK.WriteImage(img, file)
      cFile=getData.getLocalPath(r,xsetup,config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i))
      numpy.savetxt(cFile,centers[i:i+1,:],delimiter=',')
   #write center map as NRRD file in spect geometry:
   file=getData.getLocalPath(r,xsetup,config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName='SPECT'))
   img=SimpleITK.GetImageFromArray(u)
   img.SetOrigin(gSPECT.origin)
   img.SetSpacing(gSPECT.spacing)
   img.SetDirection(numpy.ravel(gSPECT.direction))
   SimpleITK.WriteImage(img, file)

   #also in CT geometry
   if True:
      file1=getData.getLocalPath(r,xsetup,config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName='CT'))
      #method nearest perservse content, which should be id of the k-means group center
      u1=geometry.toSpace2(u,gSPECT,ct,gCT,method='nearest')
      img1=SimpleITK.GetImageFromArray(u1)
      img1.SetOrigin(gCT.origin)
      img1.SetSpacing(gCT.spacing)
      img1.SetDirection(numpy.ravel(gCT.direction))
      SimpleITK.WriteImage(img1, file1)

   #write center map as numpy array
   qFile=getData.getLocalPath(r,xsetup,config.getPattern('centerMap',code=code,nclass=nclass,ir=ir,ic=i))
   usave=numpy.zeros(kmeans.labels_.shape[0]+3)
   usave[0:3]=spect.shape[0:3]
   usave[3:]=kmeans.labels_
   numpy.savetxt(qFile,usave,delimiter=',')
   
def loadCenters(r,xsetup,ir=0):
   nclass=xsetup['nclass'][0]
   centers=numpy.array(0)
   for i in range(nclass):
      cFile=os.path.join(config.getLocalDir(r,xsetup),config.getCenter(r,xsetup,nclass,ir,i))
      #row
      c=numpy.loadtxt(cFile,delimiter=',')
      if len(centers.shape)==0:
         centers=numpy.zeros((nclass,len(c)))
      centers[i,:]=c
   return centers

def loadCenterMap(r,xsetup,ir=0):

   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   qFile=getData.getLocalPath(r,xsetup,config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
   usave=numpy.loadtxt(qFile,delimiter=',')
   shape=[int(x) for x in usave[0:3]]
   u=numpy.reshape(usave[3:],shape)
   return u

def loadCenterMapNRRD(r,xsetup,ir=0):

   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   md=['CT','SPECT']
   files={x:getData.getLocalPath(r,xsetup,config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x)) for x in md}
   xd={x:SimpleITK.ReadImage(files[x]) for x in files}
   nd={x:SimpleITK.GetArrayFromImage(xd[x]) for x in xd}
   return nd['SPECT'],nd['CT']


def saveIVF(r,xsetup,ir=0,nfit=30,nbatch=30,qLambda=0):
   #fit IVF from centers in realization ir, perform nfit optimized fits where nbatch is used
   #to find best among nbatch trials (in total, nfit*nbatch fits will be made)

   #requires saveCenters to be run prior to execution 
   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   t,dt=loadTime(r,xsetup)
   centers=loadCenters(r,xsetup,ir)
   m,samples=fitData.fitIVFGlobal(t,centers,nfit=nfit,qLambda=qLambda)

   fm=m*numpy.ones(samples.shape[1])
   fw=numpy.vstack((fm,samples))
   f=getData.getLocalPath(r,xsetup,config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
   print(f'Saving to {f}')
   numpy.savetxt(f,fw,delimiter=',')
   
def readIVF(r,xsetup,ir=0,qLambda=0):
   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   f=getData.getLocalPath(r,xsetup,config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
   fw=numpy.loadtxt(f,delimiter=',')
   m=int(fw[0,0])
   samples=fw[1:,:]
   return m,samples

def saveSamples(r,xsetup,samples,m,name,iseg=0,ir=0):
   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   fm=numpy.zeros(samples.shape[1])
   n=numpy.min(numpy.array([len(m),samples.shape[0]]))
   for i in range(n):
      fm[i]=m[i]+1
   fw=numpy.vstack((fm,samples))
   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
   print(f'Saving samples to {f}')
   numpy.savetxt(f,fw,delimiter=',')

def readSamples(r,xsetup,name,iseg=0,ir=0):
   m=[]
   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
   print(f'Reading from  {f}')
   fw=numpy.loadtxt(f,delimiter=',')
   samples=fw[1:,:]
   fm=samples[0,:]
   for i in range(fm.shape[0]):
      if fm[i]==0:
         break
      m.append(fm[i]-1)
   return m,samples

def saveTAC(r,xsetup,tac,name,iseg=0,ir=0):
   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
   print(f'Saving samples to {f}')
   numpy.savetxt(f,tac,delimiter=',')

def readTAC(r,xsetup,name,iseg=0,ir=0):
   nclass=xsetup['nclass'][0]
   code=config.getCode(r,xsetup)
   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
   print(f'Reading from {f}')
   return numpy.loadtxt(f,delimiter=',')