runSegmentation.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. sfx=getSuffix(tempFile)
  43. return re.sub(sfx,'_Segm'+sfx,tempFile)
  44. def addVersion(tempFile,version):
  45. sfx=getSuffix(tempFile)
  46. return re.sub(sfx,'_'+version+sfx,tempFile)
  47. def normalizeCT(ctFile,maskFile):
  48. im=SimpleITK.ReadImage(ctFile)
  49. mask=SimpleITK.ReadImage(maskFile)
  50. nm=SimpleITK.GetArrayFromImage(im)
  51. nmask=SimpleITK.GetArrayFromImage(mask)
  52. mu=numpy.mean(nm[nmask>0])
  53. st=numpy.std(nm[nmask>0])
  54. nm[nmask>0]=(nm[nmask>0]-mu)/st
  55. nm[nmask==0]=0
  56. im1=SimpleITK.GetImageFromArray(nm)
  57. im1.SetOrigin(im.GetOrigin())
  58. im1.SetSpacing(im.GetSpacing())
  59. im1.SetDirection(im.GetDirection())
  60. SimpleITK.WriteImage(im1,ctFile)
  61. def runDeepMedic(setup,pars):
  62. args=[]
  63. args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
  64. args.append(setup['paths']['deepMedicRun'])
  65. args.append('-model')
  66. args.append(pars['deepmedic']['config']['model']['out'])
  67. args.append('-test')
  68. args.append(pars['deepmedic']['config']['test']['out'])
  69. args.append('-dev')
  70. args.append('cpu')
  71. print(args)
  72. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  73. def runDeepMedicDocker(setup,pars):
  74. args=[]
  75. args.extend(['docker-compose','-f',pars['deepmedic']['segmentationdmYAML'],'up'])
  76. print(args)
  77. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  78. def getSegmentationFile(pars):
  79. #this is how deep medic stores files
  80. return getSegmImagePath(\
  81. os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
  82. pars['images']['segmentations']['tempFile'])
  83. )
  84. def runSegmentation(fb,row,pars,setup):
  85. #download to temp file (could be a fixed name)
  86. project=pars['project']
  87. images=pars['images']
  88. participantField=pars['participantField']
  89. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  90. getPatientLabel(row,participantField)+'/'+\
  91. getVisitLabel(row))
  92. #download
  93. fullFile={key:os.path.join(pars['tempBase'],images[key]['tempFile']) for key in images}
  94. for im in images:
  95. if 'queryField' in images[im]:
  96. fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],fullFile[im])
  97. #normalize
  98. normalizeCT(fullFile['CT'],fullFile['patientmask'])
  99. #update templates to know which files to process
  100. #run deep medic
  101. runDeepMedicDocker(setup,pars)
  102. #processed file is
  103. segFile=getSegmentationFile(pars)
  104. #SimpleITK.WriteImage(outImg,segFile)
  105. return segFile
  106. def main(parameterFile):
  107. fhome=os.path.expanduser('~')
  108. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  109. setup=json.load(f)
  110. sys.path.insert(0,setup["paths"]["labkeyInterface"])
  111. import labkeyInterface
  112. import labkeyDatabaseBrowser
  113. import labkeyFileBrowser
  114. sys.path.append(setup['paths']['parseConfig'])
  115. import parseConfig
  116. with open(parameterFile) as f:
  117. pars=json.load(f)
  118. pars=parseConfig.convert(pars)
  119. pars=parseConfig.convertValues(pars)
  120. print(pars)
  121. #images=pars['images']
  122. #ctFile=os.path.join(pars['tempBase'],images['CT']['tempFile'])
  123. #maskFile=os.path.join(pars['tempBase'],images['patientmask']['tempFile'])
  124. #normalizeCT(ctFile,maskFile)
  125. def doSegmentation(parameterFile):
  126. fhome=os.path.expanduser('~')
  127. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  128. setup=json.load(f)
  129. sys.path.insert(0,setup["paths"]["labkeyInterface"])
  130. import labkeyInterface
  131. import labkeyDatabaseBrowser
  132. import labkeyFileBrowser
  133. sys.path.append(setup['paths']['parseConfig'])
  134. import parseConfig
  135. fconfig=os.path.join(fhome,'.labkey','network.json')
  136. net=labkeyInterface.labkeyInterface()
  137. net.init(fconfig)
  138. db=labkeyDatabaseBrowser.labkeyDB(net)
  139. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  140. with open(parameterFile) as f:
  141. pars=json.load(f)
  142. pars=parseConfig.convert(pars)
  143. pars=parseConfig.convertValues(pars)
  144. print(pars)
  145. #update the config
  146. cfg=pars['deepmedic']['config']
  147. for c in cfg:
  148. replacePatterns(cfg[c]['template'],\
  149. cfg[c]['out'],\
  150. pars['replacePattern'])
  151. project=pars['project']
  152. dataset=pars['targetQuery']
  153. schema=pars['targetSchema']
  154. tempBase=pars['tempBase']
  155. if not os.path.isdir(tempBase):
  156. os.makedirs(tempBase)
  157. #all images from database
  158. ds=db.selectRows(project,schema,dataset,[])
  159. #input
  160. #use webdav to transfer file (even though it is localhost)
  161. i=0
  162. for row in ds["rows"]:
  163. #check if file is already there
  164. #dummy tf to get the suffix
  165. tf=getSegmentationFile(pars,'XX')
  166. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  167. outName=addVersion(\
  168. getSegmImagePath(\
  169. getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
  170. pars['version'])
  171. outFile=outpath+'/'+outName
  172. #check if file is there
  173. if not fb.entryExists(outFile):
  174. segFile=runSegmentation(fb,row,pars,setup)
  175. #copy file to file
  176. #normally I would update the targetQuery, but it contains previously set images
  177. #copy to labkey
  178. #fb.writeFileToFile(segFile,outFile)
  179. #separate script (set version!)
  180. #update database
  181. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  182. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  183. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  184. if len(ds1['rows'])>0:
  185. mode='update'
  186. outRow=ds1['rows'][0]
  187. else:
  188. mode='insert'
  189. outRow={v:row[v] for v in copyFields}
  190. outRow[pars['version']]= outName
  191. db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow])
  192. #pull results back to LabKey
  193. i+=1
  194. if i==1:
  195. break
  196. print("Done")
  197. if __name__ == '__main__':
  198. main(sys.argv[1])
  199. #sys.exit()