parseDICOM.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import DICOMLib
  2. import sys
  3. import json
  4. import numpy
  5. import zipfile
  6. import shutil
  7. import os
  8. import signal
  9. from contextlib import contextmanager
  10. def main(configFile=None):
  11. print('Imported!')
  12. with open(configFile) as f:
  13. config=json.load(f)
  14. config.update(connectDB(config))
  15. parseData(config,getMeanHeartDose)
  16. class TimeoutException(Exception): pass
  17. @contextmanager
  18. def time_limit(seconds):
  19. def signal_handler(signum, frame):
  20. raise TimeoutException("Timed out!")
  21. signal.signal(signal.SIGALRM, signal_handler)
  22. signal.alarm(seconds)
  23. try:
  24. yield
  25. finally:
  26. signal.alarm(0)
  27. def connectDB(setup):
  28. nixSuite=os.path.join(os.path.expanduser('~'),'software','src','nixsuite')
  29. sys.path.append(os.path.join(nixSuite,'wrapper'))
  30. import nixWrapper
  31. nixWrapper.loadLibrary('labkeyInterface')
  32. import labkeyInterface
  33. import labkeyDatabaseBrowser
  34. import labkeyFileBrowser
  35. nixWrapper.loadLibrary('orthancInterface')
  36. import orthancInterface
  37. import orthancDatabaseBrowser
  38. import orthancFileBrowser
  39. importlib.reload(orthancFileBrowser)
  40. net=labkeyInterface.labkeyInterface()
  41. qfile='{}.json'.format(setup['server'])
  42. fconfig=os.path.join(os.path.expanduser('~'),'.labkey',qfile)
  43. net.init(fconfig)
  44. net.getCSRF()
  45. onet=orthancInterface.orthancInterface()
  46. onet.init(fconfig)
  47. return {"db":labkeyDatabaseBrowser.labkeyDB(net),
  48. "fb":labkeyFileBrowser.labkeyFileBrowser(net),
  49. "odb":orthancDatabaseBrowser.orthancDB(onet),
  50. "ofb":orthancFileBrowser.orthancFileBrowser(onet)}
  51. #explicit template
  52. def _updateRow(config,r):
  53. print(r)
  54. return False
  55. def parseData(config,updateRow=_updateRow):
  56. #iterates over data appliying updateRow function to every row
  57. #updateRow is an implementation of a generic function
  58. #with arguments
  59. # def updateRow(config,r)
  60. #returning True if row needs to be updated on the server
  61. #and False otherwise
  62. #update values are stored in the r dict
  63. db=config['db']
  64. qFilter=config.get('qFilter',[])
  65. debug=config.get('debug',False)
  66. #get dataset
  67. ds=db.selectRows(config['project'],config['schema'],config['query'],qFilter)
  68. rows=ds['rows']
  69. #shorten list in debug mode
  70. if debug:
  71. rows=rows[0:3]
  72. for r in rows:
  73. #this could be a generic function either as parameter of config or an argument to parseData
  74. update=updateRow(config,r)
  75. #print(r)
  76. if not update:
  77. continue
  78. db.modifyRows('update',config['project'],config['schema'],config['query'],[r])
  79. def getMeanHeartDose(config,r):
  80. #calculates mean heart dose
  81. #stores it as doseHeart to row r
  82. #return True if r needs to be updated on the server
  83. #and False if r is unchanged
  84. pid=r['ParticipantId']
  85. sid=r['orthancStudyId']
  86. if not sid:
  87. print(f'No study for {pid}')
  88. return False
  89. if r['comments']=='TIMEOUT':
  90. print(f'Skipping {pid} - timeout')
  91. return False
  92. timeout=config.get('timeout',1200)
  93. doseHeart=r['doseHeart']
  94. if doseHeart:
  95. #no need to update
  96. return False
  97. #outDir=getDicomZip(config,sid)
  98. outDir=getDicomInstances(config,sid)
  99. nodes=loadVolumes(outDir)
  100. msg=checkNodes(config,nodes)
  101. if len(msg)>0:
  102. r['comments']=msg
  103. clearDir(outDir)
  104. removeNodes(nodes)
  105. return True
  106. print(f'Running with timeout={timeout} s')
  107. r['doseHeart']=getMeanDoseRibbon(nodes,'Heart')
  108. clearDir(outDir)
  109. #needs updating
  110. #remove all created nodes from scene
  111. removeNodes(nodes)
  112. return True
  113. def loadVolumes(dataDir):
  114. nodeNames=[]
  115. with DICOMLib.DICOMUtils.TemporaryDICOMDatabase() as db:
  116. DICOMLib.DICOMUtils.importDicom(dataDir, db)
  117. patientUIDs = db.patients()
  118. for patientUID in patientUIDs:
  119. print(patientUID)
  120. nodeNames.extend(DICOMLib.DICOMUtils.loadPatientByUID(patientUID))
  121. #print(nodes)
  122. nodes=[slicer.util.getNode(pattern=n) for n in nodeNames]
  123. volumeNodes=[n for n in nodes if n.GetClassName()=='vtkMRMLScalarVolumeNode']
  124. doseNodes=[n for n in volumeNodes if n.GetName().find('RTDOSE')>-1]
  125. segmentationNodes=[n for n in nodes if n.GetClassName()=='vtkMRMLSegmentationNode']
  126. nv=len(volumeNodes)
  127. ns=len(segmentationNodes)
  128. nd=len(doseNodes)
  129. print(f'vol:{nv} seg:{ns} dose: {nd}')
  130. return {'all':nodes,'vol':volumeNodes,'dose':doseNodes,'seg':segmentationNodes}
  131. def checkNodes(config,nodes):
  132. msg=''
  133. nD=len(nodes['dose'])
  134. if nD>1:
  135. msg+=f'DOSE[{nD}]'
  136. nS=len(nodes['seg'])
  137. if nS>1:
  138. if len(msg)>0:
  139. msg+='/'
  140. msg+=f'SEG[{nS}]'
  141. return msg
  142. def removeNodes(nodes):
  143. for node in nodes['all']:
  144. slicer.mrmlScene.RemoveNode(node)
  145. def getMeanDose(nodes,target):
  146. segNode=nodes['seg'][0]
  147. seg=segNode.GetSegmentation()
  148. segmentIds=seg.GetSegmentIDs()
  149. targetSegmentIds=[s for s in segmentIds if seg.GetSegment(s).GetName()==target]
  150. doseNode=nodes['dose'][0]
  151. doseArray = slicer.util.arrayFromVolume(doseNode)
  152. print('Dose array shape: {}'.format(doseArray.shape))
  153. export=slicer.util.arrayFromSegmentBinaryLabelmap
  154. segmentArray = export(segNode, targetSegmentIds[0], doseNode)
  155. print('Segment array shape: {}'.format(segmentArray.shape))
  156. doseVoxels = doseArray[segmentArray != 0]
  157. meanDose=float(numpy.mean(doseVoxels))
  158. print(f'Dose {meanDose}')
  159. #add a float() to avoid JSON complaining about float32 converion
  160. return meanDose
  161. def getMeanDoseRibbon(nodes,target):
  162. segNode=nodes['seg'][0]
  163. seg=segNode.GetSegmentation()
  164. seg.CreateRepresentation('Ribbon model')
  165. seg.SetSourceRepresentationName('Ribbon model')
  166. segmentIds=seg.GetSegmentIDs()
  167. targetSegmentIds=[s for s in segmentIds if seg.GetSegment(s).GetName()==target]
  168. doseNode=nodes['dose'][0]
  169. doseArray = slicer.util.arrayFromVolume(doseNode)
  170. print('Dose array shape: {}'.format(doseArray.shape))
  171. export=slicer.util.arrayFromSegmentBinaryLabelmap
  172. segmentArray = export(segNode, targetSegmentIds[0], doseNode)
  173. print('Segment array shape: {}'.format(segmentArray.shape))
  174. doseVoxels = doseArray[segmentArray != 0]
  175. meanDose=float(numpy.mean(doseVoxels))
  176. print(f'Dose {meanDose}')
  177. #add a float() to avoid JSON complaining about float32 converion
  178. return meanDose
  179. def getMeanDoseSlicerRT(nodes,target):
  180. #use SlicerRT BatchProcessing example to extract LabelMaps
  181. #alternatove to getMeanDose since it was taking forever to build label map
  182. #no success, still takes forever
  183. segNode=nodes['seg'][0]
  184. seg=segNode.GetSegmentation()
  185. segmentIds=seg.GetSegmentIDs()
  186. targetSegmentIds=[s for s in segmentIds if seg.GetSegment(s).GetName()==target]
  187. doseNode=nodes['dose'][0]
  188. labelMapNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode")
  189. export=slicer.vtkSlicerSegmentationsModuleLogic.ExportSegmentsToLabelmapNode
  190. export(segNode, targetSegmentIds, labelMapNode, doseNode)
  191. doseArray = slicer.util.arrayFromVolume(doseNode)
  192. print('SlicerRT: Dose array shape: {}'.format(doseArray.shape))
  193. segmentArray = slicer.util.arrayFromVolume(labelMapNode)
  194. print('SlicerRT: Segment array shape: {}'.format(segmentArray.shape))
  195. doseVoxels = doseArray[segmentArray != 0]
  196. #add a float() to avoid JSON complaining about float32 converion
  197. meanDose= float(numpy.mean(doseVoxels))
  198. print(f'SlicerRT: Dose {meanDose} Gy')
  199. slicer.mrmlScene.removeNode(labelMapNode)
  200. return meanDose
  201. def getDicomInstances(config,sid):
  202. odb=config['odb']
  203. ofb=config['ofb']
  204. sd=odb.getStudyData(sid)
  205. series=sd['Series']
  206. instances=[]
  207. validModalities=['RTSTRUCT','RTDOSE']
  208. for s in series:
  209. sed=odb.getSeriesData(s)
  210. if sed['MainDicomTags']['Modality'] not in validModalities:
  211. continue
  212. instances.extend(sed['Instances'])
  213. #download instances one by one
  214. baseDir=config.get('baseDir',os.path.join(os.path.expanduser('~'),'temp'))
  215. outDir=os.path.join(baseDir,sid)
  216. clearDir(outDir)
  217. os.mkdir(outDir)
  218. for oid in instances:
  219. local=os.path.join(outDir,f'{oid}.dcm')
  220. ofb.getInstance(oid,local)
  221. return outDir
  222. def getDicomZip(config,sid):
  223. ofb=config['ofb']
  224. baseDir=config.get('baseDir',os.path.join(os.path.expanduser('~'),'temp'))
  225. fname=f'{sid}.zip'
  226. path=os.path.join(baseDir,fname)
  227. if not os.path.isfile(path):
  228. ofb.getZip('studies',sid,path,'archive')
  229. print(f'Using {path}')
  230. #unzip path
  231. outDir=extractZip(config,path)
  232. os.remove(path)
  233. return outDir
  234. def extractZip(config,fname):
  235. #flattens the zip files in the baseDir/bname directory
  236. #where bname is the basename of the file without the .zip suffix
  237. fzip=zipfile.ZipFile(fname)
  238. names=fzip.namelist()
  239. bname=os.path.basename(fname)
  240. bname=bname.replace('.zip','')
  241. baseDir=config['baseDir']
  242. outDir=os.path.join(baseDir,bname)
  243. #clean
  244. clearDir(outDir)
  245. os.mkdir(outDir)
  246. outnames=[os.path.join(outDir,f'out{i:03d}.dcm') for i in range(len(names))]
  247. #extracts and renames (avoids *nix and win directory separator confusion)
  248. for (member,out) in zip(names,outnames):
  249. with fzip.open(member) as zf, open(out, 'wb') as f:
  250. shutil.copyfileobj(zf, f)
  251. return outDir
  252. def clearDir(outDir):
  253. if os.path.isdir(outDir):
  254. shutil.rmtree(outDir)
  255. if __name__=='__main__':
  256. try:
  257. main(sys.argv[1])
  258. except IndexError:
  259. main()
  260. print('Succesful completion')
  261. quit()