Andrej Studen 5 роки тому
батько
коміт
9816c99107
1 змінених файлів з 86 додано та 40 видалено
  1. 86 40
      cardiacSPECT/cardiacSPECT.py

+ 86 - 40
cardiacSPECT/cardiacSPECT.py

@@ -104,12 +104,12 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
 
     patientLoadButton = qt.QPushButton("Load")
     patientLoadButton.toolTip="Load data from DICOM"
-    dataFormLayout.addRow("Load Patient",patientLoadButton)
+    dataFormLayout.addRow("Patient",patientLoadButton)
     patientLoadButton.clicked.connect(self.onPatientLoadButtonClicked)
 
     loadSegmentationButton = qt.QPushButton("Load")
     loadSegmentationButton.toolTip="Load segmentation from server"
-    dataFormLayout.addRow("Load Segmentation",loadSegmentationButton)
+    dataFormLayout.addRow("Segmentation",loadSegmentationButton)
     loadSegmentationButton.clicked.connect(self.onLoadSegmentationButtonClicked)
 
     saveVolumeButton = qt.QPushButton("Save")
@@ -127,7 +127,12 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
     dataFormLayout.addRow("Transformation",saveTransformationButton)
     saveTransformationButton.clicked.connect(self.onSaveTransformationButtonClicked)
 
-    transformNodeButton = qt.QPushButton("Transform Node")
+    saveInputFunctionButton = qt.QPushButton("Save")
+    saveInputFunctionButton.toolTip="Save InputFunction to NRRD"
+    dataFormLayout.addRow("InputFunction",saveInputFunctionButton)
+    saveInputFunctionButton.clicked.connect(self.onSaveInputFunctionButtonClicked)
+
+    transformNodeButton = qt.QPushButton("Transform Nodes")
     transformNodeButton.toolTip="Transform node with patient based transform"
     dataFormLayout.addRow("Transform Nodes",transformNodeButton)
     transformNodeButton.clicked.connect(self.onTransformNodeButtonClicked)
@@ -320,7 +325,7 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
             dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
 
             dn.SetSize(n)
-            dn.SetName(self.logic.getSegmentName(j))
+            dn.SetName(self.patientId.text+'_'+self.logic.getSegmentName(j))
 
             dt=0;
             t0=0;
@@ -371,6 +376,9 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
   def onSaveTransformationButtonClicked(self):
       self.logic.storeTransformation(self.patientId.text)
 
+  def onSaveInputFunctionButtonClicked(self):
+      self.logic.storeInputFunction(self.patientId.text)
+
   def onTransformNodeButtonClicked(self):
       self.logic.applyTransform(self.patientId.text, self.refPatientId.text,
           self.time_frame_select.minimum,self.time_frame_select.maximum)
@@ -560,6 +568,24 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
 
 
 
+
+  def getMaskPos(self,mask,i):
+      maskIJK=mask.GetPoint(i)
+      maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
+      maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
+
+      #this is now in extent spacing, whitch ImageToWorldMatrix understands
+
+      #to 4D vector for vtkMatrix4x4 handling
+      maskIJK.append(1)
+
+      #go to ras, global coordinates (in mm)
+      maskImageToWorldMatrix=vtk.vtkMatrix4x4()
+      mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
+
+      maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
+      return maskPos[0:3]
+
   def meanROI(self, volName1, i):
     s=0
 
@@ -581,7 +607,7 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
         print("Segment {} not found".format(segment))
         return s
 
-    print "Got mask for segment {}".format(segment)
+    print "Got mask for segment {}, npts {}".format(segment,mask.GetNumberOfPoints())
     #get mask at (x,y,z)
     #mask.GetPointData().GetScalars().GetTuple1(mask.FindPoint([x,y,z]))
 
@@ -589,42 +615,48 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
     dataNode=slicer.mrmlScene.GetFirstNodeByName(volName1)
     dataImage=dataNode.GetImageData()
     # use IJK2RAS to get global coordinates
-    ijkToRas = vtk.vtkMatrix4x4()
-    dataNode.GetIJKToRASMatrix(ijkToRas)
-
-    #iterate over volume pixelData
-    n=dataImage.GetNumberOfPoints()
-
-    extent=mask.GetExtent()
-    fM=vtk.vtkMatrix4x4()
-    mask.GetWorldToImageMatrix(fM)
+    dataRAStoIJK = vtk.vtkMatrix4x4()
+    dataNode.GetRASToIJKMatrix(dataRAStoIJK)
+
+    #allow for interpolation in segmentation pixels
+    coeff=vtk.vtkImageBSplineCoefficients()
+    coeff.SetInputData(dataImage)
+    coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
+    #between 3 and 5
+    coeff.SetSplineDegree(5)
+    coeff.Update()
+
+    maskImageToWorldMatrix=vtk.vtkMatrix4x4()
+    mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
     ns=0
-    for i in range(0,n):
 
-      #get global coordinates of point i
-      [ix,iy,iz]=dataImage.GetPoint(i)
-      position_ijk=[ix, iy, iz, 1]
-      #ras are global coordinates (in mm)
-      position_ras=ijkToRas.MultiplyPoint(position_ijk)
-      fpos=[int(np.round(x)) for x in fM.MultiplyPoint(position_ras)]
 
-      outOfRange=False
-      for k in range(0,3):
-          if fpos[k]<extent[2*k] or fpos[k]>extent[2*k+1]:
-              outOfRange=True
-              break;
+    maskN=mask.GetNumberOfPoints()
+    maskScalars=mask.GetPointData().GetScalars()
+    maskOrigin=[0,0,0]
+    maskOrigin=mask.GetOrigin()
 
-      if outOfRange:
+    for i in range(0,maskN):
+      #skip all points that are 0
+      if maskScalars.GetTuple1(i)==0:
           continue
 
-      #find point in mask with the same global coordinates
-      maskValue=mask.GetPointData().GetScalars().GetTuple1(mask.ComputePointId(fpos[0:3]))
+      #get global coordinates of point i
+      maskPos=self.getMaskPos(mask,i)
 
-      if maskValue == 0:
-          continue
+      #print("Evaluating at {}").format(maskPos)
+      #convert from global to local
+      dataPos=[0,0,0]
+      #account for potentially applied transform on dataNode
+      dataNode.TransformPointFromWorld(maskPos,dataPos)
+      dataPos.append(1)
+      dataIJK=dataRAStoIJK.MultiplyPoint(dataPos)
 
-      #use maskValue to project ROI data
-      s+=maskValue*dataImage.GetPointData().GetScalars().GetTuple1(i)
+      #drop the 4th index
+      dataIJK=dataIJK[0:3]
+
+      #interpolate
+      s+=coeff.Evaluate(dataIJK)
       ns+=1
 
     return s/ns
@@ -666,9 +698,14 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
       suffix=".nrrd"
       if node.__class__.__name__=="vtkMRMLDoubleArrayNode":
           suffix=".mcsv"
-      if node.__class__.__name__=="vtkMRMLTransformNode":
+      if (node.__class__.__name__=="vtkMRMLTransformNode" or
+        node.__class__.__name__=="vtkMRMLGridTransformNode"):
           suffix=".h5"
       fileName=nodeName+suffix
+
+      if not os.path.isdir(localPath):
+         os.mkdir(localPath)
+
       file=os.path.join(localPath,fileName)
       slicer.util.saveNode(node,file)
       print("Stored to: {}").format(file)
@@ -701,10 +738,15 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
       self.storeNodeRemote(relativePath,segNodeName)
 
   def storeInputFunction(self,patientId):
+      self.calculateInputFunction(patientId)
+      project="dinamic_spect/Patients"
+      relativePath=project+'/@files/'+patientId
       doubleArrayNodeName=patientId+'_Ventricle'
       self.storeNodeRemote(relativePath,doubleArrayNodeName)
 
   def storeTransformation(self,patientId):
+      project="dinamic_spect/Patients"
+      relativePath=project+'/@files/'+patientId
       transformNodeName=patientId+"_DF"
       self.storeNodeRemote(relativePath,transformNodeName)
 
@@ -736,6 +778,7 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
           refNode=slicer.util.getFirstNodeByName(refNodeName)
           if refNode!=None:
               self.resampler.rebinNode(node,refNode)
+          print("Completed transformation {}").format(it)
 
       nodeName=patientId+'CT'
       node=slicer.util.getFirstNodeByName(nodeName)
@@ -750,13 +793,14 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
 
 
 
-  def calculateInputFunction(self):
+  def calculateInputFunction(self,patientId):
        n=self.frame_data.shape[3]
 
-       dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode','Ventricle')
+       dnsNodeName=patientId+'_Ventricle'
+       dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode',dnsNodeName)
        if dns.GetNumberOfItems() == 0:
            dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
-           dn.SetName('Ventricle')
+           dn.SetName(dnsNodeName)
        else:
            dn = dns.GetItemAsObject(0)
 
@@ -774,7 +818,8 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
 
        juse=-1
        for j in range(0,segmentation.GetNumberOfSegments()):
-           segment=segNode.GetSegmentation().GetNthSegmentID(j)
+           segmentID=segNode.GetSegmentation().GetNthSegmentID(j)
+           segment=segNode.GetSegmentation().GetSegment(segmentID)
            if segment.GetName()=='Ventricle':
                juse=j
                break
@@ -785,10 +830,11 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
 
        dt=0;
        t0=0;
+       ft=self.frame_time
        for i in range(0,n):
-           vol="testVolume"+str(i)
+           vol=patientId+"Volume"+str(i)
            fx=ft[i]
-           fy=self.logic.meanROI(vol,juse)
+           fy=self.meanROI(vol,juse)
            dt=2*ft[i]-t0
            t0+=dt