Bladeren bron

Speeding up the integral calculation, renaming of the functions for a better clarity

Andrej Studen 5 jaren geleden
bovenliggende
commit
ce66894338
2 gewijzigde bestanden met toevoegingen van 73 en 35 verwijderingen
  1. 72 34
      cardiacSPECT/cardiacSPECT.py
  2. 1 1
      cardiacSPECT/resample.py

+ 72 - 34
cardiacSPECT/cardiacSPECT.py

@@ -104,12 +104,12 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
 
     patientLoadButton = qt.QPushButton("Load")
     patientLoadButton.toolTip="Load data from DICOM"
-    dataFormLayout.addRow("Load Patient",patientLoadButton)
+    dataFormLayout.addRow("Patient",patientLoadButton)
     patientLoadButton.clicked.connect(self.onPatientLoadButtonClicked)
 
     loadSegmentationButton = qt.QPushButton("Load")
     loadSegmentationButton.toolTip="Load segmentation from server"
-    dataFormLayout.addRow("Load Segmentation",loadSegmentationButton)
+    dataFormLayout.addRow("Segmentation",loadSegmentationButton)
     loadSegmentationButton.clicked.connect(self.onLoadSegmentationButtonClicked)
 
     saveVolumeButton = qt.QPushButton("Save")
@@ -127,7 +127,12 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
     dataFormLayout.addRow("Transformation",saveTransformationButton)
     saveTransformationButton.clicked.connect(self.onSaveTransformationButtonClicked)
 
-    transformNodeButton = qt.QPushButton("Transform Node")
+    saveInputFunctionButton = qt.QPushButton("Save")
+    saveInputFunctionButton.toolTip="Save InputFunction to NRRD"
+    dataFormLayout.addRow("InputFunction",saveInputFunctionButton)
+    saveInputFunctionButton.clicked.connect(self.onSaveInputFunctionButtonClicked)
+
+    transformNodeButton = qt.QPushButton("Transform Nodes")
     transformNodeButton.toolTip="Transform node with patient based transform"
     dataFormLayout.addRow("Transform Nodes",transformNodeButton)
     transformNodeButton.clicked.connect(self.onTransformNodeButtonClicked)
@@ -320,7 +325,7 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
             dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
 
             dn.SetSize(n)
-            dn.SetName(self.logic.getSegmentName(j))
+            dn.SetName(self.patientId.text+'_'+self.logic.getSegmentName(j))
 
             dt=0;
             t0=0;
@@ -371,6 +376,9 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
   def onSaveTransformationButtonClicked(self):
       self.logic.storeTransformation(self.patientId.text)
 
+  def onSaveInputFunctionButtonClicked(self):
+      self.logic.storeInputFunction(self.patientId.text)
+
   def onTransformNodeButtonClicked(self):
       self.logic.applyTransform(self.patientId.text, self.refPatientId.text,
           self.time_frame_select.minimum,self.time_frame_select.maximum)
@@ -560,6 +568,24 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
 
 
 
+
+  def getMaskPos(self,mask,i):
+      maskIJK=mask.GetPoint(i)
+      maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
+      maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
+
+      #this is now in extent spacing, whitch ImageToWorldMatrix understands
+
+      #to 4D vector for vtkMatrix4x4 handling
+      maskIJK.append(1)
+
+      #go to ras, global coordinates (in mm)
+      maskImageToWorldMatrix=vtk.vtkMatrix4x4()
+      mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
+
+      maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
+      return maskPos[0:3]
+
   def meanROI(self, volName1, i):
     s=0
 
@@ -581,7 +607,7 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
         print("Segment {} not found".format(segment))
         return s
 
-    print "Got mask for segment {}".format(segment)
+    print "Got mask for segment {}, npts {}".format(segment,mask.GetNumberOfPoints())
     #get mask at (x,y,z)
     #mask.GetPointData().GetScalars().GetTuple1(mask.FindPoint([x,y,z]))
 
@@ -589,42 +615,48 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
     dataNode=slicer.mrmlScene.GetFirstNodeByName(volName1)
     dataImage=dataNode.GetImageData()
     # use IJK2RAS to get global coordinates
-    ijkToRas = vtk.vtkMatrix4x4()
-    dataNode.GetIJKToRASMatrix(ijkToRas)
-
-    #iterate over volume pixelData
-    n=dataImage.GetNumberOfPoints()
-
-    extent=mask.GetExtent()
-    fM=vtk.vtkMatrix4x4()
-    mask.GetWorldToImageMatrix(fM)
+    dataRAStoIJK = vtk.vtkMatrix4x4()
+    dataNode.GetRASToIJKMatrix(dataRAStoIJK)
+
+    #allow for interpolation in segmentation pixels
+    coeff=vtk.vtkImageBSplineCoefficients()
+    coeff.SetInputData(dataImage)
+    coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
+    #between 3 and 5
+    coeff.SetSplineDegree(5)
+    coeff.Update()
+
+    maskImageToWorldMatrix=vtk.vtkMatrix4x4()
+    mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
     ns=0
-    for i in range(0,n):
 
-      #get global coordinates of point i
-      [ix,iy,iz]=dataImage.GetPoint(i)
-      position_ijk=[ix, iy, iz, 1]
-      #ras are global coordinates (in mm)
-      position_ras=ijkToRas.MultiplyPoint(position_ijk)
-      fpos=[int(np.round(x)) for x in fM.MultiplyPoint(position_ras)]
 
-      outOfRange=False
-      for k in range(0,3):
-          if fpos[k]<extent[2*k] or fpos[k]>extent[2*k+1]:
-              outOfRange=True
-              break;
+    maskN=mask.GetNumberOfPoints()
+    maskScalars=mask.GetPointData().GetScalars()
+    maskOrigin=[0,0,0]
+    maskOrigin=mask.GetOrigin()
 
-      if outOfRange:
+    for i in range(0,maskN):
+      #skip all points that are 0
+      if maskScalars.GetTuple1(i)==0:
           continue
 
-      #find point in mask with the same global coordinates
-      maskValue=mask.GetPointData().GetScalars().GetTuple1(mask.ComputePointId(fpos[0:3]))
+      #get global coordinates of point i
+      maskPos=self.getMaskPos(mask,i)
 
-      if maskValue == 0:
-          continue
+      #print("Evaluating at {}").format(maskPos)
+      #convert from global to local
+      dataPos=[0,0,0]
+      #account for potentially applied transform on dataNode
+      dataNode.TransformPointFromWorld(maskPos,dataPos)
+      dataPos.append(1)
+      dataIJK=dataRAStoIJK.MultiplyPoint(dataPos)
+
+      #drop the 4th index
+      dataIJK=dataIJK[0:3]
 
-      #use maskValue to project ROI data
-      s+=maskValue*dataImage.GetPointData().GetScalars().GetTuple1(i)
+      #interpolate
+      s+=coeff.Evaluate(dataIJK)
       ns+=1
 
     return s/ns
@@ -666,7 +698,8 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
       suffix=".nrrd"
       if node.__class__.__name__=="vtkMRMLDoubleArrayNode":
           suffix=".mcsv"
-      if node.__class__.__name__=="vtkMRMLTransformNode":
+      if (node.__class__.__name__=="vtkMRMLTransformNode" or
+        node.__class__.__name__=="vtkMRMLGridTransformNode"):
           suffix=".h5"
       fileName=nodeName+suffix
       file=os.path.join(localPath,fileName)
@@ -701,10 +734,14 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
       self.storeNodeRemote(relativePath,segNodeName)
 
   def storeInputFunction(self,patientId):
+      project="dinamic_spect/Patients"
+      relativePath=project+'/@files/'+patientId
       doubleArrayNodeName=patientId+'_Ventricle'
       self.storeNodeRemote(relativePath,doubleArrayNodeName)
 
   def storeTransformation(self,patientId):
+      project="dinamic_spect/Patients"
+      relativePath=project+'/@files/'+patientId
       transformNodeName=patientId+"_DF"
       self.storeNodeRemote(relativePath,transformNodeName)
 
@@ -736,6 +773,7 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
           refNode=slicer.util.getFirstNodeByName(refNodeName)
           if refNode!=None:
               self.resampler.rebinNode(node,refNode)
+          print("Completed transformation {}").format(it)
 
       nodeName=patientId+'CT'
       node=slicer.util.getFirstNodeByName(nodeName)

+ 1 - 1
cardiacSPECT/resample.py

@@ -108,7 +108,7 @@ class resampleLogic(ScriptedLoadableModuleLogic):
             v=coeff.Evaluate(nodeIJK)
             v0=newScalars.GetTuple1(i)
             newScalars.SetTuple1(i,v)
-            print("[{}]: {}->{}").format(i,v0,v)
+            #print("[{}]: {}->{}").format(i,v0,v)
 
       node.SetName(nodeName+"_BU")