Browse Source

Adding port from matlab fitting to python fitting, new python code based on global optimization with lambda adjusted regularizer

Andrej 1 year ago
parent
commit
504c2ae9db

+ 340 - 0
pythonScripts/fitData.py

@@ -0,0 +1,340 @@
+import numpy
+import scipy.optimize
+import functools
+import fitModel
+
+class initialValueGenerator:
+   
+   def __init__(self):
+      self.map={'gaus':initialValueGenerator.drawGauss,
+      'loggaus':initialValueGenerator.drawLogGauss,
+      'poisson':initialValueGenerator.drawPoisson,
+      'flat':initialValueGenerator.drawFlat,
+      'const':initialValueGenerator.drawConst}
+
+   def add(self,gtype='gaus',parameters=[0,1],bounds=[-numpy.inf,numpy.inf]):
+      v={'type':gtype,'parameters':parameters,'bounds':bounds}
+      try:
+         self.vals.append(v)
+      except AttributeError:
+         self.vals=[v]
+
+   def getN(self):
+      return len(self.vals)
+   
+   def drawGauss(p):
+      return numpy.random.normal(p[0],p[1])
+
+   def drawLogGauss(p):
+      return numpy.power(10,initialValueGenerator.drawGauss(p))
+
+   def drawPoisson(p):
+      return numpy.random.poisson(p[0])
+      
+   def drawFlat(p):
+      return p[0]+numpy.random.random()*(p[1]-p[0])
+
+   def drawConst(p):
+      return p[0]
+
+   def draw(self,typ,p,lb,ub):
+      f=self.map[typ]
+      while True:
+         x=f(p)
+         if x<lb:
+            continue
+         if x>ub:
+            continue
+         return x
+    
+
+   def generate(self):
+      n=len(self.vals)
+      par=numpy.zeros(n)
+      lb=numpy.zeros(n)
+      ub=numpy.zeros(n)
+      for i in range(n):
+         x=self.vals[i]
+         p=x['parameters']
+         lb[i]=x['bounds'][0]
+         ub[i]=x['bounds'][1]
+         par[i]=self.draw(x['type'],p,lb[i],ub[i])
+      return par, (lb,ub)
+
+   def getBoundsScalar(self):
+      gbounds=[]
+      for x in self.vals:
+         gbounds.append(x['bounds'])
+      return gbounds
+
+def generateIVF():
+    ig=initialValueGenerator()
+    ig.add('poisson',[200],[0,numpy.inf])
+    ig.add('loggaus',[1,0.5],[0,numpy.inf])
+    ig.add('loggaus',[-1,1],[0,numpy.inf])
+    ig.add('const',[6,1],[0,numpy.inf])
+    ig.add('poisson',[200],[0,numpy.inf])
+    ig.add('loggaus',[-2,1],[0,0.1])
+    return ig
+
+def generateIVFFinite():
+    ig=initialValueGenerator()
+    #A
+    ig.add('poisson',[200],[0,10000])
+    #tau
+    ig.add('loggaus',[1,0.5],[0,40])
+    #alpha
+    ig.add('loggaus',[-1,1],[0,5])
+    #dt
+    ig.add('const',[6,1],[6,50])
+    #B
+    ig.add('poisson',[200],[0,10000])
+    #gamma
+    ig.add('loggaus',[-2,1],[0,0.1])
+    return ig
+
+
+def generateCModel():
+    #generate approx candidate
+    ig=initialValueGenerator()
+    ig.add('loggaus',[-3,1],[0,numpy.inf])
+    ig.add('flat',[0,1],[0,1])
+    ig.add('loggaus',[-7,2],[0,numpy.inf])
+    ig.add('gaus',[10,5],[0,numpy.inf])
+    return ig
+
+def generateCModelFinite():
+    #generate approx candidate
+    ig=initialValueGenerator()
+    #k1
+    ig.add('loggaus',[-3,1],[0,1])
+    #BVF
+    ig.add('flat',[0,1],[0,1])
+    #k2
+    ig.add('loggaus',[-7,2],[0,1])
+    #dt
+    ig.add('gaus',[10,2],[5,50])
+    return ig
+
+def generateGauss(fit,bounds,n=1):
+#generate n realizations of a gaussian multivariate distribution with average (vector) mu
+#and (co)variance (matrix) cov
+#output is a (r,n) matrix where r is size of vector mu and n are realizations
+   sig=scipy.linalg.sqrtm(fit.cov)
+   r=fit.mu.shape[0]
+   out=numpy.ones((r,n))
+   i=0
+   while True:
+      s=numpy.matmul(sig,numpy.random.normal(size=r))
+      s+=fit.mu
+      if in_bounds(s,bounds[0],bounds[1]):
+         out[:,i]=s
+         i+=1
+      if i==n:
+         break
+   return out
+   #return numpy.outer(fit.mu,numpy.ones((1,n)))+numpy.matmul(sig,s)
+  
+def in_bounds(x0,lb,ub):
+   r=x0.shape[0]
+   for i in range(r):
+      if x0[i]<lb[i]:
+         return False
+      if x0[i]>ub[i]:
+         return False
+   return True
+
+def getFit(samples,threshold=numpy.inf,verbose=0):
+    class fit:pass
+    chi2=samples[0,:]
+    #mScore=numpy.min(chi2)
+    
+    fit.mu=numpy.mean(samples[1:,chi2<threshold],axis=1)
+    fit.cov=numpy.cov(samples[1:,chi2<threshold])
+    if verbose>0:
+        print(fit.mu)
+        print(fit.cov)
+
+    return fit
+#fit input function
+
+def fitIVF(t, centers,nfit=10,nbatch=20):
+   #t,dt=loadData.loadTime(r,xsetup)
+   #centers=loadCenters(r,xsetup)
+   #find center with maximum content/uptake 
+   m=numpy.unravel_index(numpy.argmax(centers),centers.shape)[0]
+   #treat it as ivf sample
+   ivf=centers[m]
+   #create a partial specialization to be used in optimization
+   w=numpy.ones(t.shape)
+   fun=functools.partial(fitModel.fDiff,fitModel.fIVF,t,ivf,w)  
+   jac=functools.partial(fitModel.jacIVF,t)
+   
+   #generate approx candidate
+   ig=generateIVF()
+   n=ig.getN()
+   samples=numpy.zeros((n+1,nfit))
+   for j in range(nfit):
+      fMin=1e30
+      for i in range(nbatch):
+         par,bounds=ig.generate()
+         result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds,jac=jac)
+         #result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds)
+         #fit is invariant on 1/p[1]==p[2], so 
+         xm=[x for x in result.x]
+         xm[1]=numpy.max([result.x[1],1/result.x[2]])
+         xm[2]=numpy.max([1/result.x[1],result.x[2]])
+         if result.cost<fMin:
+            fMin=result.cost
+            x1=xm
+      samples[1:,j]=x1
+      samples[0,j]=fMin
+   
+   return m,samples
+
+
+#global version with annealing
+def fitIVFGlobal(t, centers,nfit=10, qLambda=0):
+    #find center with maximmum point
+    m=numpy.unravel_index(numpy.argmax(centers),centers.shape)[0]
+    #treat it as ivf sample
+    ivf=centers[m]
+    #this is the (relative) weight of each time point in fit
+    w=numpy.ones(t.shape)
+    #create a partial specialization to be used in optimizations
+    #funcion, a difference between model prediction fIVF and measured values ivf at points t, using point weights w
+    funScalar=functools.partial(fitModel.fDiffScalar,fitModel.fIVF,t,ivf,w)  
+    #add regulizer on A
+    funScalarRegularized=functools.partial(fitModel.fDiffScalarRegularized,funScalar,fitModel.fIVFRegA,qLambda)
+    #Jacobi
+    jac=functools.partial(fitModel.jacIVF,t)
+    #convert it to scalar for global minimization
+    jacScalar=functools.partial(fitModel.jacScalar,fitModel.fIVF,t,ivf,w,jac)
+    #add regularization on A
+    jacScalarRegularized=functools.partial(fitModel.jacScalarRegularized,jacScalar,fitModel.jacIVFRegA,qLambda)
+
+    ig=generateIVFFinite()
+    boundsScalar=ig.getBoundsScalar()
+    #par,bounds=ig.generate()
+    n=ig.getN()
+    samples=numpy.zeros((n+1,nfit))
+    
+    minSetup=dict(method='L-BFGS-B',jac=jacScalar)
+    if qLambda>0:
+       minSetup['jac']=jacScalarRegularized
+
+    for j in range(nfit):
+        if qLambda>0:
+           result=scipy.optimize.dual_annealing(func=funScalarRegularized,bounds=boundsScalar,minimizer_kwargs=minSetup)
+        else:
+           result=scipy.optimize.dual_annealing(func=funScalar,bounds=boundsScalar,minimizer_kwargs=minSetup)
+        xm=[qx for qx in result.x]
+        xm[1]=numpy.max([result.x[1],1/result.x[2]])
+        xm[2]=numpy.max([1/result.x[1],result.x[2]])
+        if result.x[2]!=xm[2]:
+           pass
+            #print(f'[{j}] switch')
+        samples[1:,j]=xm
+        samples[0,j]=result.fun
+        #print(f'[{j}] {result.fun}')
+    return m,samples
+
+
+#fit compartment
+
+def fitCompartment(ivfFit, t, qf ,nfit=10,nbatch=20, useJac=False):
+   
+   igIVF=generateIVF()
+   pars,boundsIVF=igIVF.generate()
+   ivfSamples=generateGauss(ivfFit,boundsIVF,nfit)
+   ig=generateCModel()
+   nC=ig.getN()
+   boundsScalar=ig.getBoundsScalar()
+   samples=numpy.zeros((1+nC+nIVF,nfit))
+   for j in range(nfit):
+      #create a partial specialization to be used in optimization
+      ivfPar=ivfSamples[:,j]
+      w=numpy.ones(t.shape)
+      fc1=functools.partial(fitModel.fCompartment,ivfPar)
+      fun=functools.partial(fitModel.fDiff,fc1,t,qf,w)  
+      jac=functools.partial(fitModel.jacDep,ivfPar,t)
+      #minimize requires scalar function
+      funScalar=functools.partial(fitData.fDiffScalar,fc1,t,qf,w)  
+      jacScalar=functools.partial(fitData.jacScalar,fc1,t,qf,w,jac)
+   
+      fMin=1e30
+      for i in range(nbatch):
+         #generate approx candidate
+         par,bounds=ig.generate()
+         
+         if useJac:
+            result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds,jac=jac)
+            #scalar, just for reference
+            #result=scipy.optimize.minimize(fun=funScalar,x0=par,bounds=boundsScalar,
+            #                                           jac=jacScalar,method='L-BFGS-B')
+         else:
+            result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds)
+         if result.cost<fMin:
+            fMin=result.cost
+            x1=result.x
+      
+      samples[0,j]=fMin
+      samples[1:nC+1,j]=x1
+      samples[(1+nC):,j]=ivfPar
+        
+   
+   return samples
+
+
+def fitCompartmentGlobal(ivfFit, t, qf ,nfit=10,nbatch=20, useJac=False):
+    #nclass=setup['nclass'][0]
+    #t,dt=loadData.loadTime(r,setup)
+    #centers=loadData.loadCenters(r,setup,ir)
+    #qf=centers[m]
+
+    igIVF=generateIVF()
+    xs,boundsIVF=igIVF.generate()
+    ivfSamples=generateGauss(ivfFit,boundsIVF,nfit)
+
+    #samples=numpy.zeros((4+1,nfit))
+    nIVF=igIVF.getN()
+    ig=generateCModelFinite()
+    nC=ig.getN()
+    samples=numpy.zeros((1+nC+nIVF,nfit))
+   
+    boundsScalar=ig.getBoundsScalar()
+
+    #optimize has a hard time dealing with small values, so scale the target up and later scale the parameters (approximately correct)
+    scale=1
+    fi=qf.sum()
+    if fi<0.1:
+        scale=10/fi
+        print('scale={}'.format(scale))
+    #print(qf.sum()*scale)
+    qf*=scale
+    for j in range(nfit):
+        ivfPar=ivfSamples[:,j]
+        w=numpy.ones(t.shape)
+        fc1=functools.partial(fitModel.fCompartment,ivfPar)
+        funScalar=functools.partial(fitModel.fDiffScalar,fc1,t,qf,w)  
+        jac=functools.partial(fitModel.jacDep,ivfPar,t)
+        jacScalar=functools.partial(fitModel.jacScalar,fc1,t,qf,w,jac)
+        minSetup=dict(method='L-BFGS-B',jac=jacScalar)
+        
+        result=scipy.optimize.dual_annealing(func=funScalar,bounds=boundsScalar,minimizer_kwargs=minSetup)
+
+        print(f'[{j}/{nfit}]')
+        
+        samples[0,j]=result.fun/scale
+        qx=result.x
+        qx[0]/=scale
+        qx[1]/=scale
+        print(qx)
+        samples[1:nC+1,j]=qx
+        samples[(1+nC):,j]=ivfPar
+        
+    
+    return samples
+
+
+

+ 323 - 0
pythonScripts/fitModel.py

@@ -0,0 +1,323 @@
+import numpy
+
+#adapters
+def fDiff(f,t,y,w,par):
+   fv=f(t,par)
+   return (fv-y)*w
+
+def fDiffScalar(f,t,y,w,par):
+   df=fDiff(f,t,y,w,par)
+   return (df*df).sum()
+
+def jacScalar(f,t,y,w,jac,par):
+   #(m,) array
+   #m number of time points
+   b=2*fDiff(f,t,y,w,par)
+   J=jac(par)
+   return numpy.dot(b,J)
+
+#add regularization with explicit lambda
+def fDiffScalarRegularized(fDiffScalar,fScalarReg,qLambda,par):
+   #fDiffScalar and fScalar reg should return a scalar to be minimized
+   return fDiffScalar(par)+qLambda*fScalarReg(par)
+
+def jacScalarRegularized(jacScalar,jacScalarReg,qLambda,par):
+   #jac scalar (reg) should return (n,) vector with n number of parameters
+   return jacScalar(par)+qLambda*jacScalarReg(par)
+
+#input function
+def fIVF(t,par):
+   A=par[0]
+   tau=par[1]
+   alpha=par[2]
+   dt=par[3]
+   try:
+      B=par[4]
+      gamma=par[5]
+   except IndexError:
+      print('IndexError')
+      B=0
+      gamma=0
+
+   t1=t-dt
+   x=t1/tau
+   if tau==1/alpha:
+      et=getExp(x,1)
+      fv=A*alpha*getX1(x,et)+B*gamma*getExp(t1,gamma)
+   else:
+      et=getExp(x,1)
+      ea=getExp(t1,alpha)
+      fv=A*alpha*getE(x,ea,et,1-alpha*tau)+B*gamma*getExp(t1,gamma)
+   #fv*=A
+   #fv[t1<0]=0
+   return fv
+
+
+def jacIVF(t,par):
+   #return m by n matrix, where m is t.shape and n is par.shape
+   jac=numpy.zeros((t.shape[0],par.shape[0]))
+   A=par[0]
+   tau=par[1]
+   alpha=par[2]
+   dt=par[3]
+   t1=t-dt
+   x=t1/tau
+   et=getExp(x,1)
+   try:
+      B=par[4]
+      gamma=par[5]
+   except IndexError:
+      print('IndexError')
+      B=0
+      gamma=0
+   
+
+   eg=getExp(t1,gamma)
+   fb=B*gamma*eg
+
+   if tau==1/alpha:
+      fv=A*alpha*getX1(x,et)
+      #first column, df/dA
+      jac[t1>0,0]=fv[t1>0]/A
+      #second column, df/dtau
+      jac[t1>0,1]=-fv[t1>0]/tau*(1+0.5*x[t1>0])
+      #third column df/dalpha
+      jac[t1>0,2]=fv[t1>0]*(1/alpha-0.5*t1[t1>0])
+      #last column df/d(dt)
+      jac[:,3]=fv*alpha-A*alpha*et/tau+gamma*fb
+
+   else:
+      Q=A*alpha/(1-alpha*tau)
+      ea=getExp(t1,alpha)
+      fv=Q*(ea-et)
+      #first column, df/dA
+      jac[:,0]=fv/A
+      #second column, df/dtau
+      jac[:,1]=fv*alpha/(1-alpha*tau)+Q*getX1(x,et)/tau
+      #third column df/dalpha
+      jac[:,2]=fv/alpha/(1-alpha*tau)-Q*getX1(t1,ea)
+      #last column df/d(dt)
+      jac[:,3]=Q*getF(x,alpha,ea,1/tau,et,1)+gamma*fb
+
+   try:
+      jac[:,4]=fb/B
+      jac[:,5]=fb/gamma+B*getX1(t1,eg)
+   except IndexError:
+      pass
+
+   return jac
+
+#regularizers, require A to be as small as possible
+def fIVFRegA(par):
+   return par[0]
+
+def jacIVFRegA(par):
+   jac=numpy.zeros(par.shape[0])
+   jac[0]=1
+   return jac
+
+#helper functions
+
+def getE(x,e,ek,w):
+   #get Ea or Et
+   #first argument is ea for Ea 
+   #or et for Et
+   #last argument is w0 for Ea
+   #or wk for Et
+   E=numpy.zeros(x.shape)
+   E[x>0]=(e[x>0]-ek[x>0])/w
+   return E
+
+def getF(x,a,e,b,eb,w):
+   F=numpy.zeros(x.shape)
+   F[x>0]=(a*e[x>0]-b*eb[x>0])/w
+   return F
+
+def getH(x,a,b,ea=numpy.empty(0),eb=numpy.empty(0)):
+   if ea.size==0:
+      ea=getExp(x,a)
+   if a==b:
+      return getX1(x,ea)
+   if eb.size==0:
+      eb=getExp(x,b)
+   return (ea-eb)/(b-a)
+
+def getH2(x,a,b):
+   #derivative with respect to the second parameter
+   eb=getExp(x,b)
+   if a==b:
+      return -0.5*getX2(x,ea)
+   
+   return -(getH(x,a,b,eb=eb)+getX1(x,eb))/(b-a)
+
+def getHD(x,a,b):
+   #derivative with respect to offset dt
+   ea=getExp(x,a)
+   if a==b:
+      return -ea+a*getX1(t1,ea)
+   eb=getExp(x,b)
+   return (a*ea-b*eb)/(b-a)
+
+def getX1(x,e):
+   #returns x1 or y1
+   #for x1, use ea as the 2nd argument
+   #for y1, use et as the 2nd argument
+   x1=numpy.zeros(x.shape)
+   x1[x>0]=x[x>0]*e[x>0]
+   return x1
+
+def getX2(x,e):
+   x2=numpy.zeros(x.shape)
+   x2[x>0]=x[x>0]*x[x>0]*e[x>0]
+   return x2
+
+def getX3(x,e):
+   x2=numpy.zeros(x.shape)
+   x2[x>0]=x[x>0]*x[x>0]*x[x>0]*e[x>0]
+   return x2
+
+def getExp(x,k):
+   e=numpy.zeros(x.shape)
+   e[x>0]=numpy.exp(-x[x>0]*k)
+   return e
+
+def fCompartment(ivfPar,t,par):
+   A=ivfPar[0]
+   tau=ivfPar[1]
+   alpha=ivfPar[2]
+
+   k1=par[0]
+   BVF=par[1]
+   k2=par[2]
+   dt=par[3]
+   try:
+      B=ivfPar[4]
+      gamma=ivfPar[5]
+   except IndexError:
+      B=0
+      gamma=0
+   return BVF*fIVFStar(t,A,tau,alpha,dt,B,gamma)+(1-BVF)*fC1(t,A,tau,alpha,dt,k1,k2,B,gamma)
+
+def fIVFStar(t,A,tau,alpha,dt,B=0,gamma=0):
+   t1=t-dt
+   x=t1/tau
+   
+   et=getExp(x,1)
+   ea=getExp(t1,alpha)
+   eg=getExp(t1,gamma)
+   AT=alpha*tau
+   wa=1-AT
+
+   
+   fv=numpy.zeros(t1.shape)
+
+   if AT==1:
+      fv[x>0]=A*alpha*x[x>0]*ea[x>0]+B*gamma*getExp(t1,gamma)
+   
+   else:
+      fv=A*alpha*getE(x,ea,et,wa)+B*gamma*getExp(t1,gamma)
+
+   return fv
+
+def fC1(t,A,tau,alpha,dt,k1,k2,B=0,gamma=0):
+   #apply time shift
+   t1=t-dt
+   x=t1/tau
+   
+#place to store results
+   r0=numpy.zeros(t1.shape)
+
+#helper values
+   et=getExp(x,1)
+   ea=getExp(t1,alpha)
+   ek=getExp(t1,k2)
+
+   AT=alpha*tau
+   K2T=k2*tau
+   K1T=k1*tau
+   w0=AT-K2T
+   wk=1-K2T
+   wa=1-AT
+
+#stratify by parameter values
+   #option A and D
+   if AT==1:
+#option D
+      if K2T==1:
+         r0=0.5*A*K1T*alpha*getX2(x,ea)+B*gamma*getH(t1,gamma,k2)
+      else:
+#option A
+         Ea=getE(x,ea,ek,-w0)
+         x1=getX1(x,ea)
+         r0=-K1T*A*alpha/w0*(-Ea+x1)+B*gamma*getH(t1,gamma,k2)
+   else:
+   #option 0, B and C; AT not equal to 1
+      D1=getH(t1,alpha,k2)
+      D2=getH(t1,1/tau,k2)
+      r0=A*alpha*K1T/wa*(D1-D2)+k1*B*gamma*getH(t1,gamma,k2)
+
+   return r0
+
+def jacDep(ivfPar,t,par):
+   jac=numpy.zeros((t.shape[0],par.shape[0]))
+   A=ivfPar[0]
+   tau=ivfPar[1]
+   alpha=ivfPar[2]
+   try:
+      B=ivfPar[4]
+      gamma=ivfPar[5]
+   except IndexError:
+      B=0
+      gamma=0
+  
+   k1=par[0]
+   BVF=par[1]
+   k2=par[2]
+   dt=par[3]
+   
+   t1=t-dt
+   x=t1/tau
+   et=getExp(x,1)
+   ea=getExp(t1,alpha)
+   ek=getExp(t1,k2)
+
+   AT=alpha*tau
+   K2T=k2*tau
+   K1T=k1*tau
+   w0=AT-K2T
+   wk=1-K2T
+   wa=1-AT
+
+
+
+
+   c1=fC1(t,A,tau,alpha,dt,k1,k2,B,gamma)
+   c0=fIVFStar(t,A,tau,alpha,dt,B,gamma)
+   #first column, df/dk1
+   jac[t1>0,0]=c1[t1>0]/k1
+   #second column, df/dBVF
+   jac[t1>0,1]=c0[t1>0]-c1[t1>0]
+
+#more effort for k2 and dt
+#2nd column is df/dk2
+#3rd column is df/d(dt)
+   if AT==1:
+      jac[:,3]=BVF*(A*(AT*getX1(x,ea)-ea)+B*gamma*gamma*getExp(t1,gamma))
+      if K2T==1:
+#option D
+         jac[:,2]=-0.5*K1T*A*tau*tau*getX3(x,ea)+B*k1*gamma*getH2(t1,gamma,k2)
+         jac[:,3]+=(1-BVF)*(0.5*K1T*A*(AT*getX2(x,ea)-2*getX1(x,ea))+B*k1*gamma*getHD(t1,gamma,k2))
+      else:
+#option A
+         jac[:,2]=c1*tau/w0+K1T*A*alpha/w0*getHD(t1,alpha,k2)+B*k1*gamma*getHD(t1,gamma,k2)
+         jac[:,3]+=(1-BVF)*(K1T*A/w0*(tau*getF(x,alpha,ea,k2,ek,-w0)-AT*getX1(x,ea)+ea)+B*k1*gamma*getHD(t1,gamma,k2))
+   else:
+#option 0,B,C
+      jac[:,2]=A*K1T*alpha/wa*(getH2(t1,alpha,k2)-getH2(t1,1/tau,k2))+B*k1*gamma*getH2(t1,gamma,k2)
+      jac[:,3]=BVF*A*alpha*getHD(t1,alpha,1/tau)
+      jac[:,3]+=(1-BVF)*(K1T*A*alpha/wa*(getHD(t1,alpha,k2)-getHD(t1,1/tau,k2))+B*k1*gamma*getHD(t1,gamma,k2))
+      
+   jac[:,2]*=(1-BVF)
+
+   return jac
+

+ 55 - 0
pythonScripts/geometry.py

@@ -0,0 +1,55 @@
+import numpy
+import scipy.interpolate
+
+def getGeometry(img):
+   origin=numpy.array(img.GetOrigin())
+   direction=numpy.array(img.GetDirection())
+   direction=numpy.reshape(direction,(3,3))
+   spacing=numpy.array(img.GetSpacing())
+   #print(origin)
+   #print(direction)
+   #print(spacing)
+   class geometry:pass
+   geometry.origin=origin
+   geometry.spacing=spacing
+   geometry.direction=direction
+   return geometry
+
+def pixelToVector(geometry,pixel):
+   #accounts for reverse order of coordiantes in numpy array relative to SimpleITK image
+   return numpy.dot(geometry.direction,numpy.flip(pixel)*geometry.spacing)+geometry.origin
+
+def vectorToPixel(geometry,vector):
+
+   #accounts for reverse order of coordinates in numpy array relative to SimpleITK image
+   return numpy.flip(numpy.dot(geometry.direction.transpose(),vector-geometry.origin)/geometry.spacing)
+
+def toSpace2(spect,gSPECT,ct,gCT,method='linear'):
+   #convert array spect with geometry gSPECT to an array of size equal to CT w/ corresponding geometry gCT
+   out=numpy.zeros(ct.shape)
+   pixels=[]
+   for i in range(out.shape[0]):
+      print('{}/{}'.format(i,out.shape[0]))
+      for j in range(out.shape[1]):
+         for k in range(out.shape[2]):
+            pixel=[i,j,k]
+            v=pixelToVector(gCT,pixel)
+            pixelSPECT=vectorToPixel(gSPECT,v)
+            pixels.append(pixelSPECT)
+   print('Interpolating {} pixels'.format(len(pixels)))
+   outs=interpolate(spect,pixels,method=method)
+   print('Done')
+   m=0
+   for i in range(out.shape[0]):
+      for j in range(out.shape[1]):
+         for k in range(out.shape[2]):
+            out[i,j,k]=outs[m]
+            m+=1
+    
+   return out
+
+def interpolate(ar,c,method='linear'):
+
+   points=[numpy.linspace(0.5,ar.shape[i]-0.5,ar.shape[i]) for i in range(3)]
+   return scipy.interpolate.interpn(points,ar,c,method=method,fill_value=-1,bounds_error=False)
+

+ 232 - 0
pythonScripts/loadData.py

@@ -0,0 +1,232 @@
+import SimpleITK
+import config
+import os
+import re
+import numpy
+import sklearn.cluster
+import fitData
+import getData
+import geometry
+
+
+def loadTime(r,xsetup):
+
+   tempDir=config.getTempDir(xsetup)
+   code=config.getCode(r,xsetup)
+
+   timeFile=os.path.join(tempDir,code,f'{code}_Dummy.csv')
+   if not os.path.isfile(timeFile):
+      timeFile=os.path.join(tempDir,code,f'{code}_Dummy.mcsv')
+   with open(timeFile,'r') as f:
+      lines=[re.sub('\n','',x) for x in f.readlines()]
+      lines=[x for x in lines if x[0]!='#']
+      v=[[float(x) for x in y.split(',')] for y in lines]
+      t=numpy.array([x[0] for x in v])
+      dt=numpy.array([x[1] for x in v])
+      #convert to seconds from miliseconds
+      t*=1e-3
+      #convert to seconds from miliseconds
+      dt*=1e-3
+   return t,dt
+
+
+def loadData(r,xsetup,returnGeometry=False):
+
+   #load data from nrrd
+
+   t,dt=loadTime(r,xsetup)
+   c1=len(t)
+   nodes=[config.getNodeName(r,xsetup,'NM',i) for i in range(0,c1)]
+   files=[f'{x}.nrrd' for x in nodes]
+   files=[os.path.join(config.getLocalDir(r,xsetup),x) for x in files]
+   filesPresent=[os.path.isfile(x) for x in files]
+   #possible side exit when missing files are encountered
+
+   xdata=[SimpleITK.ReadImage(x) for x in files]
+   geo=geometry.getGeometry(xdata[0])
+
+   xdata=[SimpleITK.GetArrayFromImage(x) for x in xdata]
+   #create new array to hold all data
+   data=numpy.zeros((*xdata[0].shape,len(xdata)))
+   for i in range(len(xdata)):
+      data[...,i]=numpy.array(xdata[i])/dt[i]
+
+   if returnGeometry:
+      return data,geo
+   return data
+
+def getTACAtPixels(data,loc):
+    #data is 4D array, loc are indices as returned by numpy.nonzero()
+    #return nxm array where n is number of time points and m is number of locations
+    #to get TAC for i-th location, do v[:,i]
+    loc1=[loc+(numpy.array([i,i,i,i]),) for i in range(data.shape[3])]
+    v=[data[x] for x in loc1]
+    return numpy.vstack(v)
+
+
+
+def loadCT(r,xsetup,returnGeometry=False):
+   file='{}.nrrd'.format(config.getNodeName(r,xsetup,'CT'))
+   file=getData.getLocalPath(r,xsetup,file)
+   xd=SimpleITK.ReadImage(file)
+   geo=geometry.getGeometry(xd)
+   xd=SimpleITK.GetArrayFromImage(xd)
+   if returnGeometry:
+      return xd,geo
+   return xd
+
+
+def saveCenters(r,xsetup,data=None,ir=0):
+   
+   #if not data:
+   spect,gSPECT=loadData(r,xsetup,returnGeometry=True)
+   ct,gCT=loadCT(r,xsetup,returnGeometry=True)
+   A=spect.reshape(-1,spect.shape[3])
+   nclass=xsetup['nclass'][0]
+
+   #kmeans0 = sklearn.cluster.KMeans(n_clusters=k, random_state=0, n_init="auto").fit(A)
+   #cmeans = sklearn.mixture.GaussianMixture(n_components=k, random_state=0, n_init=1).fit(A)
+   kmeans = sklearn.cluster.BisectingKMeans(n_clusters=nclass, random_state=0, n_init=1).fit(A)
+   centers=kmeans.cluster_centers_
+   u=kmeans.labels_
+   u=u.reshape(spect.shape[0:3])
+   print(u.shape)
+   code=config.getCode(r,xsetup)
+   for i in range(nclass):
+      #ui=(u==i)*numpy.ones(u.shape)
+      #file=getData.getLocalPath(r,xsetup,config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i))
+      #img=SimpleITK.GetImageFromArray(ui)
+      #SimpleITK.WriteImage(img, file)
+      cFile=getData.getLocalPath(r,xsetup,config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i))
+      numpy.savetxt(cFile,centers[i:i+1,:],delimiter=',')
+   #write center map as NRRD file in spect geometry:
+   file=getData.getLocalPath(r,xsetup,config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName='SPECT'))
+   img=SimpleITK.GetImageFromArray(u)
+   img.SetOrigin(gSPECT.origin)
+   img.SetSpacing(gSPECT.spacing)
+   img.SetDirection(numpy.ravel(gSPECT.direction))
+   SimpleITK.WriteImage(img, file)
+
+   #also in CT geometry
+   if True:
+      file1=getData.getLocalPath(r,xsetup,config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName='CT'))
+      #method nearest perservse content, which should be id of the k-means group center
+      u1=geometry.toSpace2(u,gSPECT,ct,gCT,method='nearest')
+      img1=SimpleITK.GetImageFromArray(u1)
+      img1.SetOrigin(gCT.origin)
+      img1.SetSpacing(gCT.spacing)
+      img1.SetDirection(numpy.ravel(gCT.direction))
+      SimpleITK.WriteImage(img1, file1)
+
+   #write center map as numpy array
+   qFile=getData.getLocalPath(r,xsetup,config.getPattern('centerMap',code=code,nclass=nclass,ir=ir,ic=i))
+   usave=numpy.zeros(kmeans.labels_.shape[0]+3)
+   usave[0:3]=spect.shape[0:3]
+   usave[3:]=kmeans.labels_
+   numpy.savetxt(qFile,usave,delimiter=',')
+   
+def loadCenters(r,xsetup,ir=0):
+   nclass=xsetup['nclass'][0]
+   centers=numpy.array(0)
+   for i in range(nclass):
+      cFile=os.path.join(config.getLocalDir(r,xsetup),config.getCenter(r,xsetup,nclass,ir,i))
+      #row
+      c=numpy.loadtxt(cFile,delimiter=',')
+      if len(centers.shape)==0:
+         centers=numpy.zeros((nclass,len(c)))
+      centers[i,:]=c
+   return centers
+
+def loadCenterMap(r,xsetup,ir=0):
+
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   qFile=getData.getLocalPath(r,xsetup,config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
+   usave=numpy.loadtxt(qFile,delimiter=',')
+   shape=[int(x) for x in usave[0:3]]
+   u=numpy.reshape(usave[3:],shape)
+   return u
+
+def loadCenterMapNRRD(r,xsetup,ir=0):
+
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   md=['CT','SPECT']
+   files={x:getData.getLocalPath(r,xsetup,config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x)) for x in md}
+   xd={x:SimpleITK.ReadImage(files[x]) for x in files}
+   nd={x:SimpleITK.GetArrayFromImage(xd[x]) for x in xd}
+   return nd['SPECT'],nd['CT']
+
+
+def saveIVF(r,xsetup,ir=0,nfit=30,nbatch=30,qLambda=0):
+   #fit IVF from centers in realization ir, perform nfit optimized fits where nbatch is used
+   #to find best among nbatch trials (in total, nfit*nbatch fits will be made)
+
+   #requires saveCenters to be run prior to execution 
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   t,dt=loadTime(r,xsetup)
+   centers=loadCenters(r,xsetup,ir)
+   m,samples=fitData.fitIVFGlobal(t,centers,nfit=nfit,qLambda=qLambda)
+
+   fm=m*numpy.ones(samples.shape[1])
+   fw=numpy.vstack((fm,samples))
+   f=getData.getLocalPath(r,xsetup,config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
+   print(f'Saving to {f}')
+   numpy.savetxt(f,fw,delimiter=',')
+   
+def readIVF(r,xsetup,ir=0,qLambda=0):
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   f=getData.getLocalPath(r,xsetup,config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
+   fw=numpy.loadtxt(f,delimiter=',')
+   m=int(fw[0,0])
+   samples=fw[1:,:]
+   return m,samples
+
+def saveSamples(r,xsetup,samples,m,name,iseg=0,ir=0):
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   fm=numpy.zeros(samples.shape[1])
+   n=numpy.min(numpy.array([len(m),samples.shape[0]]))
+   for i in range(n):
+      fm[i]=m[i]+1
+   fw=numpy.vstack((fm,samples))
+   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
+   print(f'Saving samples to {f}')
+   numpy.savetxt(f,fw,delimiter=',')
+
+def readSamples(r,xsetup,name,iseg=0,ir=0):
+   m=[]
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
+   print(f'Reading from  {f}')
+   fw=numpy.loadtxt(f,delimiter=',')
+   samples=fw[1:,:]
+   fm=samples[0,:]
+   for i in range(fm.shape[0]):
+      if fm[i]==0:
+         break
+      m.append(fm[i]-1)
+   return m,samples
+
+def saveTAC(r,xsetup,tac,name,iseg=0,ir=0):
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
+   print(f'Saving samples to {f}')
+   numpy.savetxt(f,tac,delimiter=',')
+
+def readTAC(r,xsetup,name,iseg=0,ir=0):
+   nclass=xsetup['nclass'][0]
+   code=config.getCode(r,xsetup)
+   f=getData.getLocalPath(r,xsetup,config.getPattern('fitCompartment',code=code,nclass=nclass,ir=ir,qaName=name,iseg=iseg))
+   print(f'Reading from {f}')
+   return numpy.loadtxt(f,delimiter=',')
+
+
+
+
+   
+

+ 172 - 0
pythonScripts/plotData.py

@@ -0,0 +1,172 @@
+import matplotlib.pyplot
+import fitData
+import fitModel
+import numpy
+import functools
+import geometry
+
+def plotIVF(t,ivf,samples,threshold,file0=None,file1=None):
+    
+    fit=fitData.getFit(samples,threshold)
+    
+    t1=numpy.linspace(0,numpy.max(t),300)
+    
+    w=numpy.ones(t.shape)
+    fun=functools.partial(fitModel.fDiff,fitModel.fIVF,t,ivf,w)
+    
+    matplotlib.pyplot.figure(0)
+    matplotlib.pyplot.scatter(t,ivf)
+    n=samples.shape[1]
+    alpha=1/n
+    alpha1=1
+    for j in range(samples.shape[1]):
+        cPar=samples[1:,j]
+        df=fun(cPar)
+        chi2=(df*df).sum()
+        cost=samples[0,j]
+        if chi2!=cost:
+            print('{}/{} {}'.format(samples[0,j],chi2,cPar))
+        
+        fv=fitModel.fIVF(t1,cPar)
+        
+        if chi2>threshold:
+            matplotlib.pyplot.figure(0)
+            matplotlib.pyplot.plot(t1,fv,color='silver',alpha=alpha)
+            matplotlib.pyplot.figure(1)
+            matplotlib.pyplot.plot(t,df,color='silver',alpha=alpha)
+        else:
+            matplotlib.pyplot.figure(0)
+            matplotlib.pyplot.plot(t1,fv,alpha=alpha1)
+            matplotlib.pyplot.figure(1)
+            matplotlib.pyplot.plot(t,df,alpha=alpha1)
+    if file0:
+        matplotlib.pyplot.figure(0)
+        matplotlib.pyplot.savefig(file0)
+    
+    if file1:
+        matplotlib.pyplot.figure(1)
+        matplotlib.pyplot.savefig(file1)
+
+def plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=None,file1=None):
+   um=centerMapSPECT==m
+   loc=numpy.nonzero(centerMapSPECT==m)
+
+   #print(loc)
+   #get center of weight of voxels that belong to center indicated as IVF sample
+   x=[numpy.mean(t) for t in loc]
+   #print(x)
+   #convert it to index
+   iSlice=[int(y) for y in x]
+
+   #print(iSlice)
+
+   d20=data[...,-1]
+   fig=matplotlib.pyplot.figure(figsize=(12, 3))
+   axs = fig.subplots(1, 3)
+   axs[0].imshow(d20[iSlice[0],...],cmap='gray_r')
+   axs[0].imshow(um[iSlice[0],...],alpha=0.2,cmap='Reds')
+   axs[1].imshow(d20[:,iSlice[1],:],cmap='gray_r')
+   axs[1].imshow(um[:,iSlice[1],:],alpha=0.2,cmap='Reds')
+   axs[2].imshow(d20[...,iSlice[2]],cmap='gray_r')
+   axs[2].imshow(um[...,iSlice[2]],alpha=0.2,cmap='Reds')
+   if file0:
+      matplotlib.pyplot.savefig(file0)
+
+   #overlay it over CT
+   um1=centerMapCT==m
+   loc1=numpy.nonzero(um1)
+   #print(loc1)
+   x1=[numpy.mean(t) for t in loc1]
+   #print(x1)
+   iSlice1=[int(y) for y in x1]
+   #print(iSlice1)
+   fig=matplotlib.pyplot.figure(figsize=(12, 3))
+   axs = fig.subplots(1, 3)
+   axs[0].imshow(ct[iSlice1[0],...],cmap='gray')
+   axs[0].imshow(um1[iSlice1[0],...],alpha=0.4,cmap='Reds')
+   axs[1].imshow(ct[:,iSlice1[1],:],aspect='auto',origin='lower',cmap='gray')
+   axs[1].imshow(um1[:,iSlice1[1],:],alpha=0.4,aspect='auto',origin='lower',cmap='Reds')
+   axs[2].imshow(ct[...,iSlice1[2]],aspect='auto',origin='lower',cmap='gray')
+   axs[2].imshow(um1[...,iSlice1[2]],alpha=0.4,aspect='auto',origin='lower',cmap='Reds')
+   if file1:
+      matplotlib.pyplot.savefig(file1)
+
+def plotIVFRealizations(t,ivf,samples,threshold,nplot=50,file=None):
+    fit=fitData.getFit(samples,threshold)
+    
+    t1=numpy.linspace(0,numpy.max(t),300)
+    #cPar=fit.mu
+    #fv=fitModel.fIVF(t1,cPar)
+    w=numpy.ones(t.shape)
+    fun=functools.partial(fitModel.fDiff,fitModel.fIVF,t,ivf,w)
+    ig=fitData.generateIVF()
+    par,bounds=ig.generate()
+              
+    ivfSamples=fitData.generateGauss(fit,bounds,nplot)
+    matplotlib.pyplot.figure()
+    matplotlib.pyplot.scatter(t,ivf)
+    for i in range(nplot):
+        cPar=ivfSamples[:,i]
+        fv=fitModel.fIVF(t1,cPar)
+        #df=fun(cPar)
+        matplotlib.pyplot.plot(t1,fv,alpha=0.05,color='blue')
+        #matplotlib.pyplot.plot(t,df)  
+    if file:
+        matplotlib.pyplot.savefig(file)
+
+
+def plotSamples(t,evalArray,file0=None,file1=None):   
+
+   igIVF=fitData.generateIVF()
+   ig=fitData.generateCModel()
+   nIVF=igIVF.getN()
+   nC=ig.getN()
+    
+   t1=numpy.linspace(0,numpy.max(t),300)
+   f1=matplotlib.pyplot.figure()
+   f2=matplotlib.pyplot.figure()
+   matplotlib.pyplot.figure(f1.number)
+   for x in evalArray:
+      matplotlib.pyplot.scatter(t,x[1],color=x[2])
+    
+   for x in evalArray:
+      samples=x[0]
+      color=x[2]
+      qf=x[1]
+      chi2=samples[0,:]
+      median=numpy.median(chi2)
+      n=chi2.shape[0]
+      alpha=1/n
+      alpha1=0.5
+      for j in range(n):
+         cPar=samples[1:1+nC,j]
+         ivfPar=samples[1+nC:,j]
+         fv=fitModel.fCompartment(ivfPar,t1,cPar)
+
+         fc1=functools.partial(fitModel.fCompartment,ivfPar)
+         fun=functools.partial(fitModel.fDiff,fc1,t,qf,numpy.ones(t.shape))
+
+
+         df=fun(cPar)
+         cost=chi2[j]
+         c1=(df*df).sum()
+         print(f'{cost}/{c1}')
+         if cost>median:
+            matplotlib.pyplot.figure(f1.number)
+            matplotlib.pyplot.plot(t1,fv,color='silver',alpha=alpha)
+            matplotlib.pyplot.figure(f2.number)
+            matplotlib.pyplot.plot(t,df,color='silver',alpha=alpha)
+        
+         else:
+            matplotlib.pyplot.figure(f1.number)
+            matplotlib.pyplot.plot(t1,fv,color=color,alpha=alpha1)
+            matplotlib.pyplot.figure(f2.number)
+            matplotlib.pyplot.plot(t,df,color=color,alpha=alpha1)
+   if file0:
+      matplotlib.pyplot.figure(f1.number)
+      matplotlib.pyplot.savefig(file0)
+    
+   if file1:
+      matplotlib.pyplot.figure(f2.number)
+      matplotlib.pyplot.savefig(file1)
+

File diff suppressed because it is too large
+ 177 - 0
pythonScripts/test.ipynb


+ 407 - 0
pythonScripts/workflow.py

@@ -0,0 +1,407 @@
+import config
+import getData
+import loadData
+import fitData
+import numpy
+import segmentation
+import plotData
+import os
+
+
+
+def listRequiredFiles(stage,r,setup):
+    code=config.getCode(r,setup)
+    nclass=setup['nclass'][0]
+    nr=setup['nr']
+    nt=20
+    if stage=='setCenters':
+        names={x:[config.getPattern(x,code)] for x in ['CT','Dummy']}
+        names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
+        return names
+    if stage=='fitIVF':
+        names={x:[config.getPattern(x,code)] for x in ['Dummy']}
+        names['center']=[]
+        for ir in range(nr):
+            rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
+            names['center'].extend(rel)
+        return names
+    if stage=='plotIVF':
+        names={x:[config.getPattern(x,code)] for x in ['Dummy']}
+        names['center']=[]
+        names['fitIVF']=[]
+        qLambda=setup['qLambda']
+        for ir in range(nr):
+            rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
+            names['center'].extend(rel)
+            names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
+            #names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
+            rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
+            names['center'].extend(rel)
+            
+        #names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
+        names.update({x:[config.getPattern(x,code)] for x in ['CT']})
+        names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
+        return names
+    
+    if stage=='fitCompartment':
+        names={}
+        names['center']=[]
+        names['fitIVF']=[]
+        qLambda=setup['qLambda']
+        for ir in range(nr):
+            rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
+            names['center'].extend(rel)
+            names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
+            names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
+        names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
+        return names
+    
+    if stage=='plotCompartment':
+        names={}
+        names['center']=[]
+        names['fitIVF']=[]
+        names['fitCompartment']=[]
+        nseg=setup['nseg']
+        qLambda=setup['qLambda']
+        for ir in range(nr):
+            rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
+            names['center'].extend(rel)
+            names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
+            names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
+            for iseg in range(nseg):
+                rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
+                names['fitCompartment'].extend(rel)
+        names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
+        names.update({x:[config.getPattern(x,code)] for x in ['CT','Dummy']})
+        names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
+        return names
+        
+    return {}
+
+def listCreatedFiles(stage,r,setup):
+    code=config.getCode(r,setup)
+    nclass=setup['nclass'][0]
+    qLambda=setup['qLambda']
+    nr=setup['nr']
+    try:
+       nseg=setup['nseg']
+    except KeyError:
+      nseg=0
+    names={}
+                                
+    if stage=='setCenters':
+        names['center']=[]
+        for ir in range(nr):
+            rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
+            names['center'].extend(rel)
+            #rel=[config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
+            #names['center'].extend(rel)
+            names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
+            rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
+            names['center'].extend(rel)
+            
+        return names
+    if stage=='fitIVF':
+        names['fitIVF']=[]
+        for ir in range(nr):
+            names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
+        return names
+    if stage=='plotIVF':
+        names['plotIVF']=[]
+        for ir in range(nr):
+            x=[config.getPattern('plotIVF',code=code,nclass=nclass,ir=ir,qaName=y,qLambda=qLambda) 
+               for y in ['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']]
+            names['plotIVF'].extend(x)
+        return names
+    if stage=='fitCompartment':
+        xc='fitCompartment'
+        names[xc]=[]
+        sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
+        for ir in range(nr):
+            for iseg in range(nseg):
+                rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
+                names[xc].extend(rel)
+    if stage=='plotCompartment':
+        xc='plotCompartment'
+        names[xc]=[]
+        sNames=['realizations','diff']
+        for ir in range(nr):
+            for iseg in range(nseg):
+                rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
+                names[xc].extend(rel)
+    
+    return []
+    
+
+def getRequiredFiles(stage,r,setup,fb,names=None):
+    #fb,r=getRow(setup,True)
+    if not names:
+        names=listRequiredFiles(stage,r,setup)
+    for f in names:
+        _copyFromServer=getData.copyFromServer
+        if f=='segmentation':
+            _copyFromServer=segmentation.copyFromServer
+        _copyFromServer(fb,r,setup,names[f])
+    return fb,r
+
+def checkRequiredFiles(stage,r,setup,names=None,fb=None,doPrint=False):
+    ok=True
+    if not names:
+        names=listRequiredFiles(stage,r,setup)
+    for f in names:
+        nm=names[f]
+        for x in nm:
+            avail=os.path.isfile(getData.getLocalPath(r,setup,x))
+            if not avail:
+                print(f'Missing {x}')
+                if fb:
+                    _getURL=getData.getURL
+                    if f=='segmentation':
+                        _getURL=segmentation.getURL
+                    availRemote=fb.entryExists(_getURL(fb,r,setup,x))
+                    print(f'Available remote: {availRemote}')
+                ok=False
+            if doPrint:
+                print(f'[{avail}] {x}')
+    return ok
+
+def uploadCreatedFiles(stage,fb,r,setup,names=None):
+   if not names:
+      names=listCreatedFiles(stage,r,setup)
+   for f in names:
+      _copyToServer=getData.copyToServer
+      _getURL=getData.getURL
+      if f=='segmentation':
+         _copyToServer=segmentation.copyToServer
+         _getURL=segmentation.getURL
+      _copyToServer(fb,r,setup,names[f])
+      for x in names[f]:
+         print('[{}] Uploaded {}'.format(fb.entryExists(_getURL(fb,r,setup,x)),x))
+
+
+#this is a poor fit for workflow, but no better logical unit was found, so here it is
+def makeMap(segs,kClass,tac):
+    map={}
+    vals=[(kClass[i],tac[:,i]) for i in range(len(kClass))]
+    
+    for (i,v) in zip(segs,vals):
+        try:
+            map[i].append(v)
+        except KeyError:
+            map[i]=[v]
+    return map
+
+#def getDataAtPixels(data,loc) replaced by loadData.getTACAtPixels(data,loc)
+
+def updateDatabase(r,setup,stage,fb=None,db=None,categories=[]):
+   #set database entry
+   try:
+      qLam=setup['qLambda']
+   except KeyError:
+      qLam=0
+
+   nclass=setup['nclass'][0]
+   code=config.getCode(r,setup) 
+   if stage=='plotIVF':
+      m,samples=loadData.readIVF(r,setup,qLambda=qLam)
+      chi2=samples[0,:]
+      threshold=numpy.median(chi2)
+      fit=fitData.getFit(samples,threshold)
+      row={x:r[x] for x in ['PatientId','visitCode']}
+      row['nclass']=nclass
+      row['mean']=fit.mu[0]
+      row['std']=fit.cov[0,0]
+      row['qLambda']=qLam
+
+      fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLam) for x in categories}
+      row.update(fNames)
+      if db:
+         db.modifyRows('insert',setup['project'],'lists','SummaryIVF',[row])
+
+
+def workflow(r,setup,stage,fb=None,db=None):
+    setCenters=False
+    setIVF=False
+    plotIVF=False
+    setC=True
+    try:
+        qLambda=setup['qLambda']
+    except KeyError:
+        qLambda=0
+                                   
+    if stage=='setCenters':
+        names=listRequiredFiles(stage,r,setup)
+        if fb:
+            getRequiredFiles(stage,r,setup,fb,names=names)
+        if not checkRequiredFiles(stage,r,setup,names=names,fb=fb,doPrint=True):
+            return
+        
+        loadData.saveCenters(r,setup)
+
+    if stage=='fitIVF':
+        #get required files
+        #stage='fitIVF'
+        if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True):
+            return
+        
+        loadData.saveIVF(r,setup,nfit=30,qLambda=qLambda)
+
+    
+    if stage=='plotIVF':
+        ir=0
+        names=listRequiredFiles(stage,r,setup)
+        if fb:
+            getRequiredFiles(stage,r,setup,fb,names=names)
+        if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
+            return
+        
+        print('Loading files to memory')
+        m,samples=loadData.readIVF(r,setup,qLambda=qLambda)
+
+        data=loadData.loadData(r,setup)
+        ct=loadData.loadCT(r,setup)
+        centerMapSPECT,centerMapCT=loadData.loadCenterMapNRRD(r,setup,ir=ir)
+
+        t,dt=loadData.loadTime(r,setup)
+        centers=loadData.loadCenters(r,setup,ir=ir)
+        ivf=centers[m]
+        chi2=samples[0,:]
+        threshold=numpy.median(chi2)
+        code=config.getCode(r,setup)
+        ir=0
+          
+        categories=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
+
+        fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLambda) for x in categories}
+        files={x:getData.getLocalPath(r,setup,fNames[x]) for x in fNames}
+                
+        
+        plotData.plotIVF(t,ivf,samples,threshold,file0=files['fits'],file1=files['diff'])
+        
+        plotData.plotIVFRealizations(t,ivf,samples,threshold,file=files['generatedIVF'])
+        
+        #temporarily blocking center generation
+        plotData.plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=files['centerIVFSPECT'],
+                               file1=files['centerIVFCT'])
+
+        updateDatabase(r,setup,stage,db=db,fb=fb,categories=categories)
+
+    if stage=='fitCompartment':
+        ir=0
+        names=listRequiredFiles(stage,r,setup)
+        if fb:
+            getRequiredFiles(stage,r,setup,fb,names=names)
+        if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
+            return
+        
+        #load class classification
+        u=loadData.loadCenterMap(r,setup)
+        print(u.shape)
+        
+        #load segmentation
+        seg=segmentation.getNRRDImage(r,setup,names)
+        loc=numpy.nonzero(seg)
+        vClass=[int(x) for x in u[loc]]
+        segments=[int(x) for x in seg[loc]]
+        print(segments)
+        data=loadData.loadData(r,setup)
+        tac=loadData.getTACAtPixels(data,loc)
+        
+        segMap=makeMap(segments,vClass,tac)
+        #for x in segMap:
+        #    print('{} {}'.format(x,segMap[x]))
+        #return
+                  
+        m1,samples=loadData.readIVF(r,setup,qLambda=qLambda)
+        chi2=samples[0,:]
+        threshold=numpy.median(chi2)
+        ivfFit=fitData.getFit(samples,threshold)
+        t,dt=loadData.loadTime(r,setup)
+        centers=loadData.loadCenters(r,setup)
+        #save segmentation pixels
+        setup['nseg']=len(segMap.keys())
+        for x in segMap:
+            mArray=segMap[x]
+            qCenter=numpy.zeros(t.shape[0])
+            qData=numpy.zeros(t.shape[0])
+            s=0
+            #average over contributions for each segmentation included in map
+            kCenters=[]
+            for m in mArray:
+                #m is a tuple of classId and tac
+                kCenters.append(m[0])
+                qCenter+=centers[m[0]]
+                qData+=m[1]
+                s+=1
+            qCenter/=s
+            qData/=s
+            samplesC=fitData.fitCompartmentGlobal(ivfFit,t,qCenter,useJac=True,nfit=20)
+            samplesC1=fitData.fitCompartmentGlobal(ivfFit,t,qData,nfit=20,useJac=True)
+            
+            loadData.saveSamples(r,setup,samplesC,kCenters,'kmeansFit',iseg=x,ir=ir,qLambda=qLambda)
+            loadData.saveSamples(r,setup,samplesC1,[-1],'localFit',iseg=x,ir=ir,qLambda=qLambda)
+            loadData.saveTAC(r,setup,qCenter,'kmeansTAC',iseg=x,ir=ir,qLambda=qLambda)
+            loadData.saveTAC(r,setup,qData,'localTAC',iseg=x,ir=ir,qLambda=qLambda)
+
+            
+            
+    if stage=='plotCompartment':
+        ir=0
+        names=listRequiredFiles(stage,r,setup)
+        if fb:
+            getRequiredFiles(stage,r,setup,fb,names=names)
+        if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
+            return
+        
+        tag='plotCompartment'
+        seg=segmentation.getNRRDImage(r,setup,names)
+        loc=numpy.nonzero(seg)
+        segmentIds=list(set([int(x) for x in seg[loc]]))
+        nclass=setup['nclass'][0]
+        code=config.getCode(r,setup)
+        setup['nseg']=len(segmentIds)
+        t,dt=loadData.loadTime(r,setup)
+        for iseg in segmentIds:
+            m,samplesC=loadData.readSamples(r,setup,'kmeansFit',ir=ir,iseg=iseg,qLambda=qLambda)
+            m1,samplesC1=loadData.readSamples(r,setup,'localFit',ir=ir,iseg=iseg,qLambda=qLambda)
+            qCenter=loadData.readTAC(r,setup,'kmeansTAC',ir=ir,iseg=iseg,qLambda=qLambda)
+            qData=loadData.readTAC(r,setup,'localTAC',ir=ir,iseg=iseg,qLambda=qLambda)
+            chi2C=samplesC[0,:]
+            threshold=numpy.median(chi2C)
+            chi2C1=samplesC1[0,:]
+            threshold1=numpy.median(chi2C1)
+            fit=fitData.getFit(samplesC,threshold)
+            fit1=fitData.getFit(samplesC1,threshold1)
+            k1=fit.mu[0]
+            stdK1=fit.cov[0,0]
+            k11=fit1.mu[0]
+            stdK11=fit1.cov[0,0]
+            #update database with entries
+            row={x:r[x] for x in ['PatientId','visitCode']}
+            row['Date']=datetime.datetime.now().isoformat()
+            row['nclass']=nclass
+            row['option']='kmeansFit'
+            row['mean']=k1
+            row['std']=stdK1
+            row['regionId']=iseg
+            row['fitPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='realizations',iseg=iseg,qLambda=qLambda)
+            row['diffPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='diff',iseg=iseg,qLambda=qLambda)
+            row1={x:row[x] for x in row}
+            row1['option']='localFit'
+            row1['mean']=k11
+            row1['std']=stdK11
+            row['qLambda']=qLambda
+            if db:
+                db.modifyRows('insert',setup['project'],'lists','Summary',[row,row1])
+            
+                            
+            evalArray=[(samplesC,qCenter,'blue'),
+                       (samplesC1,qData,'orange')]
+            
+            file0=getData.getLocalPath(r,setup,row['fitPlot'])
+            file1=getData.getLocalPath(r,setup,row['diffPlot'])
+            plotData.plotSamples(t,evalArray,file0=file0,file1=file1)
+
+
+
+    uploadCreatedFiles(stage,fb,r,setup)

Some files were not shown because too many files changed in this diff