runSegmentationDM.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. #return tempFile
  43. sfx=getSuffix(tempFile)
  44. return re.sub(sfx,'_Segm'+sfx,tempFile)
  45. def addVersion(tempFile,version):
  46. sfx=getSuffix(tempFile)
  47. return re.sub(sfx,'_'+version+sfx,tempFile)
  48. def normalizeCT(ctFile,maskFile):
  49. im=SimpleITK.ReadImage(ctFile)
  50. mask=SimpleITK.ReadImage(maskFile)
  51. nm=SimpleITK.GetArrayFromImage(im)
  52. nmask=SimpleITK.GetArrayFromImage(mask)
  53. mu=numpy.mean(nm[nmask>0])
  54. st=numpy.std(nm[nmask>0])
  55. nm[nmask>0]=(nm[nmask>0]-mu)/st
  56. nm[nmask==0]=0
  57. im1=SimpleITK.GetImageFromArray(nm)
  58. im1.SetOrigin(im.GetOrigin())
  59. im1.SetSpacing(im.GetSpacing())
  60. im1.SetDirection(im.GetDirection())
  61. SimpleITK.WriteImage(im1,ctFile)
  62. def runDeepMedic(setup,pars):
  63. args=[]
  64. #args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
  65. args.append(setup['paths']['deepMedicRun'])
  66. args.append('-model')
  67. args.append(pars['deepmedic']['config']['model']['out'])
  68. args.append('-test')
  69. args.append(pars['deepmedic']['config']['test']['out'])
  70. args.append('-dev')
  71. args.append('cpu')
  72. print(args)
  73. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  74. def runDeepMedicDocker(setup,pars):
  75. args=[]
  76. args.extend(['docker-compose','-f',pars['deepmedic']['segmentationdmYAML'],'up'])
  77. print(args)
  78. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  79. def getSegmentationFile(pars):
  80. #this is how deep medic stores files
  81. return getSegmImagePath(\
  82. os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
  83. pars['images']['segmentation']['tempFile'])
  84. )
  85. def runSegmentation(fb,row,pars,setup):
  86. #download to temp file (could be a fixed name)
  87. project=pars['project']
  88. images=pars['images']
  89. participantField=pars['participantField']
  90. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  91. getPatientLabel(row,participantField)+'/'+\
  92. getVisitLabel(row))
  93. #download
  94. fullFile={key:os.path.join(pars['tempBase'],images[key]['tempFile']) for key in images}
  95. for im in images:
  96. if 'queryField' in images[im]:
  97. fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],fullFile[im])
  98. print('Loaded {}'.format(fullFile[im]))
  99. #normalize
  100. normalizeCT(fullFile['CT'],fullFile['patientmask'])
  101. #update templates to know which files to process
  102. #run deep medic
  103. #runDeepMedicDocker(setup,pars)
  104. runDeepMedic(setup,pars)
  105. #processed file is
  106. segFile=getSegmentationFile(pars)
  107. #SimpleITK.WriteImage(outImg,segFile)
  108. return segFile
  109. def main(parameterFile):
  110. fhome=os.path.expanduser('~')
  111. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  112. setup=json.load(f)
  113. sys.path.insert(0,setup["paths"]["nixWrapper"])
  114. import nixWrapper
  115. nixWrapper.loadLibrary("labkeyInterface")
  116. import labkeyInterface
  117. import labkeyDatabaseBrowser
  118. import labkeyFileBrowser
  119. nixWrapper.loadLibrary("parseConfig")
  120. import parseConfig
  121. with open(parameterFile) as f:
  122. pars=json.load(f)
  123. pars=parseConfig.convert(pars)
  124. pars=parseConfig.convertValues(pars)
  125. print(pars)
  126. #images=pars['images']
  127. #ctFile=os.path.join(pars['tempBase'],images['CT']['tempFile'])
  128. #maskFile=os.path.join(pars['tempBase'],images['patientmask']['tempFile'])
  129. #normalizeCT(ctFile,maskFile)
  130. def doSegmentation(parameterFile):
  131. fhome=os.path.expanduser('~')
  132. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  133. setup=json.load(f)
  134. sys.path.insert(0,setup["paths"]["nixWrapper"])
  135. import nixWrapper
  136. nixWrapper.loadLibrary("labkeyInterface")
  137. import labkeyInterface
  138. import labkeyDatabaseBrowser
  139. import labkeyFileBrowser
  140. nixWrapper.loadLibrary("parseConfig")
  141. import parseConfig
  142. fconfig=os.path.join(fhome,'.labkey','network.json')
  143. net=labkeyInterface.labkeyInterface()
  144. net.init(fconfig)
  145. db=labkeyDatabaseBrowser.labkeyDB(net)
  146. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  147. with open(parameterFile) as f:
  148. pars=json.load(f)
  149. pars=parseConfig.convert(pars,setup)
  150. pars=parseConfig.convertValues(pars)
  151. print(pars)
  152. #update the config
  153. cfg=pars['deepmedic']['config']
  154. for c in cfg:
  155. replacePatterns(cfg[c]['template'],\
  156. cfg[c]['out'],\
  157. pars['replacePattern'])
  158. project=pars['project']
  159. dataset=pars['targetQuery']
  160. schema=pars['targetSchema']
  161. tempBase=pars['tempBase']
  162. if not os.path.isdir(tempBase):
  163. os.makedirs(tempBase)
  164. #all images from database
  165. qFilter=pars['entryFilter']
  166. ds=db.selectRows(project,schema,dataset,qFilter)
  167. print('Got {} rows'.format(len(ds['rows'])))
  168. #input
  169. #use webdav to transfer file (even though it is localhost)
  170. i=0
  171. for row in ds["rows"]:
  172. #check if file is already there
  173. #dummy tf to get the suffix
  174. tf=getSegmentationFile(pars)
  175. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  176. outName=addVersion(\
  177. getSegmImagePath(\
  178. getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
  179. pars['version'])
  180. outFile=outpath+'/'+outName
  181. #check if file is there
  182. if not fb.entryExists(outFile):
  183. segFile=runSegmentation(fb,row,pars,setup)
  184. #copy file to file
  185. #normally I would update the targetQuery, but it contains previously set images
  186. #copy to labkey
  187. fb.writeFileToFile(segFile,outFile)
  188. else:
  189. print('File {} available'.format(outFile))
  190. #separate script (set version!)
  191. #update database
  192. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  193. row['SequenceNum']+=0.001*float(pars['versionNumber'])
  194. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  195. filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
  196. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  197. if len(ds1['rows'])>0:
  198. mode='update'
  199. outRow=ds1['rows'][0]
  200. else:
  201. mode='insert'
  202. outRow={v:row[v] for v in copyFields}
  203. outRow['Version']= pars['version']
  204. outRow['Segmentation']= outName
  205. print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
  206. #push results back to LabKey
  207. i+=1
  208. if i==1 and pars['debug']:
  209. break
  210. print("Done")
  211. if __name__ == '__main__':
  212. #main(sys.argv[1])
  213. doSegmentation(sys.argv[1])
  214. #sys.exit()