|
- import slicer
- import vtk
- import os
- from slicer.ScriptedLoadableModule import *
- import ctk
- import qt
- class resample(ScriptedLoadableModule):
- """Uses ScriptedLoadableModule base class, available at:
- https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
- """
- def __init__(self, parent):
- ScriptedLoadableModule.__init__(self, parent)
- parent.title = "Resample"
- parent.categories = ["LabKey"]
- parent.dependencies = []
- parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
- parent.helpText = """
- Resample to different shapes
- """
- parent.acknowledgementText = """
- This module was developed within the frame of the ARRS sponsored medical
- physics research programe to investigate quantitative measurements of cardiac
- function using sestamibi-like tracers
- """ # replace with organization, grant and thanks.
- self.parent.helpText += self.getDefaultModuleDocumentationLink()
- self.parent = parent
- #
- # resampleWidget
- #
- class resampleWidget(ScriptedLoadableModuleWidget):
- """Uses ScriptedLoadableModuleWidget base class, available at:
- https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
- """
- def setup(self):
- ScriptedLoadableModuleWidget.setup(self)
- self.logic=resampleLogic(self)
- datasetCollapsibleButton = ctk.ctkCollapsibleButton()
- datasetCollapsibleButton.text = "Dataset"
- self.layout.addWidget(datasetCollapsibleButton)
- # Layout within the dummy collapsible button
- datasetFormLayout = qt.QFormLayout(datasetCollapsibleButton)
- self.transformNode=qt.QLineEdit("NodeToTransform")
- datasetFormLayout.addRow("TransformedNode:",self.transformNode)
- self.referenceNode=qt.QLineEdit("ReferenceNode")
- datasetFormLayout.addRow("ReferenceNode:",self.referenceNode)
- self.transformButton=qt.QPushButton("Transform")
- self.transformButton.clicked.connect(self.onTransformButtonClicked)
- datasetFormLayout.addRow("Volume:",self.transformButton)
- self.transformSegmentationButton=qt.QPushButton("Transform")
- self.transformSegmentationButton.clicked.connect(self.onTransformSegmentationButtonClicked)
- datasetFormLayout.addRow("Segmentation:",self.transformSegmentationButton)
- def onTransformButtonClicked(self):
- node=slicer.util.getFirstNodeByName(self.transformNode.text)
- if node==None:
- print("Transform node [{}] not found").format(self.transformNode.text)
- return
- refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
- if refNode==None:
- print("Reference node [{}] not found").format(self.referenceNode.text)
- return
- self.logic.rebinNode(node,refNode)
- def onTransformSegmentationButtonClicked(self):
- segNodes=slicer.util.getNodesByClass("vtkMRMLSegmentationNode")
- segNode=None
- for s in segNodes:
- print ("SegmentationNode: {}").format(s.GetName())
- if s.GetName()==self.transformNode.text:
- segNode=s
- break
- if segNode==None:
- print("Segmentation node [{}] not found").format(self.transformNode.text)
- return
- refNode=slicer.util.getFirstNodeByName(self.referenceNode.text)
- if refNode==None:
- print("Reference node [{}] not found").format(self.referenceNode.text)
- return
- self.logic.rebinSegmentation(segNode,refNode)
- class resampleLogic(ScriptedLoadableModuleLogic):
- """This class should implement all the actual
- computation done by your module. The interface
- should be such that other python code can import
- this class and make use of the functionality without
- requiring an instance of the Widget.
- Uses ScriptedLoadableModuleLogic base class, available at:
- https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
- """
- def __init__(self,parent):
- ScriptedLoadableModuleLogic.__init__(self, parent)
- try:
- fhome=os.environ["HOME"]
- except:
- #in windows, the variable is called HOMEPATH
- fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH']
- self.baseLog=os.path.join(fhome,".resample")
- if not os.path.isdir(self.baseLog):
- os.mkdir(self.baseLog)
- def printMe(self):
- print("resampleLogic ready")
- print("Log: {}").format(self.baseLog)
- def cast(self,newImage,originalImage):
- if newImage.GetPointData().GetScalars().GetDataType()==originalImage.GetPointData().GetScalars().GetDataType():
- return newImage
- outputType=originalImage.GetPointData().GetScalars().__class__.__name__
- shifter=vtk.vtkImageShiftScale()
- shifter.SetInputData(newImage)
- if outputType=="vtkUnsignedShortArray":
- shifter.SetOutputScalarTypeToUnsignedShort()
- if outputType=="vtkShortArray":
- shifter.SetOutputScalarTypeToShort()
- shifter.SetScale(1)
- shifter.SetShift(0)
- shifter.Update()
- return shifter.GetOutput()
- def rebinNode(self,node,refNode):
- #refNodeName="2SBMIRVolume19"
- #nodeName="2SMobrVolume19"
- #node=slicer.util.getFirstNodeByName(nodeName)
- #refNode=slicer.util.getFirstNodeByName(refNodeName)
- #transformNodeName="2SMobr_DF"
- #transformNode=slicer.util.getFirstNodeByName(transformNodeName)
- #node.SetAndObserveTransformNodeID(transformNode.GetID())
- log=open(os.path.join(self.baseLog,"rebinNode.log"),"w")
- log.write(("rebinNode: volume: {} ref: {}\n").format(node.GetName(),refNode.GetName()))
- refImage=refNode.GetImageData()
- n=refImage.GetNumberOfPoints()
- refIJKtoRAS=vtk.vtkMatrix4x4()
- refNode.GetIJKToRASMatrix(refIJKtoRAS)
- nodeRAStoIJK=vtk.vtkMatrix4x4()
- node.GetRASToIJKMatrix(nodeRAStoIJK)
- nodeName=node.GetName()
- coeff=vtk.vtkImageBSplineCoefficients()
- coeff.SetInputData(node.GetImageData())
- coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
- #between 3 and 5
- coeff.SetSplineDegree(5)
- coeff.Update()
- #interpolation ready to use
- #this is tough. COpy only links (ie. shallow copy)
- newImage=vtk.vtkImageData()
- newImage.DeepCopy(refNode.GetImageData())
- newImage=self.cast(newImage,node.GetImageData())
- newScalars=newImage.GetPointData().GetScalars()
- #doesn't set the scalars
- log.write(("Iterating: {} points\n").format(n))
- for i in range(0,n):
- refIJK=refImage.GetPoint(i)
- refIJK=list(refIJK)
- refIJK.append(1)
- #shift to world coordinates to match global points
- refPos=refIJKtoRAS.MultiplyPoint(refIJK)
- refPos=refPos[0:3]
- fWorld=[0,0,0]
- #apply potential transformations
- refNode.TransformPointToWorld(refPos,fWorld)
- #now do the opposite on the target node; reuse fPos
- nodePos=[0,0,0]
- node.TransformPointFromWorld(fWorld,nodePos)
- nodePos.append(1)
- nodeIJK=nodeRAStoIJK.MultiplyPoint(nodePos)
- #here we should apply some sort of interpolation
- nodeIJK=nodeIJK[0:3]
- v=coeff.Evaluate(nodeIJK)
- v0=newScalars.GetTuple1(i)
- newScalars.SetTuple1(i,v)
- if i%10000==0:
- log.write(("[{}]: {}->{}\n").format(i,v0,v))
- #node.SetName(nodeName+"_BU")
- newNode=slicer.vtkMRMLScalarVolumeNode()
- newNode.SetName(nodeName+"_RS")
- newNode.SetOrigin(refNode.GetOrigin())
- newNode.SetSpacing(refNode.GetSpacing())
- newNode.SetIJKToRASMatrix(refIJKtoRAS)
- newNode.SetAndObserveImageData(newImage)
- slicer.mrmlScene.AddNode(newNode)
- log.write(("Adding node {}\n").format(newNode.GetName()))
- log.close()
- return newNode
- def inMask(self,binaryRep,fpos):
- local=[0,0,0]
- segNode=binaryRep['node']
- segNode.TransformPointFromWorld(fpos,local)
- mask=binaryRep['mask']
- maskWorldToImageMatrix=vtk.vtkMatrix4x4()
- mask.GetWorldToImageMatrix(maskWorldToImageMatrix)
- local.append(1)
- maskIJK=maskWorldToImageMatrix.MultiplyPoint(local)
- #mask IJK is in image coordinates. However, binaryLabelMap is a truncated
- #version of vtkImageData, so more work is required
- maskIJK=maskIJK[0:3]#skip last (dummy) coordinate
- maskIJK=[r*s for r,s in zip(maskIJK,mask.GetSpacing())]
- maskIJK=[r+c for r,c in zip(maskIJK,mask.GetOrigin())]
- maskI=mask.FindPoint(maskIJK)
- try:
- return mask.GetPointData().GetScalars().GetTuple1(maskI)
- except:
- return 0
- def rebinSegment(self,refNode,binaryRep):
- refIJKtoRAS=vtk.vtkMatrix4x4()
- refNode.GetIJKToRASMatrix(refIJKtoRAS)
- refImage=refNode.GetImageData()
- #create new node for each segment
- newImage=vtk.vtkImageData()
- newImage.DeepCopy(refNode.GetImageData())
- n=newImage.GetNumberOfPoints()
- newScalars=newImage.GetPointData().GetScalars()
- segNode=binaryRep['node']
- mask=binaryRep['mask']
- for j in range(0,n):
- refIJK=refImage.GetPoint(j)
- refIJK=list(refIJK)
- refIJK.append(1)
- #shift to world coordinates to match global points
- refPos=refIJKtoRAS.MultiplyPoint(refIJK)
- refPos=refPos[0:3]
- fWorld=[0,0,0]
- #apply potential transformations
- refNode.TransformPointToWorld(refPos,fWorld)
- v=self.inMask(binaryRep,fWorld)
- #print("[{}] Setting ({}) to: {}\n").format(j,fWorld,v)
- newScalars.SetTuple1(j,v)
- newNode=slicer.vtkMRMLScalarVolumeNode()
- newNode.SetName(segNode.GetName()+'_'+binaryRep['segId'])
- newNode.SetOrigin(refNode.GetOrigin())
- newNode.SetSpacing(refNode.GetSpacing())
- newNode.SetIJKToRASMatrix(refIJKtoRAS)
- newNode.SetAndObserveImageData(newImage)
- slicer.mrmlScene.AddNode(newNode)
- return newNode
- def rebinSegmentation(self,segNode,refNode):
- log=open(os.path.join(self.baseLog,"rebinSegmentation.log"),"w")
- log.write(("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName()))
- nSeg=segNode.GetSegmentation().GetNumberOfSegments()
- ## DEBUG:
- #nSeg=1
- #n=1000
- for i in range(0,nSeg):
- #segID
- segID=segNode.GetSegmentation().GetNthSegmentID(i)
- log.write(("Parsing segment {}").format(segNode.GetSegmentation.GetNthSegment(i).GetName()))
- binaryRep={'node':segNode,
- 'mask':segNode.GetBinaryLabelmapRepresentation(segID)}
- newNode=self.rebinSegment(refNode,binaryRep)
- log.write(("Adding node: {}").format(newNode.GetName()))
- log.close()
- def rebinSegmentation1(self,segNode,refNode):
- logfile="C:\\Users\\studen\\labkeyCache\\log\\resample.log"
- print("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName())
- refImage=refNode.GetImageData()
- refIJKtoRAS=vtk.vtkMatrix4x4()
- refNode.GetIJKToRASMatrix(refIJKtoRAS)
- refRAStoIJK=vtk.vtkMatrix4x4()
- refNode.GetRASToIJKMatrix(refRAStoIJK)
- nSeg=segNode.GetSegmentation().GetNumberOfSegments()
- ## DEBUG:
- nSeg=1
- for i in range(0,nSeg):
- #segID
- segID=segNode.GetSegmentation().GetNthSegmentID(i)
- binaryRep={'node':segNode,
- 'mask': segNode.GetBinaryLabelmapRepresentation(segID)}
- mask=binaryRep['mask']
- #create new node for each segment
- newImage=vtk.vtkImageData()
- newImage.DeepCopy(refNode.GetImageData())
- newScalars=newImage.GetPointData().GetScalars()
- refN=newImage.GetNumberOfPoints()
- for k in range(0,refN):
- newScalars.SetTuple1(k,0)
- maskN=binaryRep['mask'].GetNumberOfPoints()
- maskScalars=mask.GetPointData().GetScalars()
- maskImageToWorldMatrix=vtk.vtkMatrix4x4()
- binaryRep['mask'].GetImageToWorldMatrix(maskImageToWorldMatrix)
- for j in range(0,maskN):
- if maskScalars.GetTuple1(j)==0:
- continue
- maskIJK=mask.GetPoint(j)
- maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
- maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
- maskIJK.append(1)
- maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
- maskPos=maskPos[0:3]
- fWorld=[0,0,0]
- #apply segmentation transformation
- segNode.TransformPointToWorld(maskPos,fWorld)
- refPos=[0,0,0]
- #apply potential reference transformations
- refNode.TransformPointFromWorld(fWorld,refPos)
- refPos.append(1)
- refIJK=refRAStoIJK.MultiplyPoint(refPos)
- refIJK=refIJK[0:3]
- i1=newImage.FindPoint(refIJK)
- if i1<0:
- continue
- if i1<refN:
- newScalars.SetTuple1(i1,1)
- newNode=slicer.vtkMRMLScalarVolumeNode()
- newNode.SetName(segNode.GetName()+'_'+segNode.GetSegmentation().GetNthSegmentID(i))
- newNode.SetOrigin(refNode.GetOrigin())
- newNode.SetSpacing(refNode.GetSpacing())
- newNode.SetIJKToRASMatrix(refIJKtoRAS)
- newNode.SetAndObserveImageData(newImage)
- slicer.mrmlScene.AddNode(newNode)
- class resampleTest(ScriptedLoadableModuleTest):
- """
- This is the test case for your scripted module.
- Uses ScriptedLoadableModuleTest base class, available at:
- https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
- """
- def setUp(self):
- """ Do whatever is needed to reset the state - typically a scene clear will be enough.
- """
- slicer.mrmlScene.Clear(0)
- refNodeName="2SBMIRVolume19"
- nodeName="2SMobrVolume19"
- transformNodeName="2SMobr_DF"
- path="c:\\Users\\studen\\labkeyCache\\dinamic_spect\\Patients\\@files"
- refPath=os.path.join(path,"2SBMIR")
- refPath=os.path.join(refPath,refNodeName+".nrrd")
- slicer.util.loadNodeFromFile(refPath,'VolumeFile')
- transformPath=os.path.join(path,"2SMobr")
- transformPath=os.path.join(transformPath,transformNodeName+".h5")
- slicer.util.loadNodeFromFile(transformPath,'TransformFile')
- nodePath=os.path.join(path,"2SMobr")
- nodePath=os.path.join(nodePath,nodeName+".nrrd")
- slicer.util.loadNodeFromFile(nodePath,'VolumeFile')
- self.node=slicer.util.getFirstNodeByName(nodeName)
- self.refNode=slicer.util.getFirstNodeByName(refNodeName)
- self.transformNode=slicer.util.getFirstNodeByName(transformNodeName)
- self.node.SetAndObserveTransformNodeID(self.transformNode.GetID())
- self.resampler=resampleLogic(None)
- def runTest(self):
- """Run as few or as many tests as needed here.
- """
- self.setUp()
- self.test_resample()
- def test_resample(self):
- """ Ideally you should have several levels of tests. At the lowest level
- tests should exercise the functionality of the logic with different inputs
- (both valid and invalid). At higher levels your tests should emulate the
- way the user would interact with your code and confirm that it still works
- the way you intended.
- One of the most important features of the tests is that it should alert other
- developers when their changes will have an impact on the behavior of your
- module. For example, if a developer removes a feature that you depend on,
- your test should break so they know that the feature is needed.
- """
- self.delayDisplay("Starting the test")
- #
- # first, get some data
- #
- self.resampler.rebinNode(self.node,self.refNode)
- self.delayDisplay('Test passed!')
|