Kaynağa Gözat

Adding resample to dynamicSPECT collection

Andrej 4 yıl önce
ebeveyn
işleme
e6c6a4a2aa
1 değiştirilmiş dosya ile 442 ekleme ve 0 silme
  1. 442 0
      resample/resample.py

+ 442 - 0
resample/resample.py

@@ -0,0 +1,442 @@
+import slicer
+import vtk
+import os
+from slicer.ScriptedLoadableModule import *
+import ctk
+import qt
+
+class resample(ScriptedLoadableModule):
+  """Uses ScriptedLoadableModule base class, available at:
+  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
+  """
+  def __init__(self, parent):
+    ScriptedLoadableModule.__init__(self, parent)
+    parent.title = "Resample"
+    parent.categories = ["LabKey"]
+    parent.dependencies = []
+    parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
+    parent.helpText = """
+    Resample to different shapes
+    """
+    parent.acknowledgementText = """
+    This module was developed within the frame of the ARRS sponsored medical
+    physics research programe to investigate quantitative measurements of cardiac
+    function using sestamibi-like tracers
+    """ # replace with organization, grant and thanks.
+    self.parent.helpText += self.getDefaultModuleDocumentationLink()
+    self.parent = parent
+
+#
+# resampleWidget
+#
+class resampleWidget(ScriptedLoadableModuleWidget):
+  """Uses ScriptedLoadableModuleWidget base class, available at:
+  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
+  """
+
+  def setup(self):
+    ScriptedLoadableModuleWidget.setup(self)
+    self.logic=resampleLogic(self)
+
+    datasetCollapsibleButton = ctk.ctkCollapsibleButton()
+    datasetCollapsibleButton.text = "Dataset"
+    self.layout.addWidget(datasetCollapsibleButton)
+    # Layout within the dummy collapsible button
+    datasetFormLayout = qt.QFormLayout(datasetCollapsibleButton)
+
+    self.transformNode=qt.QLineEdit("NodeToTransform")
+    datasetFormLayout.addRow("TransformedNode:",self.transformNode)
+    self.referenceNode=qt.QLineEdit("ReferenceNode")
+    datasetFormLayout.addRow("ReferenceNode:",self.referenceNode)
+
+    self.transformButton=qt.QPushButton("Transform")
+    self.transformButton.clicked.connect(self.onTransformButtonClicked)
+    datasetFormLayout.addRow("Volume:",self.transformButton)
+
+    self.transformSegmentationButton=qt.QPushButton("Transform")
+    self.transformSegmentationButton.clicked.connect(self.onTransformSegmentationButtonClicked)
+    datasetFormLayout.addRow("Segmentation:",self.transformSegmentationButton)
+
+
+  def onTransformButtonClicked(self):
+    node=slicer.util.getFirstNodeByName(self.transformNode.text)
+    if node==None:
+        print("Transform node [{}] not found").format(self.transformNode.text)
+        return
+    refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
+    if refNode==None:
+        print("Reference node [{}] not found").format(self.referenceNode.text)
+        return
+
+    self.logic.rebinNode(node,refNode)
+
+  def onTransformSegmentationButtonClicked(self):
+
+      segNodes=slicer.util.getNodesByClass("vtkMRMLSegmentationNode")
+      segNode=None
+      for s in segNodes:
+          print ("SegmentationNode: {}").format(s.GetName())
+          if s.GetName()==self.transformNode.text:
+              segNode=s
+              break
+
+      if segNode==None:
+          print("Segmentation node [{}] not found").format(self.transformNode.text)
+          return
+
+      refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
+      if refNode==None:
+          print("Reference node [{}] not found").format(self.referenceNode.text)
+          return
+
+      self.logic.rebinSegmentation(segNode,refNode)
+
+
+
+class resampleLogic(ScriptedLoadableModuleLogic):
+  """This class should implement all the actual
+  computation done by your module.  The interface
+  should be such that other python code can import
+  this class and make use of the functionality without
+  requiring an instance of the Widget.
+  Uses ScriptedLoadableModuleLogic base class, available at:
+  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
+  """
+  def __init__(self,parent):
+      ScriptedLoadableModuleLogic.__init__(self, parent)
+      try:
+          fhome=os.environ["HOME"]
+      except:
+             #in windows, the variable is called HOMEPATH
+          fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH']
+      self.baseLog=os.path.join(fhome,".resample")
+      if not os.path.isdir(self.baseLog):
+          os.mkdir(self.baseLog)
+
+  def printMe(self):
+      print("resampleLogic ready")
+      print("Log: {}").format(self.baseLog)
+
+
+  def cast(self,newImage,originalImage):
+      if newImage.GetPointData().GetScalars().GetDataType()==originalImage.GetPointData().GetScalars().GetDataType():
+          return newImage
+      outputType=originalImage.GetPointData().GetScalars().__class__.__name__
+      shifter=vtk.vtkImageShiftScale()
+      shifter.SetInputData(newImage)
+      if outputType=="vtkUnsignedShortArray":
+          shifter.SetOutputScalarTypeToUnsignedShort()
+      if outputType=="vtkShortArray":
+          shifter.SetOutputScalarTypeToShort()
+      shifter.SetScale(1)
+      shifter.SetShift(0)
+      shifter.Update()
+      return shifter.GetOutput()
+
+
+
+  def rebinNode(self,node,refNode):
+      #refNodeName="2SBMIRVolume19"
+      #nodeName="2SMobrVolume19"
+      #node=slicer.util.getFirstNodeByName(nodeName)
+      #refNode=slicer.util.getFirstNodeByName(refNodeName)
+      #transformNodeName="2SMobr_DF"
+      #transformNode=slicer.util.getFirstNodeByName(transformNodeName)
+      #node.SetAndObserveTransformNodeID(transformNode.GetID())
+
+      log=open(os.path.join(self.baseLog,"rebinNode.log"),"w")
+
+
+      log.write(("rebinNode: volume: {} ref: {}\n").format(node.GetName(),refNode.GetName()))
+      refImage=refNode.GetImageData()
+      n=refImage.GetNumberOfPoints()
+      refIJKtoRAS=vtk.vtkMatrix4x4()
+      refNode.GetIJKToRASMatrix(refIJKtoRAS)
+
+      nodeRAStoIJK=vtk.vtkMatrix4x4()
+      node.GetRASToIJKMatrix(nodeRAStoIJK)
+      nodeName=node.GetName()
+
+      coeff=vtk.vtkImageBSplineCoefficients()
+      coeff.SetInputData(node.GetImageData())
+      coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
+      #between 3 and 5
+      coeff.SetSplineDegree(5)
+      coeff.Update()
+      #interpolation ready to use
+
+      #this is tough. COpy only links (ie. shallow copy)
+      newImage=vtk.vtkImageData()
+      newImage.DeepCopy(refNode.GetImageData())
+      newImage=self.cast(newImage,node.GetImageData())
+      newScalars=newImage.GetPointData().GetScalars()
+      #doesn't set the scalars
+      log.write(("Iterating: {} points\n").format(n))
+
+
+      for i in range(0,n):
+            refIJK=refImage.GetPoint(i)
+            refIJK=list(refIJK)
+            refIJK.append(1)
+            #shift to world coordinates to match global points
+            refPos=refIJKtoRAS.MultiplyPoint(refIJK)
+            refPos=refPos[0:3]
+            fWorld=[0,0,0]
+            #apply potential transformations
+            refNode.TransformPointToWorld(refPos,fWorld)
+
+            #now do the opposite on the target node; reuse fPos
+            nodePos=[0,0,0]
+            node.TransformPointFromWorld(fWorld,nodePos)
+            nodePos.append(1)
+            nodeIJK=nodeRAStoIJK.MultiplyPoint(nodePos)
+
+            #here we should apply some sort of interpolation
+            nodeIJK=nodeIJK[0:3]
+            v=coeff.Evaluate(nodeIJK)
+            v0=newScalars.GetTuple1(i)
+            newScalars.SetTuple1(i,v)
+            if i%10000==0:
+                log.write(("[{}]: {}->{}\n").format(i,v0,v))
+
+      #node.SetName(nodeName+"_BU")
+
+
+      newNode=slicer.vtkMRMLScalarVolumeNode()
+      newNode.SetName(nodeName+"_RS")
+      newNode.SetOrigin(refNode.GetOrigin())
+      newNode.SetSpacing(refNode.GetSpacing())
+      newNode.SetIJKToRASMatrix(refIJKtoRAS)
+
+      newNode.SetAndObserveImageData(newImage)
+      slicer.mrmlScene.AddNode(newNode)
+      log.write(("Adding node {}\n").format(newNode.GetName()))
+      log.close()
+      return newNode
+
+
+  def inMask(self,binaryRep,fpos):
+
+       local=[0,0,0]
+
+       segNode=binaryRep['node']
+       segNode.TransformPointFromWorld(fpos,local)
+
+       mask=binaryRep['mask']
+       maskWorldToImageMatrix=vtk.vtkMatrix4x4()
+       mask.GetWorldToImageMatrix(maskWorldToImageMatrix)
+       local.append(1)
+       maskIJK=maskWorldToImageMatrix.MultiplyPoint(local)
+       #mask IJK is in image coordinates. However, binaryLabelMap is a truncated
+       #version of vtkImageData, so more work is required
+       maskIJK=maskIJK[0:3]#skip last (dummy) coordinate
+       maskIJK=[r*s for r,s in zip(maskIJK,mask.GetSpacing())]
+       maskIJK=[r+c for r,c in zip(maskIJK,mask.GetOrigin())]
+
+       maskI=mask.FindPoint(maskIJK)
+       try:
+           return mask.GetPointData().GetScalars().GetTuple1(maskI)
+       except:
+           return 0
+
+  def rebinSegment(self,refNode,binaryRep):
+
+       refIJKtoRAS=vtk.vtkMatrix4x4()
+       refNode.GetIJKToRASMatrix(refIJKtoRAS)
+       refImage=refNode.GetImageData()
+
+       #create new node for each segment
+       newImage=vtk.vtkImageData()
+       newImage.DeepCopy(refNode.GetImageData())
+       n=newImage.GetNumberOfPoints()
+       newScalars=newImage.GetPointData().GetScalars()
+
+       segNode=binaryRep['node']
+       mask=binaryRep['mask']
+
+       for j in range(0,n):
+           refIJK=refImage.GetPoint(j)
+           refIJK=list(refIJK)
+           refIJK.append(1)
+
+           #shift to world coordinates to match global points
+           refPos=refIJKtoRAS.MultiplyPoint(refIJK)
+           refPos=refPos[0:3]
+           fWorld=[0,0,0]
+
+           #apply potential transformations
+           refNode.TransformPointToWorld(refPos,fWorld)
+
+           v=self.inMask(binaryRep,fWorld)
+           #print("[{}]  Setting ({}) to: {}\n").format(j,fWorld,v)
+           newScalars.SetTuple1(j,v)
+
+       newNode=slicer.vtkMRMLScalarVolumeNode()
+       newNode.SetName(segNode.GetName()+'_'+binaryRep['segId'])
+       newNode.SetOrigin(refNode.GetOrigin())
+       newNode.SetSpacing(refNode.GetSpacing())
+       newNode.SetIJKToRASMatrix(refIJKtoRAS)
+
+       newNode.SetAndObserveImageData(newImage)
+       slicer.mrmlScene.AddNode(newNode)
+       return newNode
+
+
+  def rebinSegmentation(self,segNode,refNode):
+
+       log=open(os.path.join(self.baseLog,"rebinSegmentation.log"),"w")
+
+       log.write(("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName()))
+
+       nSeg=segNode.GetSegmentation().GetNumberOfSegments()
+       ## DEBUG:
+       #nSeg=1
+       #n=1000
+
+       for i in range(0,nSeg):
+           #segID
+           segID=segNode.GetSegmentation().GetNthSegmentID(i)
+           log.write(("Parsing segment {}").format(segNode.GetSegmentation.GetNthSegment(i).GetName()))
+           binaryRep={'node':segNode,
+                      'mask':segNode.GetBinaryLabelmapRepresentation(segID)}
+           newNode=self.rebinSegment(refNode,binaryRep)
+           log.write(("Adding node: {}").format(newNode.GetName()))
+
+       log.close()
+
+  def rebinSegmentation1(self,segNode,refNode):
+
+
+     logfile="C:\\Users\\studen\\labkeyCache\\log\\resample.log"
+     print("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName())
+     refImage=refNode.GetImageData()
+     refIJKtoRAS=vtk.vtkMatrix4x4()
+     refNode.GetIJKToRASMatrix(refIJKtoRAS)
+     refRAStoIJK=vtk.vtkMatrix4x4()
+     refNode.GetRASToIJKMatrix(refRAStoIJK)
+
+     nSeg=segNode.GetSegmentation().GetNumberOfSegments()
+     ## DEBUG:
+     nSeg=1
+
+     for i in range(0,nSeg):
+         #segID
+         segID=segNode.GetSegmentation().GetNthSegmentID(i)
+         binaryRep={'node':segNode,
+                    'mask': segNode.GetBinaryLabelmapRepresentation(segID)}
+
+         mask=binaryRep['mask']
+         #create new node for each segment
+         newImage=vtk.vtkImageData()
+         newImage.DeepCopy(refNode.GetImageData())
+         newScalars=newImage.GetPointData().GetScalars()
+         refN=newImage.GetNumberOfPoints()
+         for k in range(0,refN):
+             newScalars.SetTuple1(k,0)
+
+         maskN=binaryRep['mask'].GetNumberOfPoints()
+         maskScalars=mask.GetPointData().GetScalars()
+         maskImageToWorldMatrix=vtk.vtkMatrix4x4()
+         binaryRep['mask'].GetImageToWorldMatrix(maskImageToWorldMatrix)
+
+         for j in range(0,maskN):
+
+             if maskScalars.GetTuple1(j)==0:
+                 continue
+
+             maskIJK=mask.GetPoint(j)
+             maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
+             maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
+             maskIJK.append(1)
+             maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
+             maskPos=maskPos[0:3]
+             fWorld=[0,0,0]
+             #apply segmentation transformation
+             segNode.TransformPointToWorld(maskPos,fWorld)
+
+             refPos=[0,0,0]
+
+             #apply potential reference transformations
+             refNode.TransformPointFromWorld(fWorld,refPos)
+             refPos.append(1)
+             refIJK=refRAStoIJK.MultiplyPoint(refPos)
+             refIJK=refIJK[0:3]
+             i1=newImage.FindPoint(refIJK)
+             if i1<0:
+                 continue
+             if i1<refN:
+                 newScalars.SetTuple1(i1,1)
+
+         newNode=slicer.vtkMRMLScalarVolumeNode()
+         newNode.SetName(segNode.GetName()+'_'+segNode.GetSegmentation().GetNthSegmentID(i))
+         newNode.SetOrigin(refNode.GetOrigin())
+         newNode.SetSpacing(refNode.GetSpacing())
+         newNode.SetIJKToRASMatrix(refIJKtoRAS)
+
+         newNode.SetAndObserveImageData(newImage)
+         slicer.mrmlScene.AddNode(newNode)
+
+
+
+
+class resampleTest(ScriptedLoadableModuleTest):
+  """
+  This is the test case for your scripted module.
+  Uses ScriptedLoadableModuleTest base class, available at:
+  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
+  """
+
+  def setUp(self):
+    """ Do whatever is needed to reset the state - typically a scene clear will be enough.
+    """
+    slicer.mrmlScene.Clear(0)
+    refNodeName="2SBMIRVolume19"
+    nodeName="2SMobrVolume19"
+    transformNodeName="2SMobr_DF"
+    path="c:\\Users\\studen\\labkeyCache\\dinamic_spect\\Patients\\@files"
+
+    refPath=os.path.join(path,"2SBMIR")
+    refPath=os.path.join(refPath,refNodeName+".nrrd")
+    slicer.util.loadNodeFromFile(refPath,'VolumeFile')
+
+    transformPath=os.path.join(path,"2SMobr")
+    transformPath=os.path.join(transformPath,transformNodeName+".h5")
+    slicer.util.loadNodeFromFile(transformPath,'TransformFile')
+
+    nodePath=os.path.join(path,"2SMobr")
+    nodePath=os.path.join(nodePath,nodeName+".nrrd")
+
+    slicer.util.loadNodeFromFile(nodePath,'VolumeFile')
+
+    self.node=slicer.util.getFirstNodeByName(nodeName)
+    self.refNode=slicer.util.getFirstNodeByName(refNodeName)
+    self.transformNode=slicer.util.getFirstNodeByName(transformNodeName)
+    self.node.SetAndObserveTransformNodeID(self.transformNode.GetID())
+
+    self.resampler=resampleLogic(None)
+
+  def runTest(self):
+    """Run as few or as many tests as needed here.
+    """
+    self.setUp()
+    self.test_resample()
+
+  def test_resample(self):
+    """ Ideally you should have several levels of tests.  At the lowest level
+    tests should exercise the functionality of the logic with different inputs
+    (both valid and invalid).  At higher levels your tests should emulate the
+    way the user would interact with your code and confirm that it still works
+    the way you intended.
+    One of the most important features of the tests is that it should alert other
+    developers when their changes will have an impact on the behavior of your
+    module.  For example, if a developer removes a feature that you depend on,
+    your test should break so they know that the feature is needed.
+    """
+
+    self.delayDisplay("Starting the test")
+    #
+    # first, get some data
+    #
+    self.resampler.rebinNode(self.node,self.refNode)
+
+    self.delayDisplay('Test passed!')