import numpy
import config
import SimpleITK
import os
import getData
import matplotlib.pyplot

def guessPixelPosition15(sx=-1,sy=-1,sz=-1):
    #guess position of segments
    if sx<0:
        sx=12
    if sy<0:
        sy=28
    if sz<0:
        sz=32
    rz=4
    oz=0

    slc=[sx,sy,sz]
    p1=[sx,sy,sz]
    pts={
      '0':[sx-5,sy,sz],\
      '1':[sx-2,sy,sz-rz],\
      '2':[sx-2,sy,sz+rz-1],\
      '3':[sx-1,sy-rz,sz],\
      '4':[sx-1,sy+rz-1,sz],\
      '5':[sx,sy-rz+oz,sz],\
      '6':[sx,sy-0.3*rz+oz,sz-rz],\
      '7':[sx,sy-0.3*rz+oz,sz+rz],\
      '8':[sx,sy,sz],\
      '9':[sx,sy+0.3*rz+oz,sz-rz],\
      '10':[sx,sy+0.3*rz+oz,sz+rz],\
      '11':[sx,sy+rz+oz,sz],\
      '12':[sx+3,sy-rz,sz],\
      '13':[sx+3,sy,sz-rz],\
      '14':[sx+3,sy,sz+rz],\
      '15':[sx+3,sy+rz,sz]}
    slices={'0':['8','0','1','13','2','14'],\
      '1':['8','0','3','4','12','15'],\
      '2':['8','11','9','6','5','7','10']}
   
    print(slices['0'])

    fp={x:[pts[q] for q in slices[x]] for x in slices}
    
    sliceIds={x:[] for x in pts}
    for p in pts:
       for s in slices:
          if p in slices[s]:
             sliceIds[p].append(s)
    
    sliceCode={x:';'.join(sliceIds[x]) for x in sliceIds}
    print(fp)
    print(sliceCode)
    return fp

    #tip
    fp={'0':[\
            [sx,sy,sz],   
            [sx-5,sy,sz],\
            [sx-2,sy,sz-rz],\
            [sx+3,sy,sz-rz],\
            [sx-2,sy,sz+rz-1],\
            [sx+3,sy,sz+rz]],\
        '1':[\
            [sx,sy,sz],  
            [sx-5,sy,sz],\
            [sx-1,sy-rz,sz],\
            [sx-1,sy+rz-1,sz],\
            [sx+3,sy-rz,sz],\
            [sx+3,sy+rz,sz]],\
        '2':[\
            [sx,sy,sz],
            [sx,sy+rz+oz,sz],\
            [sx,sy+0.3*rz+oz,sz-rz],\
            [sx,sy-0.3*rz+oz,sz-rz],\
            [sx,sy-rz+oz,sz],\
            [sx,sy-0.3*rz+oz,sz+rz],\
            [sx,sy+0.3*rz+oz,sz+rz]]}
    return fp

def guessPixelPosition4(sx=-1,sy=-1,sz=-1):
    #guess position of segments
    if sx<0:
        sx=32
    if sy<0:
        sy=31
    if sz<0:
        sz=31
    rz=4

    pts={
      '0':[sx,sy,sz],\
      '1':[sx,sy,sz-rz],\
      '2':[sx,sy,sz+rz],\
      '3':[sx,sy-rz,sz],\
      '4':[sx,sy+rz,sz]}
    slices={
      '0':['0','1','2'],\
      '1':['0','3','4'],\
      '2':['0','1','2','3','4']}
   
    print(slices['0'])

    fp={x:[pts[q] for q in slices[x]] for x in slices}
    
    sliceIds={x:[] for x in pts}
    for p in pts:
       for s in slices:
          if p in slices[s]:
             sliceIds[p].append(s)
    
    sliceCode={x:';'.join(sliceIds[x]) for x in sliceIds}
    print(fp)
    print(sliceCode)
    return [{'regionId':x,'x':pts[x][0],'y':pts[x][1],'z':pts[x][2],'sliceId':sliceCode[x]} for x in pts]

def updateSegmentation(db,setup,r,pixels):
   copyFields=['PatientId','visitCode']
   for x in pixels:
      for c in copyFields:
         x[c]=r[c]
      x['SequenceNum']=r['SequenceNum']+0.01*int(x['regionId'])
      filterVar=['PatientId','SequenceNum']
      qFilter=[{'variable':y,'value':'{}'.format(x[y]),'oper':'eq'} for y in filterVar]
      ds=db.selectRows(setup['project'],'study','Segmentation',qFilter)
      entry={}
      mode='insert'
      if len(ds['rows'])>0:
         entry=ds['rows'][0]
         mode='update'
      for q in x:
         entry[q]=x[q]
      db.modifyRows(mode,setup['project'],'study','Segmentation',[entry])
   print('Done')
    
def getSegmentationFileName(r,setup,db=None):
   if not db:
      db,fb=getData.connectDB(setup['network'])
   if setup['segmentationMode']=='TXT':
      return '{}_Segmentation.txt'.format(config.getCode(r,setup))
   if setup['segmentationMode']=='NRRD':
      copyFields=['PatientId','visitCode']
      qFilter=[{'variable':x,'value':r[x],'oper':'eq'} for x in copyFields]
      qFilter.append({'variable':'User','value':setup['targetUser'],'oper':'eq'})
      rows=getData.getSegmentation(db,setup,qFilter)
      r=rows[0]
      return r['latestFile']

def getURL(fb,r,setup,name):
   remoteDir=fb.buildPathURL(setup['project'],config.getPathList(r,setup))
   return '/'.join([remoteDir,'Segmentations',name])
 
def copyFromServer(fb,r,setup,names):
   try:
      forceReload=setup['forceReload']
   except KeyError:
      forceReload=False

   getData.getLocalDir(r,setup,createIfMissing=True)
   remoteDir=fb.buildPathURL(setup['project'],config.getPathList(r,setup))
   for n in names:
      localPath=getData.getLocalPath(r,setup,n)
      if os.path.isfile(localPath) and not forceReload:
         continue
      remotePath='/'.join([remoteDir,'Segmentations',n])
      fb.readFileToFile(remotePath,localPath)

def copyToServer(fb,r,setup,names):
   remoteDir=fb.buildPathURL(setup['project'],config.getPathList(r,setup))
   for n in names:
      localPath=getLocalPath(r,setup,n)
      remotePath='/'.join([remoteDir,'Segmentations',n])
      fb.writeFileToFile(localPath,remotePath)


def writeSegmentation(db,fb,r,setup):

   if setup['segmentationMode']=='NRRD':
      print('Failed to load segmentation')
      return
   
   fileName=getSegmentationFileName(db,r,setup)
   idFilter={'variable':'PatientId','value':config.getPatientId(r,setup),'oper':'eq'}
   visitFilter={'variable':'visitCode','value':config.getVisitId(r,setup),'oper':'eq'}
   rows=getData.getSegmentation(db,setup,[idFilter,visitFilter])
   v=numpy.zeros((len(rows),3))
   for qr in rows:
      region=int(qr['regionId'])
      v[region,2]=float(qr['x'])
      v[region,1]=float(qr['y'])
      v[region,0]=float(qr['z'])
    #for i in range(len(rows)):
    #    print(v[i,:])
      numpy.savetxt(getData.getLocalPath(r,setup,fileName),v)
      getData.copyToServer(fb,r,setup,[fileName])

def getNC(r,xsetup):
   if xsetup['segmentationMode']=='TXT':
      getNCTxt(r,xsetup)
   if xsetup['segmentationMode']=='NRRD':
      return xsetup['NC']

def getNCTxt(r,xsetup):
   sName=getSegmentationFileName(db=None,r=r,setup=xsetup)
   fName=getData.getLocalPath(r,xsetup,sName)
   x=numpy.loadtxt(fName)
   nc=x.shape[0]
   return nc 
   
def loadSegmentation(db,fb,r,setup):
   
   sName=getSegmentationFileName(db,r,setup)
   print(f'Looking for {sName}')
   fName=getData.getLocalPath(r,setup,sName)
   print(f'Local {fName}')
   if not os.path.isfile(fName):
      fURL=getData.getURL(fb,r,setup,sName)
      if fb.entryExists(fURL):
         getData.copyFromServer(fb,r,setup,[sName])
         if os.path.isfile(fName):
            print(f'Copied {fURL} to {fName}') 
         else: 
            print(f'Failed to load {fName} from {fURL}')
      else:
         #this creates local and global file
         writeSegmentation(db,fb,r,setup)

   if setup['segmentationMode']=='TXT':
      return numpy.loadtxt(fName)

   
def plotSegmentation(db,fb,r,setup,vmax=1000):    
   copyFields=['PatientId','visitCode']
   qFilter=[{'variable':x,'value':r[x],'oper':'eq'} for x in copyFields]
   nim=getData.getPatientNIM(fb,r,setup)
   rows=getData.getSegmentation(db,setup,qFilter)

   if len(rows)==0:
      pId=r['PatientId']
      visitCode=r['visitCode']
      print(f'Not found for id={pId}/{visitCode}')
      return
   
   fp={}
   
   for q in rows:
      if q['regionId']==0:
         slc=[q['x'],q['y'],q['z']]
         slc=[int(x) for x in slc]
      slices=q['sliceId'].split(';')
      for s in slices:
         try:
            fp[s].append([float(x) for x in [q['x'],q['y'],q['z']]])
         except KeyError:
            fp[s]=[]
            fp[s].append([float(x) for x in [q['x'],q['y'],q['z']]])

   cut0=20
   w0=20
   cut1=20
   w1=20
   cut2=20
   w2=20
   vmin=0
   nd=3
   fig,ax=matplotlib.pyplot.subplots(3,2*nd+1,figsize=(20,12))
   for i in numpy.arange(0,2*nd+1):
      ax[0,i].set_xlabel('z')
      ax[0,i].set_ylabel('x')
      ax[0,i].imshow(nim[cut2:cut2+w2,slc[1]-nd+i,cut0:cut0+w0],cmap='gray_r',vmax=vmax,vmin=vmin)
      ax[1,i].set_xlabel('x')
      ax[1,i].set_ylabel('y')
      ax[1,i].imshow(nim[cut2:cut2+w2,cut0:cut0+w0,slc[2]-nd+i].T,cmap='gray_r',vmax=vmax,vmin=vmin)
      ax[2,i].set_xlabel('z')
      ax[2,i].set_ylabel('y')
      ax[2,i].imshow(nim[slc[0]-nd+i,cut1:cut1+w1,cut1:cut1+w1],cmap='gray_r',vmax=vmax,vmin=vmin)
      if i==nd:
         pt=fp['0']
         ax[0,i].scatter([x[2]-cut0 for x in pt],[x[0]-cut2 for x in pt])
         pt=fp['1']
         ax[1,i].scatter([x[0]-cut2 for x in pt],[x[1]-cut0 for x in pt])
         pt=fp['2']
         ax[2,i].scatter([x[2]-cut1 for x in pt],[x[1]-cut1 for x in pt])

         if i==0:
            ax[0,i].text(2,2,pId,fontsize='large')
   name='{}_segmentation.png'.format(config.getCode(r,setup))
   fPath=getData.getLocalPath(r,setup,name)
   fig.savefig(fPath)
   getData.copyToServer(fb,r,setup,[name])


def getNRRDImage(r,setup,names=None):
   if names:
      localFile=getData.getLocalPath(r,setup,names['segmentation'][0])
   else:
      localFile=getData.getLocalPath(r,setup,getSegmentationFileName(r,setup))
   segImg=SimpleITK.ReadImage(localFile) 
   seg=SimpleITK.GetArrayFromImage(segImg)
   return seg