import DICOMLib
import sys
import json
import numpy
import zipfile
import shutil
import os

def main(configFile=None):

   

   print('Imported!')
   with open(configFile) as f:
      config=json.load(f)

   config.update(connectDB(config))
   parseData(config,getMeanHeartDose)

def connectDB(setup):
   nixSuite=os.path.join(os.path.expanduser('~'),'software','src','nixsuite')
   sys.path.append(os.path.join(nixSuite,'wrapper'))
   import nixWrapper
   nixWrapper.loadLibrary('labkeyInterface')
   import labkeyInterface
   import labkeyDatabaseBrowser
   import labkeyFileBrowser

   nixWrapper.loadLibrary('orthancInterface')
   import orthancInterface
   import orthancDatabaseBrowser
   import orthancFileBrowser
   importlib.reload(orthancFileBrowser)

   net=labkeyInterface.labkeyInterface()
   qfile='{}.json'.format(setup['server'])
   fconfig=os.path.join(os.path.expanduser('~'),'.labkey',qfile)
   net.init(fconfig)
   net.getCSRF()

   onet=orthancInterface.orthancInterface()
   onet.init(fconfig)

   return {"db":labkeyDatabaseBrowser.labkeyDB(net),
      "fb":labkeyFileBrowser.labkeyFileBrowser(net),
      "odb":orthancDatabaseBrowser.orthancDB(onet),
      "ofb":orthancFileBrowser.orthancFileBrowser(onet)}


#explicit template
def _updateRow(config,r):
   print(r)
   return False

def parseData(config,updateRow=_updateRow):
#iterates over data appliying updateRow function to every row
#updateRow is an implementation of a generic function 
#with arguments
#   def updateRow(config,r)
#returning True if row needs to be updated on the server
#and False otherwise
#update values are stored in the r dict

   db=config['db']
    
   qFilter=config.get('qFilter',[])
   debug=config.get('debug',False)
#get dataset 
   ds=db.selectRows(config['project'],config['schema'],config['query'],qFilter)
   rows=ds['rows']
#shorten list in debug mode
   if debug:
      rows=rows[0:3]
   for r in rows:
#this could be a generic function either as parameter of config or an argument to parseData
      update=updateRow(config,r)
      #print(r)
      if not update:
         continue
      db.modifyRows('update',config['project'],config['schema'],config['query'],[r])



def getMeanHeartDose(config,r):
#calculates mean heart dose
#stores it as doseHeart to row r
#return True if r needs to be updated on the server
#and False if r is unchanged
   sid=r['orthancStudyId']
   if not sid:
      print('No study for {}'.format(r['ParticipantId']))
      return False
   doseHeart=r['doseHeart']
   if doseHeart:
#no need to update
      return False

   #outDir=getDicomZip(config,sid)
   outDir=getDicomInstances(config,sid)

   nodes=loadVolumes(outDir)
   msg=checkNodes(config,nodes)
   if len(msg)>0:
      r['comments']=msg
   else:
      r['doseHeart']=getMeanDose(nodes,'Heart')
   clearDir(outDir)
#needs updating
   return True

def loadVolumes(dataDir):
   nodeNames=[]
   with DICOMLib.DICOMUtils.TemporaryDICOMDatabase() as db:
      DICOMLib.DICOMUtils.importDicom(dataDir, db)
      patientUIDs = db.patients()
      for patientUID in patientUIDs:
         print(patientUID)
         nodeNames.extend(DICOMLib.DICOMUtils.loadPatientByUID(patientUID))
   #print(nodes)
   nodes=[slicer.util.getNode(pattern=n) for n in nodeNames]
   volumeNodes=[n for n in nodes if n.GetClassName()=='vtkMRMLScalarVolumeNode']
   doseNodes=[n for n in volumeNodes if n.GetName().find('RTDOSE')>-1]
   segmentationNodes=[n for n in nodes if n.GetClassName()=='vtkMRMLSegmentationNode']
   nv=len(volumeNodes)
   ns=len(segmentationNodes)
   nd=len(doseNodes)
   print(f'vol:{nv} seg:{ns} dose: {nd}')

   return {'vol':volumeNodes,'dose':doseNodes,'seg':segmentationNodes}


def checkNodes(config,nodes):
   msg=''
   nD=len(nodes['dose'])
   if nD>1:
      msg+=f'DOSE[{nD}]'
   nS=len(nodes['seg'])
   if nS>1:
      if len(msg)>0:
         msg+='/'
      msg+=f'SEG[{nS}]'
   return msg

def getMeanDose(nodes,target):
   segNode=nodes['seg'][0]
   seg=segNode.GetSegmentation()

   segmentIds=seg.GetSegmentIDs()
#[seg.GetNthSegment(i) for i in range(seg.GetNumberOfSegments())]

   targetSegmentIds=[s for s in segmentIds if seg.GetSegment(s).GetName()==target]
   #labelmapVolumeNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode")
   #export=slicer.vtkSlicerSegmentationsModuleLogic.ExportSegmentsToLabelmapNode
   #export(segNode, targetSegmentIds, labelmapVolumeNode, nodes['dose'][0])
   #nodes.update({'labelMap':labelmapVolumeNode})

   doseNode=nodes['dose'][0]
   doseArray = slicer.util.arrayFromVolume(doseNode)
   segmentArray = slicer.util.arrayFromSegmentBinaryLabelmap(segNode, targetSegmentIds[0], doseNode)
   doseVoxels = doseArray[segmentArray != 0]
   print(numpy.mean(doseVoxels))
#add a float() to avoid JSON complaining about float32 converion
   return float(numpy.mean(doseVoxels))

def getDicomInstances(config,sid):
   odb=config['odb']
   ofb=config['ofb']
   sd=odb.getStudyData(sid)
   series=sd['Series']
   instances=[]
   for s in series:
      sed=odb.getSeriesData(s)
      instances.extend(sed['Instances'])
   #download instances one by one
   baseDir=config.get('baseDir',os.path.join(os.path.expanduser('~'),'temp'))
   outDir=os.path.join(baseDir,sid)
   clearDir(outDir)
   os.mkdir(outDir)
   for oid in instances:
      local=os.path.join(outDir,f'{oid}.dcm')
      ofb.getInstance(oid,local)
   return outDir


def getDicomZip(config,sid):
   ofb=config['ofb']
   baseDir=config.get('baseDir',os.path.join(os.path.expanduser('~'),'temp'))
   fname=f'{sid}.zip'
   path=os.path.join(baseDir,fname)
   if not os.path.isfile(path):
      ofb.getZip('studies',sid,path,'archive')
   print(f'Using {path}')
   #unzip path
   outDir=extractZip(config,path)
   os.remove(path)
   return outDir

def extractZip(config,fname):
#flattens the zip files in the baseDir/bname directory 
#where bname is the basename of the file without the .zip suffix
   fzip=zipfile.ZipFile(fname)
   names=fzip.namelist()
   bname=os.path.basename(fname)
   bname=bname.replace('.zip','')
   baseDir=config['baseDir']
   outDir=os.path.join(baseDir,bname)
   #clean
   clearDir(outDir)
   os.mkdir(outDir)
   outnames=[os.path.join(outDir,f'out{i:03d}.dcm') for i in range(len(names))]
   #extracts and renames (avoids *nix and win directory separator confusion)
   for (member,out) in zip(names,outnames):
      with fzip.open(member) as zf, open(out, 'wb') as f:
         shutil.copyfileobj(zf, f)

   return outDir

def clearDir(outDir):
   if os.path.isdir(outDir):
      shutil.rmtree(outDir)

if __name__=='__main__':
   try:
      main(sys.argv[1])
   except IndexError:
      main()
   print('Succesful completion')
   quit()