runSegmentationDM.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. sfx=getSuffix(tempFile)
  43. return re.sub(sfx,'_Segm'+sfx,tempFile)
  44. def addVersion(tempFile,version):
  45. sfx=getSuffix(tempFile)
  46. return re.sub(sfx,'_'+version+sfx,tempFile)
  47. def normalizeCT(ctFile,maskFile):
  48. im=SimpleITK.ReadImage(ctFile)
  49. mask=SimpleITK.ReadImage(maskFile)
  50. nm=SimpleITK.GetArrayFromImage(im)
  51. nmask=SimpleITK.GetArrayFromImage(mask)
  52. mu=numpy.mean(nm[nmask>0])
  53. st=numpy.std(nm[nmask>0])
  54. nm[nmask>0]=(nm[nmask>0]-mu)/st
  55. nm[nmask==0]=0
  56. im1=SimpleITK.GetImageFromArray(nm)
  57. im1.SetOrigin(im.GetOrigin())
  58. im1.SetSpacing(im.GetSpacing())
  59. im1.SetDirection(im.GetDirection())
  60. SimpleITK.WriteImage(im1,ctFile)
  61. def runDeepMedic(setup,pars):
  62. args=[]
  63. args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
  64. args.append(setup['paths']['deepMedicRun'])
  65. args.append('-model')
  66. args.append(pars['deepmedic']['config']['model']['out'])
  67. args.append('-test')
  68. args.append(pars['deepmedic']['config']['test']['out'])
  69. args.append('-dev')
  70. args.append('cpu')
  71. print(args)
  72. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  73. def runDeepMedicDocker(setup,pars):
  74. args=[]
  75. args.extend(['docker-compose','-f',pars['deepmedic']['segmentationdmYAML'],'up'])
  76. print(args)
  77. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  78. def getSegmentationFile(pars):
  79. #this is how deep medic stores files
  80. return getSegmImagePath(\
  81. os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
  82. pars['images']['segmentations']['tempFile'])
  83. )
  84. def runSegmentation(fb,row,pars,setup):
  85. #download to temp file (could be a fixed name)
  86. project=pars['project']
  87. images=pars['images']
  88. participantField=pars['participantField']
  89. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  90. getPatientLabel(row,participantField)+'/'+\
  91. getVisitLabel(row))
  92. #download
  93. fullFile={key:os.path.join(pars['tempBase'],images[key]['tempFile']) for key in images}
  94. for im in images:
  95. if 'queryField' in images[im]:
  96. fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],fullFile[im])
  97. #normalize
  98. normalizeCT(fullFile['CT'],fullFile['patientmask'])
  99. #update templates to know which files to process
  100. #run deep medic
  101. runDeepMedicDocker(setup,pars)
  102. #processed file is
  103. segFile=getSegmentationFile(pars)
  104. #SimpleITK.WriteImage(outImg,segFile)
  105. return segFile
  106. def main(parameterFile):
  107. fhome=os.path.expanduser('~')
  108. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  109. setup=json.load(f)
  110. sys.path.insert(0,setup["paths"]["nixWrapper"])
  111. import nixWrapper
  112. nixWrapper.loadLibrary("labkeyInterface")
  113. import labkeyInterface
  114. import labkeyDatabaseBrowser
  115. import labkeyFileBrowser
  116. nixWrapper.loadLibrary("parseConfig")
  117. import parseConfig
  118. with open(parameterFile) as f:
  119. pars=json.load(f)
  120. pars=parseConfig.convert(pars)
  121. pars=parseConfig.convertValues(pars)
  122. print(pars)
  123. #images=pars['images']
  124. #ctFile=os.path.join(pars['tempBase'],images['CT']['tempFile'])
  125. #maskFile=os.path.join(pars['tempBase'],images['patientmask']['tempFile'])
  126. #normalizeCT(ctFile,maskFile)
  127. def doSegmentation(parameterFile):
  128. fhome=os.path.expanduser('~')
  129. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  130. setup=json.load(f)
  131. sys.path.insert(0,setup["paths"]["nixWrapper"])
  132. import nixWrapper
  133. nixWrapper.loadLibrary("labkeyInterface")
  134. import labkeyInterface
  135. import labkeyDatabaseBrowser
  136. import labkeyFileBrowser
  137. nixWrapper.loadLibrary("parseConfig")
  138. import parseConfig
  139. fconfig=os.path.join(fhome,'.labkey','network.json')
  140. net=labkeyInterface.labkeyInterface()
  141. net.init(fconfig)
  142. db=labkeyDatabaseBrowser.labkeyDB(net)
  143. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  144. with open(parameterFile) as f:
  145. pars=json.load(f)
  146. pars=parseConfig.convert(pars)
  147. pars=parseConfig.convertValues(pars)
  148. print(pars)
  149. #update the config
  150. cfg=pars['deepmedic']['config']
  151. for c in cfg:
  152. replacePatterns(cfg[c]['template'],\
  153. cfg[c]['out'],\
  154. pars['replacePattern'])
  155. project=pars['project']
  156. dataset=pars['targetQuery']
  157. schema=pars['targetSchema']
  158. tempBase=pars['tempBase']
  159. if not os.path.isdir(tempBase):
  160. os.makedirs(tempBase)
  161. #all images from database
  162. ds=db.selectRows(project,schema,dataset,[])
  163. #input
  164. #use webdav to transfer file (even though it is localhost)
  165. i=0
  166. for row in ds["rows"]:
  167. #check if file is already there
  168. #dummy tf to get the suffix
  169. tf=getSegmentationFile(pars,'XX')
  170. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  171. outName=addVersion(\
  172. getSegmImagePath(\
  173. getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
  174. pars['version'])
  175. outFile=outpath+'/'+outName
  176. #check if file is there
  177. if not fb.entryExists(outFile):
  178. segFile=runSegmentation(fb,row,pars,setup)
  179. #copy file to file
  180. #normally I would update the targetQuery, but it contains previously set images
  181. #copy to labkey
  182. fb.writeFileToFile(segFile,outFile)
  183. #separate script (set version!)
  184. #update database
  185. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  186. row['SequenceNum']+=0.001*float(pars['versionNumber'])
  187. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  188. filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
  189. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  190. if len(ds1['rows'])>0:
  191. mode='update'
  192. outRow=ds1['rows'][0]
  193. else:
  194. mode='insert'
  195. outRow={v:row[v] for v in copyFields}
  196. outRow['Version']= pars['version']
  197. outRow['Segmentation']= outName
  198. print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
  199. #push results back to LabKey
  200. i+=1
  201. if i==1:
  202. break
  203. print("Done")
  204. if __name__ == '__main__':
  205. main(sys.argv[1])
  206. #sys.exit()