cardiacSPECT.py 34 KB


  1. import os
  2. import sys
  3. import unittest
  4. import vtk, qt, ctk, slicer
  5. from slicer.ScriptedLoadableModule import *
  6. import logging
  7. import slicer
  8. import numpy as np
  9. import json
  10. import re
  11. #
  12. # cardiacSPECT
  13. #
  14. class cardiacSPECT(ScriptedLoadableModule):
  15. """Uses ScriptedLoadableModule base class, available at:
  16. https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  17. """
  18. def __init__(self, parent):
  19. ScriptedLoadableModule.__init__(self, parent)
  20. parent.title = "Cardiac SPECT"
  21. parent.categories = ["dynamicSPECT"]
  22. parent.dependencies = []
  23. parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
  24. parent.helpText = """
  25. Load dynamic cardiac SPECT data to Slicer
  26. """
  27. parent.acknowledgementText = """
  28. This module was developed within the frame of the ARRS sponsored medical
  29. physics research programe to investigate quantitative measurements of cardiac
  30. function using sestamibi-like tracers
  31. """ # replace with organization, grant and thanks.
  32. self.parent.helpText += self.getDefaultModuleDocumentationLink()
  33. self.parent = parent
  34. #
  35. # cardiacSPECTWidget
  36. #
  37. class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
  38. """Uses ScriptedLoadableModuleWidget base class, available at:
  39. https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  40. """
  41. def setup(self):
  42. ScriptedLoadableModuleWidget.setup(self)
  43. #load config
  44. configFile=os.path.join(os.path.expanduser('~'),\
  45. '.cardiacSPECT','cardiacSPECT.json')
  46. with open(configFile) as f:
  47. self.cfg=json.load(f)
  48. #need labkey browser here
  49. sconfig=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
  50. with open(sconfig) as f:
  51. setup=json.load(f)
  52. sys.path.append(setup['paths']['labkeyInterface'])
  53. import labkeyInterface
  54. import labkeyDatabaseBrowser
  55. import labkeyFileBrowser
  56. self.net=labkeyInterface.labkeyInterface()
  57. fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
  58. self.net.init(fconfig)
  59. self.db=labkeyDatabaseBrowser.labkeyDB(self.net)
  60. ds=self.db.selectRows(self.cfg['project'],self.cfg['schemaName'],\
  61. self.cfg['queryName'],[])
  62. patients=list(set([row['aliasID'] for row in ds['rows']]))
  63. # Instantiate and connect widgets ...
  64. dataButton = ctk.ctkCollapsibleButton()
  65. dataButton.text = "Data"
  66. self.layout.addWidget(dataButton)
  67. # Layout within the sample collapsible button
  68. dataFormLayout = qt.QFormLayout(dataButton)
  69. self.patientId=qt.QComboBox()
  70. self.patientId.insertItems(0,patients)
  71. dataFormLayout.addRow('Patient ID', self.patientId)
  72. self.refPatientId=qt.QComboBox();
  73. self.refPatientId.insertItems(0,patients)
  74. dataFormLayout.addRow('Reference Patient ID', self.refPatientId)
  75. patientLoadButton = qt.QPushButton("Load")
  76. patientLoadButton.toolTip="Load data from DICOM"
  77. dataFormLayout.addRow("Patient",patientLoadButton)
  78. patientLoadButton.clicked.connect(self.onPatientLoadButtonClicked)
  79. patientLoadNRRDButton = qt.QPushButton("Load NRRD")
  80. patientLoadNRRDButton.toolTip="Load data from NRRD"
  81. dataFormLayout.addRow("Patient",patientLoadNRRDButton)
  82. patientLoadNRRDButton.clicked.connect(self.onPatientLoadNRRDButtonClicked)
  83. loadSegmentationButton = qt.QPushButton("Load")
  84. loadSegmentationButton.toolTip="Load segmentation from server"
  85. dataFormLayout.addRow("Segmentation",loadSegmentationButton)
  86. loadSegmentationButton.clicked.connect(self.onLoadSegmentationButtonClicked)
  87. self.modelParameter=qt.QLineEdit('k1');
  88. dataFormLayout.addRow('Model Parameter', self.modelParameter)
  89. loadModelButton = qt.QPushButton("Load")
  90. loadModelButton.toolTip="Load model parameters from server"
  91. dataFormLayout.addRow("Model",loadModelButton)
  92. loadModelButton.clicked.connect(self.onLoadModelButtonClicked)
  93. saveVolumeButton = qt.QPushButton("Save")
  94. saveVolumeButton.toolTip="Save volume to NRRD"
  95. dataFormLayout.addRow("Volume",saveVolumeButton)
  96. saveVolumeButton.clicked.connect(self.onSaveVolumeButtonClicked)
  97. saveSegmentationButton = qt.QPushButton("Save")
  98. saveSegmentationButton.toolTip="Save segmentation to NRRD"
  99. dataFormLayout.addRow("Segmentation",saveSegmentationButton)
  100. saveSegmentationButton.clicked.connect(self.onSaveSegmentationButtonClicked)
  101. saveTransformationButton = qt.QPushButton("Save")
  102. saveTransformationButton.toolTip="Save transformation to NRRD"
  103. dataFormLayout.addRow("Transformation",saveTransformationButton)
  104. saveTransformationButton.clicked.connect(self.onSaveTransformationButtonClicked)
  105. saveInputFunctionButton = qt.QPushButton("Save")
  106. saveInputFunctionButton.toolTip="Save InputFunction to NRRD"
  107. dataFormLayout.addRow("InputFunction",saveInputFunctionButton)
  108. saveInputFunctionButton.clicked.connect(self.onSaveInputFunctionButtonClicked)
  109. transformNodeButton = qt.QPushButton("Transform Nodes")
  110. transformNodeButton.toolTip="Transform node with patient based transform"
  111. dataFormLayout.addRow("Transform Nodes",transformNodeButton)
  112. transformNodeButton.clicked.connect(self.onTransformNodeButtonClicked)
  113. # Add vertical spacer
  114. self.layout.addStretch(1)
  115. #addFrameButton=qt.QPushButton("Add Frame")
  116. #addFrameButton.toolTip="Add frame to VTK"
  117. #dataFormLayout.addWidget(addFrameButton)
  118. #addFrameButton.connect('clicked(bool)',self.onAddFrameButtonClicked)
  119. #addCTButton=qt.QPushButton("Add CT")
  120. #addCTButton.toolTip="Add CT to VTK"
  121. #dataFormLayout.addWidget(addCTButton)
  122. #addCTButton.connect('clicked(bool)',self.onAddCTButtonClicked)
  123. #
  124. # Parameters Area
  125. #
  126. parametersCollapsibleButton = ctk.ctkCollapsibleButton()
  127. parametersCollapsibleButton.text = "Parameters"
  128. self.layout.addWidget(parametersCollapsibleButton)
  129. # Layout within the dummy collapsible button
  130. parametersFormLayout = qt.QFormLayout(parametersCollapsibleButton)
  131. #
  132. # check box to trigger taking screen shots for later use in tutorials
  133. #
  134. hbox1=qt.QHBoxLayout()
  135. frameLabel = qt.QLabel()
  136. frameLabel.setText("Select frame")
  137. hbox1.addWidget(frameLabel)
  138. self.time_frame_select=qt.QSlider(qt.Qt.Horizontal)
  139. self.time_frame_select.valueChanged.connect(self.onTimeFrameSelect)
  140. #self.time_frame_select.connect('valueChanged()', self.onTimeFrameSelect)
  141. self.time_frame_select.setMinimum(0)
  142. self.time_frame_select.setMaximum(0)
  143. self.time_frame_select.setValue(0)
  144. self.time_frame_select.setTickPosition(qt.QSlider.TicksBelow)
  145. self.time_frame_select.setTickInterval(5)
  146. self.time_frame_select.toolTip = "Select the time frame"
  147. hbox1.addWidget(self.time_frame_select)
  148. parametersFormLayout.addRow(hbox1)
  149. hbox2 = qt.QHBoxLayout()
  150. meanROILabel = qt.QLabel()
  151. meanROILabel.setText("MeanROI")
  152. hbox2.addWidget(meanROILabel)
  153. self.meanROIVolume = qt.QLineEdit()
  154. self.meanROIVolume.setText("testVolume15")
  155. hbox2.addWidget(self.meanROIVolume)
  156. self.meanROISegment = qt.QLineEdit()
  157. self.meanROISegment.setText("Segment_1")
  158. hbox2.addWidget(self.meanROISegment)
  159. computeMeanROI = qt.QPushButton("Compute mean ROI")
  160. computeMeanROI.connect('clicked(bool)',self.onComputeMeanROIClicked)
  161. hbox2.addWidget(computeMeanROI)
  162. self.meanROIResult = qt.QLineEdit()
  163. self.meanROIResult.setText("0")
  164. hbox2.addWidget(self.meanROIResult)
  165. parametersFormLayout.addRow(hbox2)
  166. #row 3
  167. hbox3 = qt.QHBoxLayout()
  168. drawTimePlot=qt.QPushButton("Draw ROI time plot")
  169. drawTimePlot.connect('clicked(bool)',self.onDrawTimePlotClicked)
  170. hbox3.addWidget(drawTimePlot)
  171. parametersFormLayout.addRow(hbox3)
  172. #dataFormLayout.addWidget(hbox)
  173. #row 4
  174. hbox4 = qt.QHBoxLayout()
  175. countSegments=qt.QPushButton("Count segmentation segments")
  176. countSegments.connect('clicked(bool)',self.onCountSegmentsClicked)
  177. hbox4.addWidget(countSegments)
  178. self.countSegmentsDisplay=qt.QLineEdit()
  179. self.countSegmentsDisplay.setText("0")
  180. hbox4.addWidget(self.countSegmentsDisplay)
  181. parametersFormLayout.addRow(hbox4)
  182. #
  183. # Apply Button
  184. #
  185. self.applyButton = qt.QPushButton("Apply")
  186. self.applyButton.toolTip = "Run the algorithm."
  187. self.applyButton.enabled = False
  188. parametersFormLayout.addRow(self.applyButton)
  189. # connections
  190. self.applyButton.connect('clicked(bool)', self.onApplyButton)
  191. #self.inputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.onSelect)
  192. #self.outputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.onSelect)
  193. # Add vertical spacer
  194. self.layout.addStretch(1)
  195. self.resetPosition=1
  196. #add logic aware of all GUI elements on page
  197. self.logic=cardiacSPECTLogic(self.cfg)
  198. def cleanup(self):
  199. pass
  200. def onApplyButton(self):
  201. pass
  202. #logic = cardiacSPECTLogic()
  203. #imageThreshold = self.imageThresholdSliderWidget.value
  204. def onBrowseButtonClicked(self):
  205. startDir=self.dataPath.text
  206. inputDir=qt.QFileDialog.getExistingDirectory(None,
  207. 'Select DICOM directory',startDir)
  208. self.dataPath.setText("file://"+inputDir)
  209. def onRemoteBrowseButtonClicked(self):
  210. self.selectRemote.show()
  211. def onDataLoadButtonClicked(self):
  212. self.logic.loadData(self)
  213. def onRemotePathTextChanged(self,str):
  214. self.dataPath.setText('labkey://'+str)
  215. def onTimeFrameSelect(self):
  216. it=self.time_frame_select.value
  217. selectionNode = slicer.app.applicationLogic().GetSelectionNode()
  218. print("Propagating CT volume")
  219. nodeName=self.patientId.currentText+'CT'
  220. node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
  221. selectionNode.SetReferenceActiveVolumeID(node.GetID())
  222. if self.resetPosition==1:
  223. self.resetPosition=0
  224. slicer.app.applicationLogic().PropagateVolumeSelection(1)
  225. else:
  226. slicer.app.applicationLogic().PropagateVolumeSelection(0)
  227. print("Propagating SPECT volume")
  228. nodeName=self.patientId.currentText+'Volume'+str(it)
  229. node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
  230. selectionNode.SetSecondaryVolumeID(node.GetID())
  231. slicer.app.applicationLogic().PropagateForegroundVolumeSelection(0)
  232. node.GetDisplayNode().SetAndObserveColorNodeID('vtkMRMLColorTableNodeRed')
  233. lm = slicer.app.layoutManager()
  234. sID=['Red','Yellow','Green']
  235. for s in sID:
  236. sliceLogic = lm.sliceWidget(s).sliceLogic()
  237. compositeNode = sliceLogic.GetSliceCompositeNode()
  238. compositeNode.SetForegroundOpacity(0.5)
  239. #make sure the viewer is matched to the volume
  240. print("Done")
  241. #to access sliceLogic (slice control) use
  242. #lcol=slicer.app.layoutManager().mrmlSliceLogics() (vtkCollection)
  243. #vtkMRMLSliceLogic are named by colors (Red,Green,Blue)
  244. def onComputeMeanROIClicked(self):
  245. s=self.logic.meanROI(self.meanROIVolume.text,self.meanROISegment.text)
  246. self.meanROIResult.setText(str(s))
  247. def onDrawTimePlotClicked(self):
  248. n=self.time_frame_select.maximum+1
  249. ft=self.logic.frame_time
  250. #find number of segments
  251. ns = self.logic.countSegments()
  252. #add the chart node
  253. cn = slicer.mrmlScene.AddNode(slicer.vtkMRMLChartNode())
  254. for j in range(0,ns):
  255. #add node for data
  256. dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
  257. dn.SetSize(n)
  258. dn.SetName(self.patientId.currentText+'_'+self.logic.getSegmentName(j))
  259. dt=0;
  260. t0=0;
  261. for i in range(0,n):
  262. vol=self.patientId.currentText+"Volume"+str(i)
  263. fx=ft[i]
  264. fy=self.logic.meanROI(vol,j)
  265. dt=2*ft[i]-t0
  266. t0+=dt
  267. dn.SetValue(i, 0, fx)
  268. dn.SetValue(i, 1, fy/dt)
  269. dn.SetValue(i, 2, 0)
  270. print("[{0} at {1:.2f}:{2:.2f}]".format(vol,fx,fy))
  271. #fish the number of the segment
  272. cn.AddArray(self.logic.getSegmentName(j), dn.GetID())
  273. cn.SetProperty('default', 'title', 'ROI time plot')
  274. cn.SetProperty('default', 'xAxisLabel', 'time [ms]')
  275. cn.SetProperty('default', 'yAxisLabel', 'Activity (arb)')
  276. #update the chart node
  277. cvns = slicer.mrmlScene.GetNodesByClass('vtkMRMLChartViewNode')
  278. if cvns.GetNumberOfItems() == 0:
  279. cvn = slicer.mrmlScene.AddNode(slicer.vtkMRMLChartViewNode())
  280. else:
  281. cvn = cvns.GetItemAsObject(0)
  282. cvn.SetChartNodeID(cn.GetID())
  283. def onCountSegmentsClicked(self):
  284. self.countSegmentsDisplay.setText(self.logic.countSegments)
  285. def onPatientLoadButtonClicked(self):
  286. self.logic.loadPatient(self.patientId.currentText)
  287. self.time_frame_select.setMaximum(self.logic.frame_data.shape[3]-1)
  288. def onPatientLoadNRRDButtonClicked(self):
  289. self.logic.loadPatientNRRD(self.patientId.currentText)
  290. self.time_frame_select.setMaximum(len(self.logic.frame_time))
  291. def onLoadSegmentationButtonClicked(self):
  292. self.logic.loadSegmentation(self.patientId.currentText)
  293. def onLoadModelButtonClicked(self):
  294. self.logic.loadModelVolume(self.patientId.currentText,self.modelParameter.text)
  295. def onSaveVolumeButtonClicked(self):
  296. self.logic.storeVolumeNodes(self.patientId.currentText,
  297. self.time_frame_select.minimum,self.time_frame_select.maximum)
  298. def onSaveSegmentationButtonClicked(self):
  299. self.logic.storeSegmentation(self.patientId.currentText)
  300. def onSaveTransformationButtonClicked(self):
  301. self.logic.storeTransformation(self.patientId.currentText)
  302. def onSaveInputFunctionButtonClicked(self):
  303. self.logic.storeInputFunction(self.patientId.currentText)
  304. def onTransformNodeButtonClicked(self):
  305. self.logic.applyTransform(self.patientId.currentText, self.refPatientId.currentText,
  306. self.time_frame_select.minimum,self.time_frame_select.maximum)
  307. #def onAddFrameButtonClicked(self):
  308. # it=int(self.time_frame_select.text)
  309. # self.logic.addFrame(it)
  310. # def onAddCTButtonClicked(self):
  311. # self.logic.addCT()
  312. #
  313. #
  314. # cardiacSPECTLogic
  315. #
  316. class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
  317. """This class should implement all the actual
  318. computation done by your module. The interface
  319. should be such that other python code can import
  320. this class and make use of the functionality without
  321. requiring an instance of the Widget.
  322. Uses ScriptedLoadableModuleLogic base class, available at:
  323. https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  324. """
  325. def __init__(self,config):
  326. ScriptedLoadableModuleLogic.__init__(self)
  327. sconfig=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
  328. with open(sconfig) as f:
  329. setup=json.load(f)
  330. sys.path.append(setup['paths']['labkeyInterface'])
  331. import labkeyInterface
  332. import labkeyDatabaseBrowser
  333. import labkeyFileBrowser
  334. self.net=labkeyInterface.labkeyInterface()
  335. fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
  336. self.net.init(fconfig)
  337. self.db=labkeyDatabaseBrowser.labkeyDB(self.net)
  338. self.fb=labkeyFileBrowser.labkeyFileBrowser(self.net)
  339. sys.path.append(setup['paths']['parseDicom'])
  340. import parseDicom
  341. self.pd=parseDicom.parseDicom
  342. self.pd.setFileBrowser(self.fb)
  343. self.tempPath=os.path.join(os.path.expanduser('~'),\
  344. 'temp','cardiacSPECT')
  345. if not os.path.isdir(self.tempPath):
  346. os.makedirs(self.tempPath)
  347. self.pd.setTempBase(self.tempPath)
  348. sys.path.append(setup['paths']['resample'])
  349. import resample
  350. self.resampler=resample.resampleLogic(None)
  351. sys.path.append(setup['paths']['vtkInterface'])
  352. self.cfg=config
  353. def loadData(self,widget):
  354. import vtkInterface
  355. #calculate inputDir from data on form
  356. inputDir=str(widget.dataPath.text)
  357. self.pd.readMasterDirectory(inputDir)
  358. self.frame_data, self.frame_time, self.frame_duration, self.frame_origin, \
  359. self.frame_pixel_size, \
  360. self.frame_orientation=self.pd.readNMDirectory(inputDir)
  361. self.ct_data,self.ct_origin,self.ct_pixel_size, \
  362. self.ct_orientation=self.pd.readCTDirectory(inputDir)
  363. self.ct_orientation=vtkInterface.completeOrientation(self.ct_orientation)
  364. self.frame_orientation=vtkInterface.completeOrientation(self.frame_orientation)
  365. self.addCT('test')
  366. self.addFrames('test')
  367. widget.time_frame_select.setMaximum(self.frame_data.shape[3]-1)
  368. #additional message via qt
  369. qt.QMessageBox.information(
  370. slicer.util.mainWindow(),
  371. 'Slicer Python','Data loaded')
  372. def getFilespecPath(self,r):
  373. if self.cfg["remote"]==True:
  374. return self.fb.formatPathURL(self.cfg['transfer']['project'],\
  375. '/'.join([self.cfg['dicomDir'],r['Study'],r['Series']]))
  376. else:
  377. path=os.path.join(self.cfg["dicomPath"],r["Study"],r["Series"])
  378. return path
  379. def loadPatient(self,patientId):
  380. import vtkInterface
  381. print("Loading {}".format(patientId))
  382. idFilter={'variable':'aliasID','value':patientId,'oper':'eq'}
  383. ds=self.db.selectRows(self.cfg['project'],self.cfg['schemaName'],
  384. self.cfg['queryName'],[idFilter])
  385. visit=ds['rows'][0]
  386. print(visit)
  387. idFilter={'variable':'PatientId','value':visit['aliasID'],'oper':'eq'}
  388. dicoms=self.db.selectRows(self.cfg['transfer']['project'],
  389. self.cfg['transfer']['schemaName'],
  390. self.cfg['transfer']['queryName'],
  391. [idFilter])
  392. for r in dicoms['rows']:
  393. if abs(r['SequenceNum']-float(visit['nmMaster']))<0.1:
  394. masterPath=self.getFilespecPath(r)
  395. #masterPath="labkey://"+path
  396. if abs(r['SequenceNum']-float(visit['nmCorrData']))<0.1:
  397. nmPath=self.getFilespecPath(r)
  398. #nmPath="labkey://"+path
  399. if abs(r['SequenceNum']-float(visit['ctData']))<0.1:
  400. ctPath=self.getFilespecPath(r)
  401. #ctPath="labkey://"+path
  402. self.pd.readMasterDirectory(masterPath)
  403. self.frame_data, self.frame_time, self.frame_duration,self.frame_origin, \
  404. self.frame_pixel_size, self.frame_orientation=\
  405. self.pd.readNMDirectory(nmPath)
  406. self.ct_data,self.ct_origin,self.ct_pixel_size, \
  407. self.ct_orientation=self.pd.readCTDirectory(ctPath)
  408. self.ct_orientation=vtkInterface.completeOrientation(self.ct_orientation)
  409. self.frame_orientation=vtkInterface.completeOrientation(self.frame_orientation)
  410. self.addCT(patientId)
  411. self.addFrames(patientId)
  412. def loadPatientNRRD(self,patientId):
  413. print("Loading NRRD {}".format(patientId))
  414. self.loadDummyInputFunction(patientId)
  415. dnsNode=slicer.util.getFirstNodeByName(patientId+'_Dummy')
  416. if dnsNode==None:
  417. print("Could not find dummy double array node")
  418. return
  419. n=dnsNode.GetSize();
  420. self.frame_time=np.zeros(n);
  421. self.frame_duration=np.zeros(n);
  422. a=vtk.reference(1)
  423. for i in range(0,n):
  424. self.loadVolume(patientId,i)
  425. self.frame_time[i]=dnsNode.GetValue(i,0,a)
  426. self.frame_duration[i]=dnsNode.GetValue(i,1,a)
  427. self.loadCTVolume(patientId)
  428. self.loadSegmentation(patientId)
  429. def loadDummyInputFunction(self,patientId):
  430. self.loadNode(patientId,patientId+'_Dummy','DoubleArrayFile','.mcsv')
  431. def loadVolume(self,patientId,i):
  432. self.loadNode(patientId,patientId+'Volume'+str(i),'VolumeFile')
  433. def loadCTVolume(self,patientId):
  434. self.loadNode(patientId,patientId+'CT','VolumeFile')
  435. def loadModelVolume(self,patientId,name):
  436. node=self.loadNode(patientId,name,'VolumeFile')
  437. if node:
  438. node.SetName(patientId+'_'+name)
  439. def loadSegmentation(self,patientId):
  440. self.loadNode(patientId,'Heart','SegmentationFile')
  441. def loadNode(self,patientId,fName,type,suffix='.nrrd'):
  442. remotePath=self.fb.formatPathURL(self.cfg['project'],
  443. '/'.join([patientId]))
  444. labkeyFile=remotePath+'/'+fName+suffix
  445. localFile=os.path.join(self.tempPath,fName+suffix)
  446. self.fb.readFileToFile(labkeyFile,localFile)
  447. print("Remote: {}".format(labkeyFile))
  448. try:
  449. node=slicer.util.loadNodeFromFile(localFile,type)
  450. except RuntimeError:
  451. return None
  452. os.remove(localFile)
  453. return node
  454. def addNode(self,nodeName,v, lpsOrigin, pixel_size, \
  455. lpsOrientation, dataType):
  456. # if dataType=0 it is CT data, which gets propagated to background an is
  457. #used to fit the view field dimensions
  458. # if dataType=1, it is SPECT data, which gets propagated to foreground
  459. #and is not fit; keeping window set from CT
  460. #nodeName='testVolume'+str(it)
  461. newNode=slicer.vtkMRMLScalarVolumeNode()
  462. newNode.SetName(nodeName)
  463. #pixel_size=[0,0,0]
  464. #pixel_size=v.GetSpacing()
  465. #print(pixel_size)
  466. #origin=[0,0,0]
  467. #origin=v.GetOrigin()
  468. v.SetOrigin([0,0,0])
  469. v.SetSpacing([1,1,1])
  470. ijkToRAS = vtk.vtkMatrix4x4()
  471. #think how to do this with image orientation
  472. rasOrientation=[-lpsOrientation[i] if (i%3 < 2) else lpsOrientation[i]
  473. for i in range(0,len(lpsOrientation))]
  474. rasOrigin=[-lpsOrigin[i] if (i%3<2) else lpsOrigin[i] for i in range(0,len(lpsOrigin))]
  475. for i in range(0,3):
  476. for j in range(0,3):
  477. ijkToRAS.SetElement(i,j,pixel_size[i]*rasOrientation[3*j+i])
  478. ijkToRAS.SetElement(i,3,rasOrigin[i])
  479. newNode.SetIJKToRASMatrix(ijkToRAS)
  480. newNode.SetAndObserveImageData(v)
  481. slicer.mrmlScene.AddNode(newNode)
  482. def addFrames(self,patientId):
  483. import vtkInterface
  484. #convert data from numpy.array to vtkImageData
  485. #use time point it
  486. print("NFrames: {}".format(self.frame_data.shape[3]))
  487. for it in range(0,self.frame_data.shape[3]):
  488. frame_data=self.frame_data[:,:,:,it];
  489. nodeName=patientId+'Volume'+str(it)
  490. self.addNode(nodeName,
  491. vtkInterface.numpyToVTK(frame_data,frame_data.shape),
  492. self.frame_origin,
  493. self.frame_pixel_size,
  494. self.frame_orientation,1)
  495. def addCT(self,patientId):
  496. import vtkInterface
  497. nodeName=patientId+'CT'
  498. self.addNode(nodeName,
  499. #vi.numpyToVTK3D(self.ct_data,
  500. # self.ct_origin,self.ct_pixel_size),
  501. vtkInterface.numpyToVTK(self.ct_data,self.ct_data.shape),
  502. self.ct_origin,self.ct_pixel_size,
  503. self.ct_orientation,0)
  504. def rFromI(i,volumeNode):
  505. ijkToRas = vtk.vtkMatrix4x4()
  506. volumeNode.GetIJKToRASMatrix(ijkToRas)
  507. vImage=volumeNode.GetImageData()
  508. i1=list(vImage.GetPoint(i))
  509. i1=i1.append(1)
  510. #ras are global coordinates (in mm)
  511. position_ras=ijkToRas.MultiplyPoint(i1)
  512. return position_ras[0:3]
  513. def IfromR(pos,volumeNode):
  514. fM=vtk.vtkMatrix4x4()
  515. volumeNode.GetRASToIJKMatrix(fM)
  516. fM.MultiplyPoint(pos)
  517. vImage=volumeNode.GetImageData()
  518. #nearest neighbor
  519. return vImage.FindPoint(pos[0:3])
  520. def getMaskPos(self,mask,i):
  521. maskIJK=mask.GetPoint(i)
  522. maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
  523. maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]
  524. #this is now in extent spacing, whitch ImageToWorldMatrix understands
  525. #to 4D vector for vtkMatrix4x4 handling
  526. maskIJK.append(1)
  527. #go to ras, global coordinates (in mm)
  528. maskImageToWorldMatrix=vtk.vtkMatrix4x4()
  529. mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
  530. maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
  531. return maskPos[0:3]
  532. def meanROI(self, volName1, i):
  533. s=0
  534. #get the segmentation mask
  535. fNode=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode").GetItemAsObject(0)
  536. print("Found segmentation node: {}".format(fNode.GetName()))
  537. segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
  538. #no python bindings for vtkSegmentation
  539. #if segNode.GetSegmentation().GetNumberOfSegments()==0 :
  540. # print("No segments available")
  541. # return 0
  542. #edit here to change for more segments
  543. segment=segNode.GetSegmentation().GetNthSegmentID(int(i))
  544. print('Computing for segment {}'.format(segment))
  545. mask=slicer.vtkOrientedImageData()
  546. segNode.GetBinaryLabelmapRepresentation(segment,mask)
  547. if mask==None:
  548. print("Segment {} not found".format(segment))
  549. return s
  550. print("Got mask for segment {}, npts {}".format(segment,mask.GetNumberOfPoints()))
  551. #get mask at (x,y,z)
  552. #mask.GetPointData().GetScalars().GetTuple1(mask.FindPoint([x,y,z]))
  553. #get the image data
  554. dataNode=slicer.mrmlScene.GetFirstNodeByName(volName1)
  555. dataImage=dataNode.GetImageData()
  556. # use IJK2RAS to get global coordinates
  557. dataRAStoIJK = vtk.vtkMatrix4x4()
  558. dataNode.GetRASToIJKMatrix(dataRAStoIJK)
  559. #allow for interpolation in segmentation pixels
  560. coeff=vtk.vtkImageBSplineCoefficients()
  561. coeff.SetInputData(dataImage)
  562. coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
  563. #between 3 and 5
  564. coeff.SetSplineDegree(5)
  565. coeff.Update()
  566. maskImageToWorldMatrix=vtk.vtkMatrix4x4()
  567. mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
  568. ns=0
  569. maskN=mask.GetNumberOfPoints()
  570. maskScalars=mask.GetPointData().GetScalars()
  571. maskOrigin=[0,0,0]
  572. maskOrigin=mask.GetOrigin()
  573. for i in range(0,maskN):
  574. #skip all points that are 0
  575. if maskScalars.GetTuple1(i)==0:
  576. continue
  577. #get global coordinates of point i
  578. maskPos=self.getMaskPos(mask,i)
  579. #print("Evaluating at {}").format(maskPos)
  580. #convert from global to local
  581. dataPos=[0,0,0]
  582. #account for potentially applied transform on dataNode
  583. dataNode.TransformPointFromWorld(maskPos,dataPos)
  584. dataPos.append(1)
  585. dataIJK=dataRAStoIJK.MultiplyPoint(dataPos)
  586. #drop the 4th index
  587. dataIJK=dataIJK[0:3]
  588. #interpolate
  589. s+=coeff.Evaluate(dataIJK)
  590. ns+=1
  591. return s/ns
  592. def countSegments(self):
  593. segNodeList=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
  594. if segNodeList.GetNumberOfItems()==0:
  595. return 0
  596. fNode=segNodeList.GetItemAsObject(0)
  597. segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
  598. if fNode==None:
  599. return 0
  600. return segNode.GetSegmentation().GetNumberOfSegments()
  601. def getSegmentName(self,i):
  602. segNodeList=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
  603. if segNodeList.GetNumberOfItems()==0:
  604. return "NONE"
  605. fNode=segNodeList.GetItemAsObject(0)
  606. segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
  607. if fNode==None:
  608. return "NONE"
  609. return segNode.GetSegmentation().GetSegment(segNode.GetSegmentation().GetNthSegmentID(i)).GetName()
  610. def storeNodeRemote(self,pathList,nodeName):
  611. node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
  612. if node==None:
  613. print("Node {} not found".format(nodeName))
  614. return
  615. suffix=".nrrd"
  616. if node.__class__.__name__=="vtkMRMLDoubleArrayNode":
  617. suffix=".mcsv"
  618. if (node.__class__.__name__=="vtkMRMLTransformNode" or
  619. node.__class__.__name__=="vtkMRMLGridTransformNode"):
  620. suffix=".h5"
  621. fileName=nodeName+suffix
  622. localPath=os.path.join(self.tempPath,fileName)
  623. slicer.util.saveNode(node,localPath)
  624. print("Stored to: {}".format(localPath))
  625. if self.cfg["remote"]:
  626. labkeyPath=self.fb.buildPathURL(self.cfg['project'],pathList)
  627. print ("Remote: {}".format(labkeyPath))
  628. #checks if exists
  629. remoteFile=labkeyPath+'/'+fileName
  630. self.fb.writeFileToFile(localPath,remoteFile)
  631. def storeVolumeNodes(self,patientId,n1,n2):
  632. #n1=self.time_frame.minimum;
  633. #n2=self.time_frame.maximum
  634. pathList=[patientId]
  635. print("Store CT")
  636. nodeName=patientId+'CT'
  637. self.storeNodeRemote(pathList,nodeName)
  638. #prefer resampled
  639. testNode=slicer.util.getFirstNodeByName(nodeName+"_RS")
  640. if testNode:
  641. nodeName=nodeName+"_RS"
  642. self.storeNodeRemote(pathList,nodeName)
  643. print("Storing NM from {} to {}".format(n1,n2))
  644. n=n2-n1+1
  645. for i in range(n):
  646. it=i+n1
  647. nodeName=patientId+'Volume'+str(it)
  648. self.storeNodeRemote(pathList,nodeName)
  649. #add resampled
  650. testNode=slicer.util.getFirstNodeByName(nodeName+"_RS")
  651. if testNode:
  652. nodeName=nodeName+"_RS"
  653. self.storeNodeRemote(pathList,nodeName)
  654. self.storeDummyInputFunction(patientId)
  655. def storeSegmentation(self,patientId):
  656. pathList=[patientId]
  657. segNodeName="Heart"
  658. self.storeNodeRemote(pathList,segNodeName)
  659. def storeInputFunction(self,patientId):
  660. self.calculateInputFunction(patientId)
  661. relativePath=[patientId]
  662. doubleArrayNodeName=patientId+'_Ventricle'
  663. self.storeNodeRemote(relativePath,doubleArrayNodeName)
  664. def storeDummyInputFunction(self,patientId):
  665. self.calculateDummyInputFunction(patientId)
  666. relativePath=[patientId]
  667. doubleArrayNodeName=patientId+'_Dummy'
  668. self.storeNodeRemote(relativePath,doubleArrayNodeName)
  669. def storeTransformation(self,patientId):
  670. relativePath=[patientId]
  671. transformNodeName=patientId+"_DF"
  672. self.storeNodeRemote(relativePath,transformNodeName)
  673. def applyTransform(self, patientId,refPatientId,n1,n2):
  674. if patientId == refPatientId:
  675. print("Transform [{}] and reference [{}] are the same".format(patientId, refPatientId))
  676. return
  677. transformNodeName=patientId+"_DF"
  678. transformNode=slicer.util.getFirstNodeByName(transformNodeName)
  679. if transformNode==None:
  680. print("Transform node [{}] not found".format(transformNodeName))
  681. return
  682. n=n2-n1+1
  683. for i in range(n):
  684. it=i+n1
  685. nodeName=patientId+'Volume'+str(it)
  686. node=slicer.util.getFirstNodeByName(nodeName)
  687. if node==None:
  688. continue
  689. node.SetAndObserveTransformNodeID(transformNode.GetID())
  690. refNodeName=refPatientId+'Volume'+str(it)
  691. refNode=slicer.util.getFirstNodeByName(refNodeName)
  692. if refNode!=None:
  693. self.resampler.rebinNode(node,refNode)
  694. print("Completed transformation {}".format(it))
  695. #unset transformation
  696. node.SetAndObserveTransformNodeID('NONE')
  697. nodeName=patientId+'CT'
  698. node=slicer.util.getFirstNodeByName(nodeName)
  699. if not node==None:
  700. node.SetAndObserveTransformNodeID(transformNode.GetID())
  701. refNodeName=refPatientId+'CT'
  702. refNode=slicer.util.getFirstNodeByName(refNodeName)
  703. if refNode!=None:
  704. self.resampler.rebinNode(node,refNode)
  705. node.SetAndObserveTransformNodeID('NONE')
  706. def calculateInputFunction(self,patientId):
  707. debug=True
  708. n=len(self.frame_time)
  709. dnsNodeName=patientId+'_Ventricle'
  710. dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode',dnsNodeName)
  711. if dns.GetNumberOfItems() == 0:
  712. dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
  713. dn.SetName(dnsNodeName)
  714. else:
  715. dn = dns.GetItemAsObject(0)
  716. dn.SetSize(n)
  717. fNodes=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
  718. if fNodes.GetNumberOfItems() == 0:
  719. return
  720. fNode=fNodes.GetItemAsObject(0)
  721. segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
  722. segmentation=segNode.GetSegmentation()
  723. juse=-1
  724. for j in range(0,segmentation.GetNumberOfSegments()):
  725. segmentID=segNode.GetSegmentation().GetNthSegmentID(j)
  726. segment=segNode.GetSegmentation().GetSegment(segmentID)
  727. if segment.GetName()=='Ventricle':
  728. juse=j
  729. break
  730. if juse<0:
  731. print('Failed to find Ventricle segment')
  732. return
  733. dt=0;
  734. t0=0;
  735. ft=self.frame_time
  736. dt=self.frame_duration
  737. for i in range(0,n):
  738. vol=patientId+"Volume"+str(i)
  739. fx=ft[i]
  740. fy=self.meanROI(vol,juse)
  741. if debug:
  742. print('{}: t0={} tp={} dt={}'.format(i,t0,fx,dt))
  743. t0+=dt[i]
  744. dn.SetValue(i, 0, fx)
  745. dn.SetValue(i, 1, fy/dt[i])
  746. dn.SetValue(i, 2, 0)
  747. print("[{0} at {1:.2f}:{2:.2f}]".format(vol,fx,fy))
  748. def calculateDummyInputFunction(self,patientId):
  749. n=self.frame_data.shape[3]
  750. dnsNodeName=patientId+'_Dummy'
  751. dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode',dnsNodeName)
  752. if dns.GetNumberOfItems() == 0:
  753. dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
  754. dn.SetName(dnsNodeName)
  755. else:
  756. dn = dns.GetItemAsObject(0)
  757. dn.SetSize(n)
  758. ft=self.frame_time
  759. dt=self.frame_duration
  760. for i in range(0,n):
  761. fx=ft[i]
  762. fy=dt[i]
  763. dn.SetValue(i, 0, fx)
  764. dn.SetValue(i, 1, fy)
  765. dn.SetValue(i, 2, 0)
  766. class cardiacSPECTTest(ScriptedLoadableModuleTest):
  767. """
  768. This is the test case for your scripted module.
  769. Uses ScriptedLoadableModuleTest base class, available at:
  770. https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  771. """
  772. def setUp(self):
  773. """ Do whatever is needed to reset the state - typically a scene clear will be enough.
  774. """
  775. slicer.mrmlScene.Clear(0)
  776. def runTest(self):
  777. """Run as few or as many tests as needed here.
  778. """
  779. self.setUp()
  780. self.test_cardiacSPECT1()
  781. def test_cardiacSPECT1(self):
  782. """ Ideally you should have several levels of tests. At the lowest level
  783. tests should exercise the functionality of the logic with different inputs
  784. (both valid and invalid). At higher levels your tests should emulate the
  785. way the user would interact with your code and confirm that it still works
  786. the way you intended.
  787. One of the most important features of the tests is that it should alert other
  788. developers when their changes will have an impact on the behavior of your
  789. module. For example, if a developer removes a feature that you depend on,
  790. your test should break so they know that the feature is needed.
  791. """
  792. self.delayDisplay("Starting the test")
  793. #
  794. # first, get some data
  795. #
  796. self.delayDisplay('Test passed!')