runSegmentationnnUNet.py 7.2 KB


  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. sfx=getSuffix(tempFile)
  43. return re.sub(sfx,'_Segm'+sfx,tempFile)
  44. def addVersion(tempFile,version):
  45. sfx=getSuffix(tempFile)
  46. return re.sub(sfx,'_'+version+sfx,tempFile)
  47. def addnnUNetCode(tempFile,fileNumber=0):
  48. sfx=getSuffix(tempFile)
  49. return re.sub(sfx,'_'+'{:04d}'.format(fileNumber)+sfx,tempFile)
  50. def runnnUNet(setup,pars):
  51. args=[]
  52. #set the environment
  53. args.append(setup['paths']['nnUNetRunInference'])
  54. #location of input images
  55. args.extend(['-i',os.path.join(pars['tempBase'],'CT')])
  56. #output path is segmentations
  57. args.extend(['-o',os.path.join(pars['tempBase'],'segmentations')])
  58. #modelid, nnUNet internal rules.
  59. args.extend(['-t',pars['nnUNet']['ModelId']])
  60. #specify configuration (3d_fullres)
  61. args.extend(['-m',pars['nnUNet']['configuration']])
  62. print(args)
  63. my_env = os.environ
  64. for key in pars['nnUNet']['env']:
  65. my_env[key]=pars['nnUNet']['env'][key]
  66. print(subprocess.run(args,env=my_env,check=True,stdout=subprocess.PIPE).stdout)
  67. def getSegmentationFile(pars):
  68. #this is how deep medic stores files
  69. return os.path.join(pars['tempBase'],'segmentations',\
  70. pars['images']['CT']['tempFile'])
  71. def runSegmentation(fb,row,pars,setup):
  72. #download to temp file (could be a fixed name)
  73. project=pars['project']
  74. images=pars['images']
  75. participantField=pars['participantField']
  76. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  77. getPatientLabel(row,participantField)+'/'+\
  78. getVisitLabel(row))
  79. #download CT
  80. ctDir=os.path.join(pars['tempBase'],'CT')
  81. if not os.path.isdir(ctDir):
  82. os.mkdir(ctDir)
  83. fullFile=os.path.join(ctDir,images['CT']['tempFile'])
  84. fullFile=addnnUNetCode(fullFile)
  85. fb.readFileToFile(baseDir+'/'+row[images['CT']['queryField']],fullFile)
  86. #debug
  87. #run deep medic
  88. runnnUNet(setup,pars)
  89. #processed file is
  90. segFile=getSegmentationFile(pars)
  91. #SimpleITK.WriteImage(outImg,segFile)
  92. return segFile
  93. def test(parameterFile):
  94. fhome=os.path.expanduser('~')
  95. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  96. setup=json.load(f)
  97. sys.path.insert(0,setup["paths"]["nixWrapper"])
  98. import nixWrapper
  99. nixWrapper.loadLibrary("labkeyInterface")#force reload
  100. import labkeyInterface
  101. import labkeyDatabaseBrowser
  102. import labkeyFileBrowser
  103. nixWrapper.loadLibrary("parseConfig")
  104. import parseConfig
  105. with open(parameterFile) as f:
  106. pars=json.load(f)
  107. pars=parseConfig.convert(pars)
  108. pars=parseConfig.convertValues(pars)
  109. #print(pars)
  110. def doSegmentation(parameterFile):
  111. fhome=os.path.expanduser('~')
  112. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  113. setup=json.load(f)
  114. sys.path.insert(0,setup["paths"]["nixWrapper"])
  115. import nixWrapper
  116. nixWrapper.loadLibrary("labkeyInterface")#force reload
  117. import labkeyInterface
  118. import labkeyDatabaseBrowser
  119. import labkeyFileBrowser
  120. nixWrapper.loadLibrary("parseConfig")
  121. import parseConfig
  122. with open(parameterFile) as f:
  123. pars=json.load(f)
  124. pars=parseConfig.convert(pars)
  125. pars=parseConfig.convertValues(pars)
  126. project=pars['project']
  127. dataset=pars['targetQuery']
  128. schema=pars['targetSchema']
  129. view=pars['viewName']
  130. tempBase=pars['tempBase']
  131. if not os.path.isdir(tempBase):
  132. os.makedirs(tempBase)
  133. #start the database interface
  134. fconfig=os.path.join(fhome,'.labkey','network.json')
  135. net=labkeyInterface.labkeyInterface()
  136. net.init(fconfig)
  137. db=labkeyDatabaseBrowser.labkeyDB(net)
  138. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  139. #all images from database
  140. ds=db.selectRows(project,schema,dataset,[],view)
  141. #input
  142. #use webdav to transfer file (even though it is localhost)
  143. i=0
  144. rows=[ds['rows'][0]]
  145. rows=ds['rows']
  146. for row in rows:
  147. #check if file is already there
  148. #dummy tf to get the suffix
  149. sfx=pars['images']['segmentation']['suffix']
  150. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  151. outName=addVersion(\
  152. getSegmImagePath(\
  153. getStudyLabel(row,pars['participantField'])+sfx),\
  154. pars['version'])
  155. outFile=outpath+'/'+outName
  156. #check if file is there
  157. if not fb.entryExists(outFile):
  158. segFile=getSegmentationFile(pars)
  159. #remove existing file
  160. if os.path.isfile(segFile):
  161. os.remove(segFile)
  162. segFile=runSegmentation(fb,row,pars,setup)
  163. #copy file to file
  164. #normally I would update the targetQuery, but it contains previously set images
  165. #copy to labkey
  166. fb.writeFileToFile(segFile,outFile)
  167. print(segFile)
  168. #debug
  169. #update database
  170. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  171. row['SequenceNum']+=0.001*float(pars['versionNumber'])
  172. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  173. filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
  174. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  175. if len(ds1['rows'])>0:
  176. mode='update'
  177. outRow=ds1['rows'][0]
  178. else:
  179. mode='insert'
  180. outRow={v:row[v] for v in copyFields}
  181. outRow['Version']= pars['version']
  182. outRow['Segmentation']= outName
  183. print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
  184. #pull results back to LabKey
  185. print("Done")
  186. if __name__ == '__main__':
  187. #test(sys.argv[1])
  188. doSegmentation(sys.argv[1])
  189. #sys.exit()