runSegmentationDM.py 8.1 KB


  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. #return tempFile
  43. sfx=getSuffix(tempFile)
  44. return re.sub(sfx,'_Segm'+sfx,tempFile)
  45. def addVersion(tempFile,version):
  46. sfx=getSuffix(tempFile)
  47. return re.sub(sfx,'_'+version+sfx,tempFile)
  48. def normalizeCT(ctFile,maskFile):
  49. im=SimpleITK.ReadImage(ctFile)
  50. mask=SimpleITK.ReadImage(maskFile)
  51. nm=SimpleITK.GetArrayFromImage(im)
  52. nmask=SimpleITK.GetArrayFromImage(mask)
  53. mu=numpy.mean(nm[nmask>0])
  54. st=numpy.std(nm[nmask>0])
  55. nm[nmask>0]=(nm[nmask>0]-mu)/st
  56. nm[nmask==0]=0
  57. im1=SimpleITK.GetImageFromArray(nm)
  58. im1.SetOrigin(im.GetOrigin())
  59. im1.SetSpacing(im.GetSpacing())
  60. im1.SetDirection(im.GetDirection())
  61. SimpleITK.WriteImage(im1,ctFile)
  62. def runDeepMedic(setup,pars):
  63. args=[]
  64. #args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
  65. args.append(setup['paths']['deepMedicRun'])
  66. args.append('-model')
  67. args.append(pars['deepmedic']['config']['model']['out'])
  68. args.append('-test')
  69. args.append(pars['deepmedic']['config']['test']['out'])
  70. args.append('-dev')
  71. args.append('cpu')
  72. print(args)
  73. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  74. def runDeepMedicDocker(setup,pars):
  75. args=[]
  76. args.extend(['docker-compose','-f',pars['deepmedic']['segmentationdmYAML'],'up'])
  77. print(args)
  78. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  79. def getSegmentationFile(pars):
  80. #this is how deep medic stores files
  81. return getSegmImagePath(\
  82. os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
  83. pars['images']['segmentation']['tempFile'])
  84. )
  85. def runSegmentation(fb,row,pars,setup):
  86. #download to temp file (could be a fixed name)
  87. project=pars['project']
  88. images=pars['images']
  89. participantField=pars['participantField']
  90. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  91. getPatientLabel(row,participantField)+'/'+\
  92. getVisitLabel(row))
  93. #download
  94. fullFile={key:os.path.join(pars['tempBase'],images[key]['tempFile']) for key in images}
  95. for im in images:
  96. if 'queryField' in images[im]:
  97. fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],fullFile[im])
  98. print('Loaded {}'.format(fullFile[im]))
  99. #normalize
  100. normalizeCT(fullFile['CT'],fullFile['patientmask'])
  101. #update templates to know which files to process
  102. #run deep medic
  103. #runDeepMedicDocker(setup,pars)
  104. runDeepMedic(setup,pars)
  105. #processed file is
  106. segFile=getSegmentationFile(pars)
  107. #SimpleITK.WriteImage(outImg,segFile)
  108. return segFile
  109. def main(parameterFile):
  110. fhome=os.path.expanduser('~')
  111. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  112. setup=json.load(f)
  113. sys.path.insert(0,setup["paths"]["nixWrapper"])
  114. import nixWrapper
  115. nixWrapper.loadLibrary("labkeyInterface")
  116. import labkeyInterface
  117. import labkeyDatabaseBrowser
  118. import labkeyFileBrowser
  119. nixWrapper.loadLibrary("parseConfig")
  120. import parseConfig
  121. with open(parameterFile) as f:
  122. pars=json.load(f)
  123. pars=parseConfig.convert(pars)
  124. pars=parseConfig.convertValues(pars)
  125. print(pars)
  126. #images=pars['images']
  127. #ctFile=os.path.join(pars['tempBase'],images['CT']['tempFile'])
  128. #maskFile=os.path.join(pars['tempBase'],images['patientmask']['tempFile'])
  129. #normalizeCT(ctFile,maskFile)
  130. def doSegmentation(parameterFile):
  131. fhome=os.path.expanduser('~')
  132. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  133. setup=json.load(f)
  134. sys.path.insert(0,setup["paths"]["nixWrapper"])
  135. import nixWrapper
  136. nixWrapper.loadLibrary("labkeyInterface")
  137. import labkeyInterface
  138. import labkeyDatabaseBrowser
  139. import labkeyFileBrowser
  140. nixWrapper.loadLibrary("parseConfig")
  141. import parseConfig
  142. fconfig=os.path.join(fhome,'.labkey','network.json')
  143. net=labkeyInterface.labkeyInterface()
  144. net.init(fconfig)
  145. db=labkeyDatabaseBrowser.labkeyDB(net)
  146. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  147. with open(parameterFile) as f:
  148. pars=json.load(f)
  149. pars=parseConfig.convert(pars,setup)
  150. pars=parseConfig.convertValues(pars)
  151. print(pars)
  152. #update the config
  153. cfg=pars['deepmedic']['config']
  154. for c in cfg:
  155. replacePatterns(cfg[c]['template'],\
  156. cfg[c]['out'],\
  157. pars['replacePattern'])
  158. project=pars['project']
  159. dataset=pars['targetQuery']
  160. schema=pars['targetSchema']
  161. tempBase=pars['tempBase']
  162. if not os.path.isdir(tempBase):
  163. os.makedirs(tempBase)
  164. #all images from database
  165. qFilter=pars['entryFilter']
  166. ds=db.selectRows(project,schema,dataset,qFilter)
  167. print('Got {} rows'.format(len(ds['rows'])))
  168. #input
  169. #use webdav to transfer file (even though it is localhost)
  170. i=0
  171. for row in ds["rows"]:
  172. #check if file is already there
  173. #dummy tf to get the suffix
  174. tf=getSegmentationFile(pars)
  175. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  176. outName=addVersion(\
  177. getSegmImagePath(\
  178. getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
  179. pars['version'])
  180. outFile=outpath+'/'+outName
  181. #check if file is there
  182. if not fb.entryExists(outFile):
  183. segFile=runSegmentation(fb,row,pars,setup)
  184. #copy file to file
  185. #normally I would update the targetQuery, but it contains previously set images
  186. #copy to labkey
  187. fb.writeFileToFile(segFile,outFile)
  188. else:
  189. print('File {} available'.format(outFile))
  190. #separate script (set version!)
  191. #update database
  192. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  193. row['SequenceNum']+=0.001*float(pars['versionNumber'])
  194. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  195. filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
  196. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  197. if len(ds1['rows'])>0:
  198. mode='update'
  199. outRow=ds1['rows'][0]
  200. else:
  201. mode='insert'
  202. outRow={v:row[v] for v in copyFields}
  203. outRow['Version']= pars['version']
  204. outRow['Segmentation']= outName
  205. print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
  206. #push results back to LabKey
  207. i+=1
  208. if i==1 and pars['debug']:
  209. break
  210. print("Done")
  211. if __name__ == '__main__':
  212. #main(sys.argv[1])
  213. doSegmentation(sys.argv[1])
  214. #sys.exit()