Andrej Studen 5 лет назад
Родитель
Сommit
94765b79c1
2 измененных файлов с 258 добавлено и 3 удалено
  1. 1 1
      cardiacSPECT/cardiacSPECT.py
  2. 257 2
      cardiacSPECT/resample.py

+ 1 - 1
cardiacSPECT/cardiacSPECT.py

@@ -672,7 +672,7 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
       file=os.path.join(localPath,fileName)
       slicer.util.saveNode(node,file)
       print("Stored to: {}").format(file)
-      f=open(file)
+      f=open(file,"rb")
       remoteFile=labkeyPath+'/'+fileName
       self.pd.net.put(remoteFile,f.read())
 

+ 257 - 2
cardiacSPECT/resample.py

@@ -2,6 +2,8 @@ import slicer
 import vtk
 import os
 from slicer.ScriptedLoadableModule import *
+import ctk
+import qt
 
 class resample(ScriptedLoadableModule):
   """Uses ScriptedLoadableModule base class, available at:
@@ -36,6 +38,60 @@ class resampleWidget(ScriptedLoadableModuleWidget):
     ScriptedLoadableModuleWidget.setup(self)
     self.logic=resampleLogic(self)
 
+    datasetCollapsibleButton = ctk.ctkCollapsibleButton()
+    datasetCollapsibleButton.text = "Dataset"
+    self.layout.addWidget(datasetCollapsibleButton)
+    # Layout within the dummy collapsible button
+    datasetFormLayout = qt.QFormLayout(datasetCollapsibleButton)
+
+    self.transformNode=qt.QLineEdit("NodeToTransform")
+    datasetFormLayout.addRow("TransformedNode:",self.transformNode)
+    self.referenceNode=qt.QLineEdit("ReferenceNode")
+    datasetFormLayout.addRow("ReferenceNode:",self.referenceNode)
+
+    self.transformButton=qt.QPushButton("Transform")
+    self.transformButton.clicked.connect(self.onTransformButtonClicked)
+    datasetFormLayout.addRow("Volume:",self.transformButton)
+
+    self.transformSegmentationButton=qt.QPushButton("Transform")
+    self.transformSegmentationButton.clicked.connect(self.onTransformSegmentationButtonClicked)
+    datasetFormLayout.addRow("Segmentation:",self.transformSegmentationButton)
+
+
+  def onTransformButtonClicked(self):
+    node=slicer.util.getFirstNodeByName(self.transformNode.text)
+    if node==None:
+        print("Transform node [{}] not found").format(self.transformNode.text)
+        return
+    refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
+    if refNode==None:
+        print("Reference node [{}] not found").format(self.referenceNode.text)
+        return
+
+    self.logic.rebinNode(node,refNode)
+
+  def onTransformSegmentationButtonClicked(self):
+
+      segNodes=slicer.util.getNodesByClass("vtkMRMLSegmentationNode")
+      segNode=None
+      for s in segNodes:
+          print ("SegmentationNode: {}").format(s.GetName())
+          if s.GetName()==self.transformNode.text:
+              segNode=s
+              break
+
+      if segNode==None:
+          print("Segmentation node [{}] not found").format(self.transformNode.text)
+          return
+
+      refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
+      if refNode==None:
+          print("Reference node [{}] not found").format(self.referenceNode.text)
+          return
+
+      self.logic.rebinSegmentation(segNode,refNode)
+
+
 
 class resampleLogic(ScriptedLoadableModuleLogic):
   """This class should implement all the actual
@@ -48,9 +104,36 @@ class resampleLogic(ScriptedLoadableModuleLogic):
   """
   def __init__(self,parent):
       ScriptedLoadableModuleLogic.__init__(self, parent)
+      try:
+          fhome=os.environ["HOME"]
+      except:
+             #in windows, the variable is called HOMEPATH
+          fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH']
+      self.baseLog=os.path.join(fhome,".resample")
+      if not os.path.isdir(self.baseLog):
+          os.mkdir(self.baseLog)
 
   def printMe(self):
       print("resampleLogic ready")
+      print("Log: {}").format(self.baseLog)
+
+
+  def cast(self,newImage,originalImage):
+      if newImage.GetPointData().GetScalars().GetDataType()==originalImage.GetPointData().GetScalars().GetDataType():
+          return newImage
+      outputType=originalImage.GetPointData().GetScalars().__class__.__name__
+      shifter=vtk.vtkImageShiftScale()
+      shifter.SetInputData(newImage)
+      if outputType=="vtkUnsignedShortArray":
+          shifter.SetOutputScalarTypeToUnsignedShort()
+      if outputType=="vtkShortArray":
+          shifter.SetOutputScalarTypeToShort()
+      shifter.SetScale(1)
+      shifter.SetShift(0)
+      shifter.Update()
+      return shifter.GetOutput()
+
+
 
   def rebinNode(self,node,refNode):
       #refNodeName="2SBMIRVolume19"
@@ -61,7 +144,10 @@ class resampleLogic(ScriptedLoadableModuleLogic):
       #transformNode=slicer.util.getFirstNodeByName(transformNodeName)
       #node.SetAndObserveTransformNodeID(transformNode.GetID())
 
-      print("rebinNode: {} {}").format(node,refNode)
+      log=open(os.path.join(self.baseLog,"rebinNode.log"),"w")
+
+
+      log.write(("rebinNode: volume: {} ref: {}\n").format(node.GetName(),refNode.GetName()))
       refImage=refNode.GetImageData()
       n=refImage.GetNumberOfPoints()
       refIJKtoRAS=vtk.vtkMatrix4x4()
@@ -82,8 +168,10 @@ class resampleLogic(ScriptedLoadableModuleLogic):
       #this is tought. COpy only links (ie. shallow copy)
       newImage=vtk.vtkImageData()
       newImage.DeepCopy(refNode.GetImageData())
+      newImage=self.cast(newImage,node.GetImageData())
       newScalars=newImage.GetPointData().GetScalars()
       #doesn't set the scalars
+      log.write(("Iterating: {} points\n").format(n))
 
 
       for i in range(0,n):
@@ -108,7 +196,8 @@ class resampleLogic(ScriptedLoadableModuleLogic):
             v=coeff.Evaluate(nodeIJK)
             v0=newScalars.GetTuple1(i)
             newScalars.SetTuple1(i,v)
-            print("[{}]: {}->{}").format(i,v0,v)
+            if i%10000==0:
+                log.write(("[{}]: {}->{}\n").format(i,v0,v))
 
       node.SetName(nodeName+"_BU")
 
@@ -121,6 +210,172 @@ class resampleLogic(ScriptedLoadableModuleLogic):
 
       newNode.SetAndObserveImageData(newImage)
       slicer.mrmlScene.AddNode(newNode)
+      log.write(("Adding node {}\n").format(newNode.GetName()))
+      log.close()
+      return newNode
+
+
+  def inMask(self,binaryRep,fpos):
+
+       local=[0,0,0]
+
+       segNode=binaryRep['node']
+       segNode.TransformPointFromWorld(fpos,local)
+
+       mask=binaryRep['mask']
+       maskWorldToImageMatrix=vtk.vtkMatrix4x4()
+       mask.GetWorldToImageMatrix(maskWorldToImageMatrix)
+       local.append(1)
+       maskIJK=maskWorldToImageMatrix.MultiplyPoint(local)
+       #mask IJK is in image coordinates. However, binaryLabelMap is a truncated
+       #version of vtkImageData, so more work is required
+       maskIJK=maskIJK[0:3]#skip last (dummy) coordinate
+       maskIJK=[r*s for r,s in zip(maskIJK,mask.GetSpacing())]
+       maskIJK=[r+c for r,c in zip(maskIJK,mask.GetOrigin())]
+
+       maskI=mask.FindPoint(maskIJK)
+       try:
+           return mask.GetPointData().GetScalars().GetTuple1(maskI)
+       except:
+           return 0
+
+  def rebinSegment(self,refNode,binaryRep):
+
+       refIJKtoRAS=vtk.vtkMatrix4x4()
+       refNode.GetIJKToRASMatrix(refIJKtoRAS)
+       refImage=refNode.GetImageData()
+
+       #create new node for each segment
+       newImage=vtk.vtkImageData()
+       newImage.DeepCopy(refNode.GetImageData())
+       n=newImage.GetNumberOfPoints()
+       newScalars=newImage.GetPointData().GetScalars()
+
+       segNode=binaryRep['node']
+       mask=binaryRep['mask']
+
+       for j in range(0,n):
+           refIJK=refImage.GetPoint(j)
+           refIJK=list(refIJK)
+           refIJK.append(1)
+
+           #shift to world coordinates to match global points
+           refPos=refIJKtoRAS.MultiplyPoint(refIJK)
+           refPos=refPos[0:3]
+           fWorld=[0,0,0]
+
+           #apply potential transformations
+           refNode.TransformPointToWorld(refPos,fWorld)
+
+           v=self.inMask(binaryRep,fWorld)
+           #print("[{}]  Setting ({}) to: {}\n").format(j,fWorld,v)
+           newScalars.SetTuple1(j,v)
+
+       newNode=slicer.vtkMRMLScalarVolumeNode()
+       newNode.SetName(segNode.GetName()+'_'+binaryRep['segId'])
+       newNode.SetOrigin(refNode.GetOrigin())
+       newNode.SetSpacing(refNode.GetSpacing())
+       newNode.SetIJKToRASMatrix(refIJKtoRAS)
+
+       newNode.SetAndObserveImageData(newImage)
+       slicer.mrmlScene.AddNode(newNode)
+       return newNode
+
+
+  def rebinSegmentation(self,segNode,refNode):
+
+       log=open(os.path.join(self.baseLog,"rebinSegmentation.log"),"w")
+
+       log.write(("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName()))
+
+       nSeg=segNode.GetSegmentation().GetNumberOfSegments()
+       ## DEBUG:
+       #nSeg=1
+       #n=1000
+
+       for i in range(0,nSeg):
+           #segID
+           segID=segNode.GetSegmentation().GetNthSegmentID(i)
+           log.write(("Parsing segment {}").format(segNode.GetSegmentation.GetNthSegment(i).GetName()))
+           binaryRep={'node':segNode,
+                      'mask':segNode.GetBinaryLabelmapRepresentation(segID)}
+           newNode=self.rebinSegment(refNode,binaryRep)
+           log.write(("Adding node: {}").format(newNode.GetName()))
+
+       log.close()
+
+  def rebinSegmentation1(self,segNode,refNode):
+
+
+     logfile="C:\\Users\\studen\\labkeyCache\\log\\resample.log"
+     print("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName())
+     refImage=refNode.GetImageData()
+     refIJKtoRAS=vtk.vtkMatrix4x4()
+     refNode.GetIJKToRASMatrix(refIJKtoRAS)
+     refRAStoIJK=vtk.vtkMatrix4x4()
+     refNode.GetRASToIJKMatrix(refRAStoIJK)
+
+     nSeg=segNode.GetSegmentation().GetNumberOfSegments()
+     ## DEBUG:
+     nSeg=1
+
+     for i in range(0,nSeg):
+         #segID
+         segID=segNode.GetSegmentation().GetNthSegmentID(i)
+         binaryRep={'node':segNode,
+                    'mask': segNode.GetBinaryLabelmapRepresentation(segID)}
+
+         mask=binaryRep['mask']
+         #create new node for each segment
+         newImage=vtk.vtkImageData()
+         newImage.DeepCopy(refNode.GetImageData())
+         newScalars=newImage.GetPointData().GetScalars()
+         refN=newImage.GetNumberOfPoints()
+         for k in range(0,refN):
+             newScalars.SetTuple1(k,0)
+
+         maskN=binaryRep['mask'].GetNumberOfPoints()
+         maskScalars=mask.GetPointData().GetScalars()
+         maskImageToWorldMatrix=vtk.vtkMatrix4x4()
+         binaryRep['mask'].GetImageToWorldMatrix(maskImageToWorldMatrix)
+
+         for j in range(0,maskN):
+
+             if maskScalars.GetTuple1(j)==0:
+                 continue
+
+             maskIJK=mask.GetPoint(j)
+             maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
+             maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
+             maskIJK.append(1)
+             maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
+             maskPos=maskPos[0:3]
+             fWorld=[0,0,0]
+             #apply segmentation transformation
+             segNode.TransformPointToWorld(maskPos,fWorld)
+
+             refPos=[0,0,0]
+
+             #apply potential reference transformations
+             refNode.TransformPointFromWorld(fWorld,refPos)
+             refPos.append(1)
+             refIJK=refRAStoIJK.MultiplyPoint(refPos)
+             refIJK=refIJK[0:3]
+             i1=newImage.FindPoint(refIJK)
+             if i1<0:
+                 continue
+             if i1<refN:
+                 newScalars.SetTuple1(i1,1)
+
+         newNode=slicer.vtkMRMLScalarVolumeNode()
+         newNode.SetName(segNode.GetName()+'_'+segNode.GetSegmentation().GetNthSegmentID(i))
+         newNode.SetOrigin(refNode.GetOrigin())
+         newNode.SetSpacing(refNode.GetSpacing())
+         newNode.SetIJKToRASMatrix(refIJKtoRAS)
+
+         newNode.SetAndObserveImageData(newImage)
+         slicer.mrmlScene.AddNode(newNode)
+