import slicer import vtk import os from slicer.ScriptedLoadableModule import * import ctk import qt class resample(ScriptedLoadableModule): """Uses ScriptedLoadableModule base class, available at: https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py """ def __init__(self, parent): ScriptedLoadableModule.__init__(self, parent) parent.title = "Resample" parent.categories = ["LabKey"] parent.dependencies = [] parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)" parent.helpText = """ Resample to different shapes """ parent.acknowledgementText = """ This module was developed within the frame of the ARRS sponsored medical physics research programe to investigate quantitative measurements of cardiac function using sestamibi-like tracers """ # replace with organization, grant and thanks. self.parent.helpText += self.getDefaultModuleDocumentationLink() self.parent = parent # # resampleWidget # class resampleWidget(ScriptedLoadableModuleWidget): """Uses ScriptedLoadableModuleWidget base class, available at: https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py """ def setup(self): ScriptedLoadableModuleWidget.setup(self) self.logic=resampleLogic(self) datasetCollapsibleButton = ctk.ctkCollapsibleButton() datasetCollapsibleButton.text = "Dataset" self.layout.addWidget(datasetCollapsibleButton) # Layout within the dummy collapsible button datasetFormLayout = qt.QFormLayout(datasetCollapsibleButton) self.transformNode=qt.QLineEdit("NodeToTransform") datasetFormLayout.addRow("TransformedNode:",self.transformNode) self.referenceNode=qt.QLineEdit("ReferenceNode") datasetFormLayout.addRow("ReferenceNode:",self.referenceNode) self.transformButton=qt.QPushButton("Transform") self.transformButton.clicked.connect(self.onTransformButtonClicked) datasetFormLayout.addRow("Volume:",self.transformButton) self.transformSegmentationButton=qt.QPushButton("Transform") self.transformSegmentationButton.clicked.connect(self.onTransformSegmentationButtonClicked) datasetFormLayout.addRow("Segmentation:",self.transformSegmentationButton) def onTransformButtonClicked(self): node=slicer.util.getFirstNodeByName(self.transformNode.text) if node==None: print("Transform node [{}] not found").format(self.transformNode.text) return refNode=slicer.util.getFirstNodeByName(self.referenceNode.text) if refNode==None: print("Reference node [{}] not found").format(self.referenceNode.text) return self.logic.rebinNode(node,refNode) def onTransformSegmentationButtonClicked(self): segNodes=slicer.util.getNodesByClass("vtkMRMLSegmentationNode") segNode=None for s in segNodes: print ("SegmentationNode: {}").format(s.GetName()) if s.GetName()==self.transformNode.text: segNode=s break if segNode==None: print("Segmentation node [{}] not found").format(self.transformNode.text) return refNode=slicer.util.getFirstNodeByName(self.referenceNode.text) if refNode==None: print("Reference node [{}] not found").format(self.referenceNode.text) return self.logic.rebinSegmentation(segNode,refNode) class resampleLogic(ScriptedLoadableModuleLogic): """This class should implement all the actual computation done by your module. The interface should be such that other python code can import this class and make use of the functionality without requiring an instance of the Widget. Uses ScriptedLoadableModuleLogic base class, available at: https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py """ def __init__(self,parent): ScriptedLoadableModuleLogic.__init__(self, parent) try: fhome=os.environ["HOME"] except: #in windows, the variable is called HOMEPATH fhome=os.environ['HOMEDRIVE']+os.environ['HOMEPATH'] self.baseLog=os.path.join(fhome,".resample") if not os.path.isdir(self.baseLog): os.mkdir(self.baseLog) def printMe(self): print("resampleLogic ready") print("Log: {}").format(self.baseLog) def cast(self,newImage,originalImage): if newImage.GetPointData().GetScalars().GetDataType()==originalImage.GetPointData().GetScalars().GetDataType(): return newImage outputType=originalImage.GetPointData().GetScalars().__class__.__name__ shifter=vtk.vtkImageShiftScale() shifter.SetInputData(newImage) if outputType=="vtkUnsignedShortArray": shifter.SetOutputScalarTypeToUnsignedShort() if outputType=="vtkShortArray": shifter.SetOutputScalarTypeToShort() shifter.SetScale(1) shifter.SetShift(0) shifter.Update() return shifter.GetOutput() def rebinNode(self,node,refNode): #refNodeName="2SBMIRVolume19" #nodeName="2SMobrVolume19" #node=slicer.util.getFirstNodeByName(nodeName) #refNode=slicer.util.getFirstNodeByName(refNodeName) #transformNodeName="2SMobr_DF" #transformNode=slicer.util.getFirstNodeByName(transformNodeName) #node.SetAndObserveTransformNodeID(transformNode.GetID()) log=open(os.path.join(self.baseLog,"rebinNode.log"),"w") log.write(("rebinNode: volume: {} ref: {}\n").format(node.GetName(),refNode.GetName())) refImage=refNode.GetImageData() n=refImage.GetNumberOfPoints() refIJKtoRAS=vtk.vtkMatrix4x4() refNode.GetIJKToRASMatrix(refIJKtoRAS) nodeRAStoIJK=vtk.vtkMatrix4x4() node.GetRASToIJKMatrix(nodeRAStoIJK) nodeName=node.GetName() coeff=vtk.vtkImageBSplineCoefficients() coeff.SetInputData(node.GetImageData()) coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP) #between 3 and 5 coeff.SetSplineDegree(5) coeff.Update() #interpolation ready to use #this is tough. COpy only links (ie. shallow copy) newImage=vtk.vtkImageData() newImage.DeepCopy(refNode.GetImageData()) newImage=self.cast(newImage,node.GetImageData()) newScalars=newImage.GetPointData().GetScalars() #doesn't set the scalars log.write(("Iterating: {} points\n").format(n)) for i in range(0,n): refIJK=refImage.GetPoint(i) refIJK=list(refIJK) refIJK.append(1) #shift to world coordinates to match global points refPos=refIJKtoRAS.MultiplyPoint(refIJK) refPos=refPos[0:3] fWorld=[0,0,0] #apply potential transformations refNode.TransformPointToWorld(refPos,fWorld) #now do the opposite on the target node; reuse fPos nodePos=[0,0,0] node.TransformPointFromWorld(fWorld,nodePos) nodePos.append(1) nodeIJK=nodeRAStoIJK.MultiplyPoint(nodePos) #here we should apply some sort of interpolation nodeIJK=nodeIJK[0:3] v=coeff.Evaluate(nodeIJK) v0=newScalars.GetTuple1(i) newScalars.SetTuple1(i,v) if i%10000==0: log.write(("[{}]: {}->{}\n").format(i,v0,v)) #node.SetName(nodeName+"_BU") newNode=slicer.vtkMRMLScalarVolumeNode() newNode.SetName(nodeName+"_RS") newNode.SetOrigin(refNode.GetOrigin()) newNode.SetSpacing(refNode.GetSpacing()) newNode.SetIJKToRASMatrix(refIJKtoRAS) newNode.SetAndObserveImageData(newImage) slicer.mrmlScene.AddNode(newNode) log.write(("Adding node {}\n").format(newNode.GetName())) log.close() return newNode def inMask(self,binaryRep,fpos): local=[0,0,0] segNode=binaryRep['node'] segNode.TransformPointFromWorld(fpos,local) mask=binaryRep['mask'] maskWorldToImageMatrix=vtk.vtkMatrix4x4() mask.GetWorldToImageMatrix(maskWorldToImageMatrix) local.append(1) maskIJK=maskWorldToImageMatrix.MultiplyPoint(local) #mask IJK is in image coordinates. However, binaryLabelMap is a truncated #version of vtkImageData, so more work is required maskIJK=maskIJK[0:3]#skip last (dummy) coordinate maskIJK=[r*s for r,s in zip(maskIJK,mask.GetSpacing())] maskIJK=[r+c for r,c in zip(maskIJK,mask.GetOrigin())] maskI=mask.FindPoint(maskIJK) try: return mask.GetPointData().GetScalars().GetTuple1(maskI) except: return 0 def rebinSegment(self,refNode,binaryRep): refIJKtoRAS=vtk.vtkMatrix4x4() refNode.GetIJKToRASMatrix(refIJKtoRAS) refImage=refNode.GetImageData() #create new node for each segment newImage=vtk.vtkImageData() newImage.DeepCopy(refNode.GetImageData()) n=newImage.GetNumberOfPoints() newScalars=newImage.GetPointData().GetScalars() segNode=binaryRep['node'] mask=binaryRep['mask'] for j in range(0,n): refIJK=refImage.GetPoint(j) refIJK=list(refIJK) refIJK.append(1) #shift to world coordinates to match global points refPos=refIJKtoRAS.MultiplyPoint(refIJK) refPos=refPos[0:3] fWorld=[0,0,0] #apply potential transformations refNode.TransformPointToWorld(refPos,fWorld) v=self.inMask(binaryRep,fWorld) #print("[{}] Setting ({}) to: {}\n").format(j,fWorld,v) newScalars.SetTuple1(j,v) newNode=slicer.vtkMRMLScalarVolumeNode() newNode.SetName(segNode.GetName()+'_'+binaryRep['segId']) newNode.SetOrigin(refNode.GetOrigin()) newNode.SetSpacing(refNode.GetSpacing()) newNode.SetIJKToRASMatrix(refIJKtoRAS) newNode.SetAndObserveImageData(newImage) slicer.mrmlScene.AddNode(newNode) return newNode def rebinSegmentation(self,segNode,refNode): log=open(os.path.join(self.baseLog,"rebinSegmentation.log"),"w") log.write(("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName())) nSeg=segNode.GetSegmentation().GetNumberOfSegments() ## DEBUG: #nSeg=1 #n=1000 for i in range(0,nSeg): #segID segID=segNode.GetSegmentation().GetNthSegmentID(i) log.write(("Parsing segment {}").format(segNode.GetSegmentation.GetNthSegment(i).GetName())) binaryRep={'node':segNode, 'mask':segNode.GetBinaryLabelmapRepresentation(segID)} newNode=self.rebinSegment(refNode,binaryRep) log.write(("Adding node: {}").format(newNode.GetName())) log.close() def rebinSegmentation1(self,segNode,refNode): logfile="C:\\Users\\studen\\labkeyCache\\log\\resample.log" print("rebinNode: {} {}\n").format(segNode.GetName(),refNode.GetName()) refImage=refNode.GetImageData() refIJKtoRAS=vtk.vtkMatrix4x4() refNode.GetIJKToRASMatrix(refIJKtoRAS) refRAStoIJK=vtk.vtkMatrix4x4() refNode.GetRASToIJKMatrix(refRAStoIJK) nSeg=segNode.GetSegmentation().GetNumberOfSegments() ## DEBUG: nSeg=1 for i in range(0,nSeg): #segID segID=segNode.GetSegmentation().GetNthSegmentID(i) binaryRep={'node':segNode, 'mask': segNode.GetBinaryLabelmapRepresentation(segID)} mask=binaryRep['mask'] #create new node for each segment newImage=vtk.vtkImageData() newImage.DeepCopy(refNode.GetImageData()) newScalars=newImage.GetPointData().GetScalars() refN=newImage.GetNumberOfPoints() for k in range(0,refN): newScalars.SetTuple1(k,0) maskN=binaryRep['mask'].GetNumberOfPoints() maskScalars=mask.GetPointData().GetScalars() maskImageToWorldMatrix=vtk.vtkMatrix4x4() binaryRep['mask'].GetImageToWorldMatrix(maskImageToWorldMatrix) for j in range(0,maskN): if maskScalars.GetTuple1(j)==0: continue maskIJK=mask.GetPoint(j) maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())] maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())] maskIJK.append(1) maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK) maskPos=maskPos[0:3] fWorld=[0,0,0] #apply segmentation transformation segNode.TransformPointToWorld(maskPos,fWorld) refPos=[0,0,0] #apply potential reference transformations refNode.TransformPointFromWorld(fWorld,refPos) refPos.append(1) refIJK=refRAStoIJK.MultiplyPoint(refPos) refIJK=refIJK[0:3] i1=newImage.FindPoint(refIJK) if i1<0: continue if i1