from slicer.ScriptedLoadableModule import *
import slicerNetwork
import qt,vtk,ctk,slicer
import os
import resample
import exportDicom
import loadPatient


class CTRegistration(slicer.ScriptedLoadableModule.ScriptedLoadableModule):
  def __init__(self,parent):
        slicer.ScriptedLoadableModule.ScriptedLoadableModule.__init__(self, parent)
        self.className="CTRegistration"
        self.parent.title="CTRegistration"
        self.parent.categories = ["EMBRACE"]
        self.parent.dependencies = []
        self.parent.contributors = ["Andrej Studen (University of Ljubljana)"] # replace with "Firstname Lastname (Organization)"
        self.parent.helpText = """
        This is an example of scripted loadable module bundled in an extension.
        It performs registration of CT (EBRT) and MRI (BRT)
        """
        self.parent.helpText += self.getDefaultModuleDocumentationLink()
        self.parent.acknowledgementText = """
        This extension developed within Medical Physics research programe of ARRS
        """ # replace with organization, grant and thanks.

#
# dataExplorerWidget



class CTRegistrationWidget(ScriptedLoadableModuleWidget):
  """Uses ScriptedLoadableModuleWidget base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """

  def setup(self):
      ScriptedLoadableModuleWidget.setup(self)
      self.logic=CTRegistrationLogic(self)

      try:
           fhome=os.environ["HOME"]
      except:
              #in windows, the variable is called HOMEPATH
           fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH']

      cfgPath=os.path.join(fhome,".labkey")
      cfgPath=os.path.join(cfgPath,"onko-nix.json")

      self.onkoNet=slicerNetwork.labkeyURIHandler()
      self.onkoNet.parseConfig(cfgPath)
      self.onkoNet.initRemote()

      self.project='EMBRACE/Studija'

      datasetCollapsibleButton = ctk.ctkCollapsibleButton()
      datasetCollapsibleButton.text = "Node data"
      self.layout.addWidget(datasetCollapsibleButton)
      # Layout within the dummy collapsible button
      datasetFormLayout = qt.QFormLayout(datasetCollapsibleButton)

      self.patientId=qt.QLineEdit("LJU004")
      datasetFormLayout.addRow("EMBRACE ID:",self.patientId)

      loadDataButton=qt.QPushButton("Load")
      loadDataButton.clicked.connect(self.onLoadDataButtonClicked)
      datasetFormLayout.addRow("Data:",loadDataButton)

      self.exportButton=qt.QPushButton("Export")
      self.exportButton.clicked.connect(self.onExportButtonClicked)
      datasetFormLayout.addRow("Export:",self.exportButton)

      self.debugCheckBox=qt.QCheckBox()
      self.debugCheckBox.setChecked(True)
      datasetFormLayout.addRow("Debug:",self.debugCheckBox)

  def onLoadDataButtonClicked(self):
      self.logic.loadData(self.onkoNet,self.patientId.text)

  def onExportButtonClicked(self):
      self.logic.export(self.onkoNet,self.project,self.patientId.text)



class CTRegistrationLogic(slicer.ScriptedLoadableModule.ScriptedLoadableModuleLogic):
  def __init__(self,parent):
       slicer.ScriptedLoadableModule.ScriptedLoadableModuleLogic.__init__(self, parent)
       self.exporter=exportDicom.exportDicomLogic(self)
       self.importer=loadPatient.loadPatientLogic(self)
       self.resampler=resample.resampleLogic(self)

       try:
           fhome=os.environ["HOME"]
       except:
              #in windows, the variable is called HOMEPATH
           fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH']

  def setLocal(self,basePath):
      self.importer.setLocal(basePath)
      
  def getLocalRegistrationPath(self,net,project,patientId):
      path=os.path.join(net.GetLocalCacheDirectory(),project)
      path=os.path.join(path,"%40files")
      path=os.path.join(path,patientId)
      path=os.path.join(path,"Registration")
      if not os.path.isdir(path):
          os.mkdir(path)

      relDir=net.GetRelativePathFromLocalPath(path)
      remoteDir=net.GetLabkeyPathFromRelativePath(relDir)

      if not net.isRemoteDir(remoteDir):
          net.mkdir(remoteDir)
      return path


  def loadData(self,net,patientId):
      self.importer.setURIHandler(net)
      self.ct=self.importer.loadCT(patientId)
      if len(self.ct)<1:
          print("No CT found for patient {}").format(patientId)
          return
      if len(self.ct)>1:
          print("Too many CT volumes found for patient {}").format(patientId)
          return
      volumeNode=self.ct[0]['node']
      volumeNode.SetName(patientId+"_CT")
      self.ctrs=self.importer.loadCTRS(patientId)
      if len(self.ctrs)<1:
          print("No CT-segmentation found for patient {}").format(patientId)
          return
      if len(self.ctrs)>1:
          print("Multiple CT-segmentations found for patient {}").format(patientId)
          return
      segNode=self.ctrs[0]['node']
      segNode.SetName(patientId+"_CTRS")

      self.dmr=self.importer.loadDMR(patientId)
      if len(self.dmr)<1:
          print("No DMR found for patient {}").format(patientId)
          return
      if len(self.dmr)>1:
          print("Multiple DMR found for patient {}").format(patientId)
          return
      dmrMetadata=self.dmr[0]['metadata']
      if len(dmrMetadata['frameOfReferenceInstanceUid'])<1:
          refId=self.exporter.generateFrameOfReferenceUUID('volume')
          dmrMetadata['frameOfReferenceInstanceUid']=refId
      self.dmr[0]['node'].SetName(patientId+"_DMR")

  def exportFile(self,net,path):
      print("localPath: {}").format(path)
      relativePath=net.GetRelativePathFromLocalPath(path)
      print("relativePath: {}").format(relativePath)
      remotePath=net.GetLabkeyPathFromRelativePath(relativePath)
      print("remotePath: {}").format(relativePath)
      net.copyLocalFileToRemote(path,remotePath)



  def exportTransformation(self,net,project,patientId):
     tNodeName=patientId+"_T2_DF"
     tNode=slicer.util.getFirstNodeByName(tNodeName)
     path=self.getLocalRegistrationPath(net,project,patientId)
     fname=tNodeName+".h5"
     path=os.path.join(path,fname)
     slicer.util.saveNode(tNode,path)
     self.exportFile(net,path)
     return tNode


  def exportSegmentation(self,net,project,patientId,tNode):

      #DMR
      dmrNode=self.dmr[0]['node']
      dmrMetadata=self.dmr[0]['metadata']

      #segmentations
      segNode=self.ctrs[0]['node']
      segMetadata=self.ctrs[0]['metadata']
      segMetadata['frameOfReferenceInstanceUid']=dmrMetadata['frameOfReferenceInstanceUid']
      nSeg=segNode.GetSegmentation().GetNumberOfSegments()

      for i in range(0,nSeg):
         segId=segNode.GetSegmentation().GetNthSegmentID(i)
         segment=segNode.GetSegmentation().GetSegment(segId)
         nodeName=segNode.GetName()+'_'+segId
         node=slicer.util.getFirstNodeByName(nodeName)

         if node==None:
            segNode.SetAndObserveTransformNodeID(tNode.GetID())
            binaryRep={'node':segNode,
                      'mask':segNode.GetBinaryLabelmapRepresentation(segId),
                      'segId':segId}
            if binaryRep['mask']==None:
                segNode.CreateBinaryLabelmapRepresentation()
                binaryRep['mask']=segNode.GetBinaryLabelmapRepresentation(segId)

            self.resampler.rebinSegment(dmrNode,binaryRep)
            node=slicer.util.getFirstNodeByName(nodeName)

         segMetadata['seriesInstanceUid']=self.exporter.generateSeriesUUID('segmentation')
         segMetadata['seriesNumber']=i
         segMetadata['patientId']=patientId

         self.exporter.exportNode(net,project,node,segMetadata)

  def exportCT(self,net,project,patientId,tNode):
             #DMR
         dmrNode=self.dmr[0]['node']
         dmrMetadata=self.dmr[0]['metadata']

         #CT (rebinned)
         ctNode=self.ct[0]['node']
         ctMetadata=self.ct[0]['metadata']
         ctMetadata['frameOfReferenceInstanceUid']=dmrMetadata['frameOfReferenceInstanceUid']
         ctMetadata['seriesInstanceUid']=self.exporter.generateSeriesUUID('segmentation')
         ctMetadata['patientId']=patientId

         ctNode=self.ct[0]['node']
         ctName=ctNode.GetName()
         ctRBName=ctNode.GetName()+"_TF"
         ctRBNode=slicer.util.getFirstNodeByName(ctRBName)
         if ctRBNode==None:
             ctNode.SetAndObserveTransformNodeID(tNode.GetID())
             ctRBNode=self.resampler.rebinNode(ctNode,dmrNode)
             ctRBNode.SetName(ctRBName)
             ctNode.SetName(ctName)
         self.exporter.exportNode(net,project,ctRBNode,ctMetadata)




  def export(self,net,project,patientId):
      #transformation
      tNode=self.exportTransformation(net,project,patientId)

      self.exportSegmentation(net,project,patientId,tNode)
      #tNodeName=patientId+"_T2_DF"
      #tNode=slicer.util.getFirstNodeByName(tNodeName)

      self.exportCT(net,project,patientId,tNode)