import slicer
import vtk
import os
from slicer.ScriptedLoadableModule import *
import ctk
import qt

class resample(ScriptedLoadableModule):
  """Uses ScriptedLoadableModule base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """
  def __init__(self, parent):
    ScriptedLoadableModule.__init__(self, parent)
    parent.title = "Resample"
    parent.categories = ["LabKey"]
    parent.dependencies = []
    parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
    parent.helpText = """
    Resample to different shapes
    """
    parent.acknowledgementText = """
    This module was developed within the frame of the ARRS sponsored medical
    physics research programe to investigate quantitative measurements of cardiac
    function using sestamibi-like tracers
    """ # replace with organization, grant and thanks.
    self.parent.helpText += self.getDefaultModuleDocumentationLink()
    self.parent = parent

#
# resampleWidget
#
class resampleWidget(ScriptedLoadableModuleWidget):
  """Uses ScriptedLoadableModuleWidget base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """

  def setup(self):
    ScriptedLoadableModuleWidget.setup(self)
    self.logic=resampleLogic(self)

    datasetCollapsibleButton = ctk.ctkCollapsibleButton()
    datasetCollapsibleButton.text = "Dataset"
    self.layout.addWidget(datasetCollapsibleButton)
    # Layout within the dummy collapsible button
    datasetFormLayout = qt.QFormLayout(datasetCollapsibleButton)

    self.transformNode=qt.QLineEdit("NodeToTransform")
    datasetFormLayout.addRow("TransformedNode:",self.transformNode)
    self.referenceNode=qt.QLineEdit("ReferenceNode")
    datasetFormLayout.addRow("ReferenceNode:",self.referenceNode)

    self.transformButton=qt.QPushButton("Transform")
    self.transformButton.clicked.connect(self.onTransformButtonClicked)
    datasetFormLayout.addRow("Volume:",self.transformButton)

    self.transformSegmentationButton=qt.QPushButton("Transform")
    self.transformSegmentationButton.clicked.connect(self.onTransformSegmentationButtonClicked)
    datasetFormLayout.addRow("Segmentation:",self.transformSegmentationButton)


  def onTransformButtonClicked(self):
    node=slicer.util.getFirstNodeByName(self.transformNode.text)
    if node==None:
        print("Transform node [{}] not found").format(self.transformNode.text)
        return
    refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
    if refNode==None:
        print("Reference node [{}] not found").format(self.referenceNode.text)
        return

    self.logic.rebinNode(node,refNode)

  def onTransformSegmentationButtonClicked(self):

      segNodes=slicer.util.getNodesByClass("vtkMRMLSegmentationNode")
      segNode=None
      for s in segNodes:
          print ("SegmentationNode: {}").format(s.GetName())
          if s.GetName()==self.transformNode.text:
              segNode=s
              break

      if segNode==None:
          print("Segmentation node [{}] not found").format(self.transformNode.text)
          return

      refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
      if refNode==None:
          print("Reference node [{}] not found").format(self.referenceNode.text)
          return

      self.logic.rebinSegmentation(segNode,refNode)



class resampleLogic(ScriptedLoadableModuleLogic):
  """This class should implement all the actual
  computation done by your module.  The interface
  should be such that other python code can import
  this class and make use of the functionality without
  requiring an instance of the Widget.
  Uses ScriptedLoadableModuleLogic base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """
  def __init__(self,parent):
      ScriptedLoadableModuleLogic.__init__(self, parent)
      try:
          fhome=os.environ["HOME"]
      except:
             #in windows, the variable is called HOMEPATH
          fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH']
      self.baseLog=os.path.join(fhome,".resample")
      if not os.path.isdir(self.baseLog):
          os.mkdir(self.baseLog)

  def printMe(self):
      print("resampleLogic ready")
      print("Log: {}").format(self.baseLog)


  def cast(self,newImage,originalImage):
      if newImage.GetPointData().GetScalars().GetDataType()==originalImage.GetPointData().GetScalars().GetDataType():
          return newImage
      outputType=originalImage.GetPointData().GetScalars().__class__.__name__
      shifter=vtk.vtkImageShiftScale()
      shifter.SetInputData(newImage)
      if outputType=="vtkUnsignedShortArray":
          shifter.SetOutputScalarTypeToUnsignedShort()
      if outputType=="vtkShortArray":
          shifter.SetOutputScalarTypeToShort()
      shifter.SetScale(1)
      shifter.SetShift(0)
      shifter.Update()
      return shifter.GetOutput()



  def rebinNode(self,node,refNode):
      #refNodeName="2SBMIRVolume19"
      #nodeName="2SMobrVolume19"
      #node=slicer.util.getFirstNodeByName(nodeName)
      #refNode=slicer.util.getFirstNodeByName(refNodeName)
      #transformNodeName="2SMobr_DF"
      #transformNode=slicer.util.getFirstNodeByName(transformNodeName)
      #node.SetAndObserveTransformNodeID(transformNode.GetID())

      log=open(os.path.join(self.baseLog,"rebinNode.log"),"w")


      log.write(("rebinNode: volume: {} ref: {}\n").format(node.GetName(),refNode.GetName()))
      refImage=refNode.GetImageData()
      n=refImage.GetNumberOfPoints()
      refIJKtoRAS=vtk.vtkMatrix4x4()
      refNode.GetIJKToRASMatrix(refIJKtoRAS)

      nodeRAStoIJK=vtk.vtkMatrix4x4()
      node.GetRASToIJKMatrix(nodeRAStoIJK)
      nodeName=node.GetName()

      coeff=vtk.vtkImageBSplineCoefficients()
      coeff.SetInputData(node.GetImageData())
      coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
      #between 3 and 5
      coeff.SetSplineDegree(5)
      coeff.Update()
      #interpolation ready to use

      #this is tough. COpy only links (ie. shallow copy)
      newImage=vtk.vtkImageData()
      newImage.DeepCopy(refNode.GetImageData())
      newImage=self.cast(newImage,node.GetImageData())
      newScalars=newImage.GetPointData().GetScalars()
      #doesn't set the scalars
      log.write(("Iterating: {} points\n").format(n))


      for i in range(0,n):
            refIJK=refImage.GetPoint(i)
            refIJK=list(refIJK)
            refIJK.append(1)
            #shift to world coordinates to match global points
            refPos=refIJKtoRAS.MultiplyPoint(refIJK)
            refPos=refPos[0:3]
            fWorld=[0,0,0]
            #apply potential transformations
            refNode.TransformPointToWorld(refPos,fWorld)

            #now do the opposite on the target node; reuse fPos
            nodePos=[0,0,0]
            node.TransformPointFromWorld(fWorld,nodePos)
            nodePos.append(1)
            nodeIJK=nodeRAStoIJK.MultiplyPoint(nodePos)

            #here we should apply some sort of interpolation
            nodeIJK=nodeIJK[0:3]
            v=coeff.Evaluate(nodeIJK)
            v0=newScalars.GetTuple1(i)
            newScalars.SetTuple1(i,v)
            if i%10000==0:
                log.write(("[{}]: {}->{}\n").format(i,v0,v))

      #node.SetName(nodeName+"_BU")


      newNode=slicer.vtkMRMLScalarVolumeNode()
      newNode.SetName(nodeName+"_RS")
      newNode.SetOrigin(refNode.GetOrigin())
      newNode.SetSpacing(refNode.GetSpacing())
      newNode.SetIJKToRASMatrix(refIJKtoRAS)

      newNode.SetAndObserveImageData(newImage)
      slicer.mrmlScene.AddNode(newNode)
      log.write(("Adding node {}\n").format(newNode.GetName()))
      log.close()
      return newNode


  def inMask(self,binaryRep,fpos):

       local=[0,0,0]

       segNode=binaryRep['node']
       segNode.TransformPointFromWorld(fpos,local)

       mask=binaryRep['mask']
       maskWorldToImageMatrix=vtk.vtkMatrix4x4()
       mask.GetWorldToImageMatrix(maskWorldToImageMatrix)
       local.append(1)
       maskIJK=maskWorldToImageMatrix.MultiplyPoint(local)
       #mask IJK is in image coordinates. However, binaryLabelMap is a truncated
       #version of vtkImageData, so more work is required
       maskIJK=maskIJK[0:3]#skip last (dummy) coordinate
       maskIJK=[r*s for r,s in zip(maskIJK,mask.GetSpacing())]
       maskIJK=[r+c for r,c in zip(maskIJK,mask.GetOrigin())]

       maskI=mask.FindPoint(maskIJK)
       try:
           return mask.GetPointData().GetScalars().GetTuple1(maskI)
       except:
           return 0

  def rebinSegment(self,refNode,binaryRep):

       refIJKtoRAS=vtk.vtkMatrix4x4()
       refNode.GetIJKToRASMatrix(refIJKtoRAS)
       refImage=refNode.GetImageData()

       #create new node for each segment
       newImage=vtk.vtkImageData()
       newImage.DeepCopy(refNode.GetImageData())
       n=newImage.GetNumberOfPoints()
       newScalars=newImage.GetPointData().GetScalars()

       segNode=binaryRep['node']
       mask=binaryRep['mask']

       for j in range(0,n):
           refIJK=refImage.GetPoint(j)
           refIJK=list(refIJK)
           refIJK.append(1)

           #shift to world coordinates to match global points
           refPos=refIJKtoRAS.MultiplyPoint(refIJK)
           refPos=refPos[0:3]
           fWorld=[0,0,0]

           #apply potential transformations
           refNode.TransformPointToWorld(refPos,fWorld)

           v=self.inMask(binaryRep,fWorld)
           #print("[{}]  Setting ({}) to: {}\n").format(j,fWorld,v)
           newScalars.SetTuple1(j,v)

       newNode=slicer.vtkMRMLScalarVolumeNode()
       newNode.SetName(segNode.GetName()+'_'+binaryRep['segId'])
       newNode.SetOrigin(refNode.GetOrigin())
       newNode.SetSpacing(refNode.GetSpacing())
       newNode.SetIJKToRASMatrix(refIJKtoRAS)

       newNode.SetAndObserveImageData(newImage)
       slicer.mrmlScene.AddNode(newNode)
       return newNode


  def rebinSegmentation(self,segNode,refNode):

       log=open(os.path.join(self.baseLog,"rebinSegmentation.log"),"w")

       log.write(("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName()))

       nSeg=segNode.GetSegmentation().GetNumberOfSegments()
       ## DEBUG:
       #nSeg=1
       #n=1000

       for i in range(0,nSeg):
           #segID
           segID=segNode.GetSegmentation().GetNthSegmentID(i)
           log.write(("Parsing segment {}").format(segNode.GetSegmentation.GetNthSegment(i).GetName()))
           binaryRep={'node':segNode,
                      'mask':segNode.GetBinaryLabelmapRepresentation(segID)}
           newNode=self.rebinSegment(refNode,binaryRep)
           log.write(("Adding node: {}").format(newNode.GetName()))

       log.close()

  def rebinSegmentation1(self,segNode,refNode):


     logfile="C:\\Users\\studen\\labkeyCache\\log\\resample.log"
     print("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName())
     refImage=refNode.GetImageData()
     refIJKtoRAS=vtk.vtkMatrix4x4()
     refNode.GetIJKToRASMatrix(refIJKtoRAS)
     refRAStoIJK=vtk.vtkMatrix4x4()
     refNode.GetRASToIJKMatrix(refRAStoIJK)

     nSeg=segNode.GetSegmentation().GetNumberOfSegments()
     ## DEBUG:
     nSeg=1

     for i in range(0,nSeg):
         #segID
         segID=segNode.GetSegmentation().GetNthSegmentID(i)
         binaryRep={'node':segNode,
                    'mask': segNode.GetBinaryLabelmapRepresentation(segID)}

         mask=binaryRep['mask']
         #create new node for each segment
         newImage=vtk.vtkImageData()
         newImage.DeepCopy(refNode.GetImageData())
         newScalars=newImage.GetPointData().GetScalars()
         refN=newImage.GetNumberOfPoints()
         for k in range(0,refN):
             newScalars.SetTuple1(k,0)

         maskN=binaryRep['mask'].GetNumberOfPoints()
         maskScalars=mask.GetPointData().GetScalars()
         maskImageToWorldMatrix=vtk.vtkMatrix4x4()
         binaryRep['mask'].GetImageToWorldMatrix(maskImageToWorldMatrix)

         for j in range(0,maskN):

             if maskScalars.GetTuple1(j)==0:
                 continue

             maskIJK=mask.GetPoint(j)
             maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
             maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
             maskIJK.append(1)
             maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
             maskPos=maskPos[0:3]
             fWorld=[0,0,0]
             #apply segmentation transformation
             segNode.TransformPointToWorld(maskPos,fWorld)

             refPos=[0,0,0]

             #apply potential reference transformations
             refNode.TransformPointFromWorld(fWorld,refPos)
             refPos.append(1)
             refIJK=refRAStoIJK.MultiplyPoint(refPos)
             refIJK=refIJK[0:3]
             i1=newImage.FindPoint(refIJK)
             if i1<0:
                 continue
             if i1<refN:
                 newScalars.SetTuple1(i1,1)

         newNode=slicer.vtkMRMLScalarVolumeNode()
         newNode.SetName(segNode.GetName()+'_'+segNode.GetSegmentation().GetNthSegmentID(i))
         newNode.SetOrigin(refNode.GetOrigin())
         newNode.SetSpacing(refNode.GetSpacing())
         newNode.SetIJKToRASMatrix(refIJKtoRAS)

         newNode.SetAndObserveImageData(newImage)
         slicer.mrmlScene.AddNode(newNode)




class resampleTest(ScriptedLoadableModuleTest):
  """
  This is the test case for your scripted module.
  Uses ScriptedLoadableModuleTest base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """

  def setUp(self):
    """ Do whatever is needed to reset the state - typically a scene clear will be enough.
    """
    slicer.mrmlScene.Clear(0)
    refNodeName="2SBMIRVolume19"
    nodeName="2SMobrVolume19"
    transformNodeName="2SMobr_DF"
    path="c:\\Users\\studen\\labkeyCache\\dinamic_spect\\Patients\\@files"

    refPath=os.path.join(path,"2SBMIR")
    refPath=os.path.join(refPath,refNodeName+".nrrd")
    slicer.util.loadNodeFromFile(refPath,'VolumeFile')

    transformPath=os.path.join(path,"2SMobr")
    transformPath=os.path.join(transformPath,transformNodeName+".h5")
    slicer.util.loadNodeFromFile(transformPath,'TransformFile')

    nodePath=os.path.join(path,"2SMobr")
    nodePath=os.path.join(nodePath,nodeName+".nrrd")

    slicer.util.loadNodeFromFile(nodePath,'VolumeFile')

    self.node=slicer.util.getFirstNodeByName(nodeName)
    self.refNode=slicer.util.getFirstNodeByName(refNodeName)
    self.transformNode=slicer.util.getFirstNodeByName(transformNodeName)
    self.node.SetAndObserveTransformNodeID(self.transformNode.GetID())

    self.resampler=resampleLogic(None)

  def runTest(self):
    """Run as few or as many tests as needed here.
    """
    self.setUp()
    self.test_resample()

  def test_resample(self):
    """ Ideally you should have several levels of tests.  At the lowest level
    tests should exercise the functionality of the logic with different inputs
    (both valid and invalid).  At higher levels your tests should emulate the
    way the user would interact with your code and confirm that it still works
    the way you intended.
    One of the most important features of the tests is that it should alert other
    developers when their changes will have an impact on the behavior of your
    module.  For example, if a developer removes a feature that you depend on,
    your test should break so they know that the feature is needed.
    """

    self.delayDisplay("Starting the test")
    #
    # first, get some data
    #
    self.resampler.rebinNode(self.node,self.refNode)

    self.delayDisplay('Test passed!')