#analysis functions for cardiacSPECT dynamic data analysis
#see clusterAnalysis.ipynb for details on usage

import os
import getData
import config
import subprocess
import SimpleITK
import numpy
import re
import json
import segmentation

def calculateCenters(db,setup):

    
   #path of scripts is ../scripts
   #one up
   #scriptPath
   rows=getData.getPatients(db,setup)

  # $MATLAB -sd $MATLABSCRIPTDIR -batch "patientID='${PATIENTID}'; nclass='${NCLASS}'; nRealizations='${NR}'; generateCenters" >& ~/logs/dynamicSPECT.log;

   for r in rows:

      #download 
      calculateRowCenters(r,setup)

def calculateRowCenters(r,setup):

   cmds=config.cmdMatlab()
   tempDir=config.getTempDir(setup)
   code=config.getCode(r,setup)

   nr=setup['nr']
   nclass=setup['nclass']

   for nc in nclass:
      runCmds=[x for x in cmds]
      runCmds.append('-r "path=\'{}\'; patientID=\'{}\'; nclass=\'{}\'; nRealizations=\'{}\'; generateCenters"'.format(tempDir,code,nc,nr))
      print('centers: {}/{},{}'.format(code,nc,nr))
      #redirect and print output to stdout
      print(runCmds)
      print(subprocess.run(runCmds, check=True, stdout=subprocess.PIPE).stdout)

def doAnalysis(db,setup,mode):
   
   rows=getData.getPatients(db,setup)

   for r in rows:
      doAnalysisRow(r,xsetup,mode)

def doAnalysisRow(r,xsetup,mode):
   cmds=config.cmdMatlab()

   tempDir=config.getTempDir(xsetup)
   #this is if IVF is inferred together with k1 fits
   #runScript='doAnalysis.sh'
   #this is if IVF is taken from previous cluster fits
   #runScript='doAnalysisIVF.sh'
    
   if mode=='global':
      mScript='analyze'
      analysisType=''
   if mode=='IVF':
      mScript='analyzeIVF'
      analysisType='IVF_'
    
   try:
      print('Setting runScript to {}'.format(mScript))
   except NameError:
      print('Mode can be one of (global,IVF))')
      return
   
   code=config.getCode(r,xsetup)
   nr=xsetup['nr']
   nclass=xsetup['nclass']

   for nc in nclass:
      for j in numpy.arange(nr):
         print('{} [{} {}/{}]'.format(code,nc,j+1,nr))

         #avoid repetition, only do for new files
         #this is a duplicate of path generated in fitCenters.m
         fName=os.path.join(tempDir,code,config.getFitParFinalName(code,nc,j,analysisType))
         if os.path.isfile(fName):
            print('Skipping; {} available.'.format(fName))
            continue
         runCmds=[x for x in cmds]
         runCmds.append('-r "path=\'{}\'; patientID=\'{}\'; nclass=\'{}\'; realizationId=\'{}\'; {}"'.format(tempDir,code,nc,j+1,mScript))
      
         subprocess.run(runCmds,check=True, stdout=subprocess.PIPE)
   

 

def doPixelAnalysis(db,setup,sigma2,mode='IVF'):

    baseDir=os.path.dirname(os.getcwd())#one up
    rows=getData.getPatients(db,setup)

    for r in rows:
       doPixelAnalysisRow(db,r,setup,mode)

def doPixelAnalysisRow(db,r,setup, mode='IVF'):
   cmds=config.cmdMatlab()

   tempDir=config.getTempDir(setup)
   #
                
    #in global mode, IVF parameters are inferred together with fits to classes
    #this is essentially repeat of the above, except that classes are taken as
    #time-response curves from pixels in the sigma2-defined neighborhood of
    #target pixels

   if mode=='global':
      mScript='analyzePixel'
      analysisType=''
   if mode=='IVF':
      mScript='analyzePixelIVF'
      analysisType='IVF_'
    

    #in IVF mode, the parameters of input function are taken from the cluster fit 
    #(doAnalysis above, with mode=general). The rest is the same as for global mode
   
   try:
      print('Setting runScript to {}'.format(mScript))
   except NameError:
      print('Mode can be one of (global,IVF))')
      return
    
   #code=config.getCode(r,setup)
   nc=segmentation.getNC(r,setup)
   #loadSegmentation(db,fb,r,setup)
   #nc=x.shape[0]

   sigma2=setup['sigma2']
   code=config.getCode(r,setup)



   for s2 in sigma2:
      f=config.getPixelFitParFinalName(code,nc,s2,mode)
      fName=getData.getLocalPath(r,setup,f)
      sFile=segmentation.getSegmentationFileName(r,setup,db=db)
      segmFile=getData.getLocalPath(r,setup,sFile)
      if os.path.isfile(fName):
         print('Skipping; {} available.'.format(fName))
         continue

      runCmds=[x for x in cmds]
      runCmds.append('-r "path=\'{}\'; patientID=\'{}\'; sigma2=\'{}\'; segmFile=\'{}\'; {}"'.format(tempDir,code,s2,segmFile,mScript))
      print(f'Running with {s2}')
      
      print(subprocess.run(runCmds, check=True, stdout=subprocess.PIPE).stdout)
      print(f'Done')
 
            
def getIWeights(r,setup,nclass,realizationId,ic):
    locDir=config.getLocalDir(r,setup)
    code=config.getCode(r,setup)
    fname='{}_{}_{}_center{}_centerWeigth.nrrd'.\
            format(code,nclass,realizationId+1,ic+1)
    uFile=os.path.join(locDir,fname)
            
    imU=SimpleITK.ReadImage(uFile)
    return SimpleITK.GetArrayFromImage(imU)

def getGaussianWeight(nU,pt,sigma2,na):
    #gaussian weighted summation of surrounding pixels
    
    #find point closest to the target point
    idx=[int(x) for x in pt]
    #running offset        
    fidx=numpy.zeros(3)
    #half of the neighborhood
    na2=0.5*(na-1)
    
    fs=0
    fWeight=0
    for i in numpy.arange(na):
        fidx[0]=idx[0]+(i-na2)
        for j in numpy.arange(na):
            fidx[1]=idx[1]+(j-na2)
            for k in numpy.arange(na):
                fidx[2]=idx[2]+(k-na2)
                fidx=[int(x) for x in fidx]
                fd=numpy.array(fidx)-numpy.array(pt)
                dist2=numpy.sum(fd*fd)
                fw=numpy.exp(-0.5*dist2/sigma2);
                fs+=fw
                fWeight+=fw*nU[tuple(fidx)]
                #print('\t{}/{}: {}/{:.2f} {:.2g} {:.3g} {:.2g}'.format(idx,fidx,numpy.sqrt(dist2),fw,nU[tuple(fidx)],fs,fWeight))
    fWeight/=fs
    return fWeight
            


#gets weights by class for a particular realization and sigma2 averaging
def getWeights(db,r,setup,nclass,realizationId,sigma2,na):
    #for w1, classes are in 0 to nclass-1 space
    #na is the size of the neighborhood
    idFilter=config.getIdFilter(r,setup)
    visitFilter=config.getVisitFilter(r,setup)
    code=config.getCode(r,setup)
    rows=getData.getSegmentation(db,setup,[idFilter,visitFilter])
    pts={r['regionId']:[float(x) for x in [r['x'],r['y'],r['z']]] for r in rows}
    
    w={region:numpy.zeros(nclass) for region in pts}
    na2=0.5*(na-1)
    
    for c in numpy.arange(nclass):
        nU=getIWeights(r,setup,nclass,realizationId,c)
        
        for region in w:
            #print(pts[region])
            #print('{} {}/{} {}'.format(code,c,nclass,region))
            #gaussian weighted summation of surrounding pixels
            w[region][c]=getGaussianWeight(nU,pts[region],sigma2,na)
            
            
    return w


#gets fitPar for a particular realization in [0..nr-1] range
def getPixelFitPar(fb,r,setup,nc,s2,mode):
   code=config.getCode(r,setup)
   f=config.getPixelFitParFinalName(code,nc,s2,mode)
   print('getPixelFitPar {}'.format(f))
   getData.copyFromServer(fb,r,setup,[f])
   fName=getData.getLocalPath(r,setup,f)
   print('getPixelFitPar {}'.format(fName))
   return numpy.genfromtxt(fName,delimiter='\t')

def getFitPar(fb,r,setup,nclass,realizationId,mode):
   code=config.getCode(r,setup)
   f=config.getFitParFinalName(code,nclass,realizationId,mode)
   getData.copyFromServer(fb,r,setup,[f])
   fName=getData.getLocalPath(r,setup,f)
   return numpy.genfromtxt(fName,delimiter='\t')


def getFitParBackup(r,setup,nclass,realizationId,mode=''):
    #fitGoodnes A tau alpha delay [k1 BVF k2 delay]xNcenters
    allowedModes=['','IVF','Pixel','PixelIVF']
    if mode not in allowedModes:
        print('Mode should be one of {}'.format(allowedModes))
        return []
    
    if mode=='PixelIVF':
        #4th parameter is sigma, not realizationId
        rCode='{:.2f}'.format(realizationId)
        rCode=re.sub('\.','p',rCode)
    else:
        #add one to match matlab 1..N array indexing
        rCode='{}'.format(realizationId+1)
    
    if len(mode)>0:
        mode=mode+'_'
    
    code=config.getCode(r,setup)
    fname='{}_{}_{}_{}fitParFinal.txt'.format(code,nclass,rCode,mode)
    locDir=config.getLocalDir(r,setup)
    uFile=os.path.join(locDir,fname)
    return numpy.genfromtxt(uFile,delimiter='\t')

def getK1(fitPar,iclass):
    #fitGoodnes A tau alpha delay [k1 BVF k2 delay]xNcenters
    #0 to nclass-1 space
    return fitPar[4*iclass+5]

def calculateK1(w,fitPar):
    #calculateK1 for region weights
    #return the k1 that belongs to the 
    #maximum class in region (M) and 
    #a weighted average (W)
    k1={region:{'M':0,'W':0} for region in w}
    for region in w:
        cmax=numpy.argmax(w[region])
        k1[region]['M']=getK1(fitPar,cmax)
        fs=0
        for c in numpy.arange(len(w[region])):
            fs+=w[region][c]*getK1(fitPar,c)
        k1[region]['W']=fs
    return k1

def getPatientValuesByNclass(db,r,setup,nclass,nrealizations,sigma2,na):
    #summary script
    #db is for database; needs segmentations
    #r,setup identify patient
    #nclass and nrealizations select strategy
    #sigma2 is for combining output from adjacent pixels
    #na is neighborhood size where smoothing/combination is done
    k1={}
    for rId in numpy.arange(nrealizations):
        w=getWeights(db,r,setup,nclass,rId,sigma2,na)
        fitPar=getFitPar(r,setup,nclass,rId,'IVF')
        qk1=calculateK1(w,fitPar)
        for region in w:
            for type in qk1[region]:
                try:
                    k1[region][type].append(qk1[region][type])
                except KeyError:
                    k1={region:{type:[] for type in qk1[region]} for region in w}
                    print(type)
                    k1[region][type].append(qk1[region][type])
        print('[{}] {}/{}'.format(nclass,rId+1,nrealizations))
    return k1   

def getSummaryPatientValuesByNclass(db,r,setup,nclass,nrealizations,sigma2,na):
    #summary script, same arguments as above
    #also return deviation over realization
    k1=getPatientValuesByNclass(db,r,setup,nclass,nrealizations,sigma2,na)
    avgType=['M','W']
    summaryK1={type:{region:{
        'mean':numpy.mean(k1[region][type]), 
        'std':numpy.std(k1[region][type]), 
        'median':numpy.median(k1[region][type])} for region in k1}
               for type in avgType}
    
    return summaryK1

def fullSummary(db,setup,classes,nr,sigma2,na):
    rows=getData.getPatients(db,setup)
    return \
        {config.getCode(r,setup):\
         {c:getSummaryPatientValuesByNclass(db,r,setup,c,nr,sigma2,na) for c in classes} for r in rows}
          
def storeSummary(db,setup,summary,sigma2,na):
    #dsM=db.selectRows(project,'study','Imaging',[])
    for rCode in summary:
        r=config.decode(rCode,setup)
        idFilter=config.getIdFilter(r,setup)
        visitFilter=config.getVisitFilter(r,setup)
        for c in summary[rCode]:
            cFilter={'variable':'nclass','value':str(c),'oper':'eq'}
            for t in summary[rCode][c]:
                tFilter={'variable':'option','value':t,'oper':'eq'}
                for rId in summary[rCode][c][t]:
                    rFilter={'variable':'regionId','value':str(rId),'oper':'eq'}
                    rows=getData.getSummary(db,setup,[idFilter,visitFilter,cFilter,tFilter,rFilter])
                    if len(rows)>0:
                        qrow=rows[0]
                        mode='update'
                    else:
                        qrow={qr:r[qr] for qr in r}
                        qrow['nclass']=c
                        qrow['option']=t
                        qrow['regionId']=rId
                        seqNum=config.getTargetSeqNum(r,setup)
                        qrow['SequenceNum']=100+seqNum+c+0.001*rId
                        if t=='M':
                            qrow['SequenceNum']+=0.0001
                        mode='insert'
                    for v in summary[rCode][c][t][rId]:
                        qrow[v]=summary[rCode][c][t][rId][v]
                    qrow['sigma2']=sigma2
                    qrow['na']=na
                    getData.updateSummary(db,setup,mode,[qrow])
                    
def summaryPixelIVF(db,fb,setup):
    #for second type of analysis (pixel based regions)
    qfilter=config.getFilter(setup)
    rows=getData.getPatients(db,setup,qfilter)
    sigma2=setup['sigma2']
    return \
        {config.getCode(r,setup):\
         {s2:getPixelIVF(db,fb,r,setup,s2) for s2 in sigma2} for r in rows}

    
def storeIVF(db,setup,summary):
    for rCode in summary:
        r=config.decode(rCode,setup)
        idFilter=config.getIdFilter(r,setup)
        visitFilter=config.getVisitFilter(r,setup)
        for s2 in summary[rCode]:
            sigmaFilter={'variable':'sigma2','value':str(s2),'oper':'eq'}
            nr=len(summary[rCode][s2])
            for rId in summary[rCode][s2]:
                rFilter={'variable':'regionId','value':str(rId),'oper':'eq'}
                typeFilter={'variable':'option','value':'D','oper':'eq'}
                rows=getData.getSummary(db,setup,[idFilter,visitFilter,sigmaFilter,rFilter,typeFilter])
                if len(rows)>0:
                    qrow=rows[0]
                    mode='update'
                else:
                    qrow={qr:r[qr] for qr in r}
                    qrow['sigma2']=s2
                    qrow['regionId']=rId
                    seqNum=config.getTargetSeqNum(r,setup)
                    qrow['SequenceNum']=140+seqNum+0.01*rId+0.001*s2
                    mode='insert'
                qrow['mean']=summary[rCode][s2][rId]
                qrow['na']=7
                qrow['nclass']=nr
                qrow['option']='D'
                getData.updateSummary(db,setup,mode,[qrow])
                
def getPixelIVF(db,fb,r,setup,sigma2):
    idFilter=config.getIdFilter(r,setup)
    visitFilter=config.getVisitFilter(r,setup)
    nclassIVF=segmentation.getNC(r,setup)
    regions=[{'regionId':x+1} for x in range(nclassIVF)]
    #rows=getData.getSegmentation(db,setup,[idFilter,visitFilter])
    #nclassIVF=len(rows)

    #x=segmentation.loadSegmentation(db,fb,r,setup)
    #nclassIVF=x.shape[0]
    #this assumes segmentation is loaded
    #nclassIVF=segmentation.getNC(r,setup)
    #fitPar=getFitPar(r,setup,nclassIVF,sigma2,'PixelIVF')
    fitPar=getPixelFitPar(fb,r,setup,nclassIVF,sigma2,'IVF')
    print(fitPar)
    k1={r['regionId']:getK1(fitPar,r['regionId']-1) for r in regions}
    print(k1)
    return k1