import os
import sys
import json
import pathlib
import slicer
import subprocess
import shutil

#for all rows in xconfig[queryName] import dicom and convert to nrrd

def main(configFile):
   print('Running with {}'.format(configFile))
   with open(configFile) as f:
      xconfig=json.load(f)

   fsetup=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
   with open(fsetup) as f:
      setup=json.load(f)

   sys.path.append(setup['paths']['nixWrapper'])
   import nixWrapper

   nixWrapper.loadLibrary('labkeyInterface')
   import labkeyInterface
   import labkeyDatabaseBrowser
   import labkeyFileBrowser
   

   nixWrapper.loadLibrary('orthancInterface')
   import orthancInterface
   import orthancDatabaseBrowser
   import orthancFileBrowser
   
#one should be smart to figure this out if pwd is not the same as the directory of the script

   pwd=os.path.dirname(os.path.abspath(__file__))
   pwdUp=os.path.dirname(pwd)
   pythonScripts=os.path.join(pwdUp,'pythonScripts')
   sys.path.append(pythonScripts)
   import config
   import getData
   print('Config loaded')

   net=labkeyInterface.labkeyInterface()
   fnet=os.path.join(os.path.expanduser('~'),'.labkey',xconfig['network'])
   net.init(fnet)
      r['CT']=
   net.getCSRF()
   fb=labkeyFileBrowser.labkeyFileBrowser(net)
   db=labkeyDatabaseBrowser.labkeyDB(net)


   onet=orthancInterface.orthancInterface()
   onet.init(fnet)
   ofb=orthancFileBrowser.orthancFileBrowser(onet)
   odb=orthancDatabaseBrowser.orthancDB(onet)

   qFilter=config.getFilter(xconfig)
   ds=db.selectRows(xconfig['project'],xconfig['schemaName'],xconfig['queryName'],qFilter)


   try:
      rows=ds['rows']
   except KeyError:
      print('No rows returned')
      return
   for r in rows:
      print("Loading {}/{}".format(config.getPatientId(r,xconfig),config.getVisitId(r,xconfig)))
      rPath=fb.formatPathURL(xconfig['project'],config.getOutputDir(r,xconfig))
      rPath+='/'+config.getNodeName(r,xconfig,'CT')+'.nrrd'
      entryDone=fb.entryExists(rPath)
      if entryDone:
         try:
            if not xconfig['recalculate']:
               print('Entry done')
               continue
         except KeyError:
            print('Entry done')
            continue

         print('Forced recalculation')

            
      print('{} available:{}'.format(config.getNodeName(r,xconfig,'CT'),fb.entryExists(rPath)))
      #loadPatient into slicer
      patient=loadPatient(ofb,r,xconfig)
      #convert to nodes
      addCT(r,patient,xconfig)
      addFrames(r,patient,xconfig)
      addDummyInputFunction(r,patient,xconfig)
      
      nodes=slicer.mrmlScene.GetNodesByClass('vtkMRMLScalarVolumeNode')
      print('Nodes')
      for node in nodes:
         print('\t{}'.format(node.GetName()))
         storeNode(fb,r,xconfig,node)

      nodes=slicer.mrmlScene.GetNodesByClass('vtkMRMLDoubleArrayNode')
      print('Nodes (double array)')
      for node in nodes:
         print('\t{}'.format(node.GetName()))
         storeNode(fb,r,xconfig,node)

      nodes=slicer.mrmlScene.GetNodesByClass('vtkMRMLTableNode')
      print('Nodes (table)')
      for node in nodes:
         print('\t{}'.format(node.GetName()))
         storeNode(fb,r,xconfig,node)



      clearNodes(r,xconfig)
      #addCT and addFrames fill r['ct'] and r['spect']
      db.modifyRows('update',xconfig['project'],xconfig['schemaName'],xconfig['queryName'],[r])
      getData.updateStatus(db,r,setup,'convertToNRRD')

      
def clearNodes(row,xconfig):
   nodes=slicer.mrmlScene.GetNodesByClass('vtkMRMLScalarVolumeNode')
   nodes1=slicer.mrmlScene.GetNodesByClass('vtkMRMLDoubleArrayNode')
   nodes2=slicer.mrmlScene.GetNodesByClass('vtkMRMLTableNode')
   for n in nodes1:
      nodes.AddItem(n)
   for n in nodes2:
      nodes.AddItem(n)

   res=[slicer.mrmlScene.RemoveNode(f) for f in nodes]

def loadPatient(ofb,r,xconfig):
   sys.path.append('../pythonScripts')
   import parseDicom
   import vtkInterface

   pd=parseDicom.parseDicom()

   masterPath=downloadAndUnzip(ofb,r,"nmMaster",xconfig)
   pd.readMasterDirectory(masterPath,False)
   print('Time [{} .. {}]'.format(pd.frame_start,pd.frame_stop))
   clearUnzipDir(r,xconfig)

   nmPath=downloadAndUnzip(ofb,r,"nmCorrData",xconfig)
   frame_data, frame_time, frame_duration, frame_origin, \
       frame_pixel_size, frame_orientation=\
       pd.readNMDirectory(nmPath,False)
   print('Frame time {}'.format(frame_time))
   clearUnzipDir(r,xconfig)

   ctPath=downloadAndUnzip(ofb,r,"ct",xconfig)
   ct_data,ct_origin,ct_pixel_size, \
      ct_orientation=pd.readCTDirectory(ctPath,False)
   print('CT pixel {}'.format(ct_pixel_size))
   clearUnzipDir(r,xconfig)

   ct_orientation=vtkInterface.completeOrientation(ct_orientation)
   frame_orientation=vtkInterface.completeOrientation(frame_orientation)

   ct={'data':ct_data,'origin':ct_origin,'pixel_size':ct_pixel_size,
      'orientation':ct_orientation}
   nm={'data':frame_data,'time':frame_time,'duration':frame_duration,
      'origin':frame_origin,'pixel_size':frame_pixel_size,
      'orientation':frame_orientation}

   return {'CT':ct,'NM':nm}
   
   print('Done')

def clearUnzipDir(r,xconfig):
   import config

   zipDir=config.getLocalDir(r,xconfig)
   try:
      os.mkdir(zipDir)
   except FileExistsError:
      shutil.rmtree(zipDir)
   return zipDir

def downloadAndUnzip(ofb,r,code,xconfig):
   import config 
   
   pathList=xconfig['tempDir'].split('/')
   pathList.insert(0,os.path.expanduser('~'))
   tempDir=os.path.join(*pathList)
   if not os.path.isdir(tempDir):
      print('Creating {}'.format(tempDir))
      os.makedirs(tempDir)

   orthancId=r[code+'OrthancId']
   fileCode='{}_{}'.format(config.getCode(r,xconfig),code)
   zipFile=os.path.join(tempDir,fileCode+'.zip')
   ofb.getZip('series',orthancId,zipFile)
   zipDir=clearUnzipDir(r,xconfig)

   try:
      outTxt=subprocess.check_output(["unzip","-d",zipDir,"-xj",zipFile])
   except subprocess.CalledProcessError:
      print("unzip failed for {}".format(zipFile))
      return ""

   return zipDir

def addCT(r,patient,xconfig):
   import config
   import vtkInterface
   ct=patient['CT']
   vtkData=vtkInterface.numpyToVTK(ct['data'],ct['data'].shape)
   nodeName=config.getNodeName(r,xconfig,'CT')
   addNode(nodeName,vtkData,ct['origin'],ct['pixel_size'],ct['orientation'],0)
   r['ct']=f'{nodeName}.nrrd'


def addFrames(r,patient,xconfig):
   import vtkInterface
   import config
   #convert data from numpy.array to vtkImageData
   #use time point it
   nm=patient['NM']
   print("NFrames: {}".format(nm['data'].shape[3]))
   for it in range(0,nm['data'].shape[3]):
      frame_data=nm['data'][:,:,:,it];
      nodeName=config.getNodeName(r,xconfig,'NM',it)
      vtkData=vtkInterface.numpyToVTK(frame_data,frame_data.shape)
      addNode(nodeName,vtkData,nm['origin'],nm['pixel_size'],nm['orientation'],1)
      #last one will be kept
      r['spect']=f'{nodeName}.nrrd'


def addNode(nodeName,v,origin,pixel_size,orientation,dataType):
   #origin,orientation in lps
   #dataType=0 is CT (to background)
   #dataType=1 is SPECT, view not adjusted, foreground, 
   newNode=slicer.vtkMRMLScalarVolumeNode()
   newNode.SetName(nodeName)
   v.SetOrigin([0,0,0])
   v.SetSpacing([1,1,1])
   ijkToRAS = vtk.vtkMatrix4x4()
   #think how to do this with image orientation
   #orientation from lps to ras
   rasOrientation=[-orientation[i] if (i%3 < 2) else orientation[i]
        for i in range(0,len(orientation))]
   #origin from lps to ras
   rasOrigin=[-origin[i] if (i%3<2) else origin[i] for i in range(0,len(origin))]

   for i in range(0,3):
      for j in range(0,3):
         ijkToRAS.SetElement(i,j,pixel_size[i]*rasOrientation[3*j+i])

      ijkToRAS.SetElement(i,3,rasOrigin[i])

   newNode.SetIJKToRASMatrix(ijkToRAS)
   newNode.SetAndObserveImageData(v)
   slicer.mrmlScene.AddNode(newNode)

def addDummyInputFunction(r,patient,xconfig):

   import config
   
   nm=patient['NM']
   n=nm['data'].shape[3]

   dnsNodeName=config.getNodeName(r,xconfig,'Dummy')
   dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode',dnsNodeName)
   if dns.GetNumberOfItems() == 0:
      try:
         dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
      except AttributeError:
         addDummyTable(dnsNodeName,n,nm)
         return
   else:
      dn = dns.GetItemAsObject(0)

   dn.SetSize(n)

   ft=nm['time']
   dt=nm['duration']
   for i in range(0,n):
      fx=ft[i]
      fy=dt[i]
      dn.SetValue(i, 0, fx)
      dn.SetValue(i, 1, fy)
      dn.SetValue(i, 2, 0)
      print('{} ({},{})'.format(i,fx,fy))



def addDummyTable(dnsNodeName,n,nm): 
   
   #add vtkMRMLTableNodes
   dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLTableNode',dnsNodeName)
   if dns.GetNumberOfItems() == 0:
      dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLTableNode())
      dn.SetName(dnsNodeName)
   else:
      dn = dns.GetItemAsObject(0)
      dn.RemoveAllColumns()
      

   ft=nm['time']
   dt=nm['duration']
   tCol=vtk.vtkDoubleArray()
   dCol=vtk.vtkDoubleArray()
   for i in range(n):
      tCol.InsertNextValue(ft[i])
      dCol.InsertNextValue(dt[i])

   tcol=dn.AddColumn(tCol)
   tcol.SetName('time')
   dcol=dn.AddColumn(dCol)
   dcol.SetName('duration')

def storeNode(fb,row,xconfig,node):
   import config

   suffix=".nrrd"
   isTable=False
   if node.__class__.__name__=="vtkMRMLDoubleArrayNode":
      suffix=".mcsv"
   if node.__class__.__name__=="vtkMRMLTableNode":
      suffix=".mcsv"
      isTable=True      
   if (node.__class__.__name__=="vtkMRMLTransformNode" or \
         node.__class__.__name__=="vtkMRMLGridTransformNode"):
      suffix=".h5"

   fileName=node.GetName()+suffix

   localPath=os.path.join(config.getLocalDir(row,xconfig),fileName)

   if isTable:
      storeTable(node,localPath)
   else:
      slicer.util.saveNode(node,localPath)
   print("Stored to: {}".format(localPath))
   labkeyPath=fb.buildPathURL(xconfig['project'],config.getPathList(row,xconfig))
   print ("Remote: {}".format(labkeyPath))
   #checks if exists
   remoteFile=labkeyPath+'/'+fileName
   fb.writeFileToFile(localPath,remoteFile)


def storeTable(node,filename):
   #mimic old vtkMRMLDoubleArray format
   table=node.GetTable()
   ft=table.GetColumnByName('time')
   fd=table.GetColumnByName('duration')
   n=ft.GetNumberOfValues()
   print(f'Storing {n} values')
   with open(filename,'w') as f:
      f.write(f'# measurement file {filename}\n')
      f.write('# no labels\n')
      for i in range(n):
         _t=ft.GetTuple1(i)
         _d=fd.GetTuple1(i)
         print(f'{_t},{_d},0')
         f.write(f'{_t},{_d},0\n')

      






def readPatient(fb,localDir,project,patientId):

    rDir=fb.formatPathURL(project,'/'.join([patientId]))
    lDir=os.path.join(localDir,patientId)
    if not os.path.isdir(lDir):
        os.makedirs(lDir)

    ok,files=fb.listRemoteDir(rDir)
    locFiles=[]
    for f in files:
        print(f)
        p=pathlib.Path(f)
        localFile=os.path.join(lDir,p.name)
        fb.readFileToFile(f,localFile)
        locFiles.append(localFile)
    print('Done')
    return locFiles

if __name__=='__main__':
    main(sys.argv[1])
    quit()