runSegmentationnnUNet.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. sfx=getSuffix(tempFile)
  43. return re.sub(sfx,'_Segm'+sfx,tempFile)
  44. def addVersion(tempFile,version):
  45. sfx=getSuffix(tempFile)
  46. return re.sub(sfx,'_'+version+sfx,tempFile)
  47. def addnnUNetCode(tempFile,fileNumber=0):
  48. sfx=getSuffix(tempFile)
  49. return re.sub(sfx,'_'+'{:04d}'.format(fileNumber)+sfx,tempFile)
  50. def runnnUNet(setup,pars):
  51. args=[]
  52. #set the environment
  53. args.append(setup['paths']['nnUNetRunInference'])
  54. #location of input images
  55. args.extend(['-i',os.path.join(pars['tempBase'],'CT')])
  56. #output path is segmentations
  57. args.extend(['-o',os.path.join(pars['tempBase'],'segmentations')])
  58. #modelid, nnUNet internal rules.
  59. args.extend(['-t',pars['nnUNet']['ModelId']])
  60. #specify configuration (3d_fullres)
  61. args.extend(['-m',pars['nnUNet']['configuration']])
  62. print(args)
  63. my_env = os.environ
  64. for key in pars['nnUNet']['env']:
  65. my_env[key]=pars['nnUNet']['env'][key]
  66. print(subprocess.run(args,env=my_env,check=True,stdout=subprocess.PIPE).stdout)
  67. def getSegmentationFile(pars):
  68. #this is how deep medic stores files
  69. return os.path.join(pars['tempBase'],'segmentations',\
  70. pars['images']['CT']['tempFile'])
  71. def runSegmentation(fb,row,pars,setup):
  72. #download to temp file (could be a fixed name)
  73. project=pars['project']
  74. images=pars['images']
  75. participantField=pars['participantField']
  76. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  77. getPatientLabel(row,participantField)+'/'+\
  78. getVisitLabel(row))
  79. #download CT
  80. ctDir=os.path.join(pars['tempBase'],'CT')
  81. if not os.path.isdir(ctDir):
  82. os.mkdir(ctDir)
  83. fullFile=os.path.join(ctDir,images['CT']['tempFile'])
  84. fullFile=addnnUNetCode(fullFile)
  85. fb.readFileToFile(baseDir+'/'+row[images['CT']['queryField']],fullFile)
  86. #debug
  87. #run deep medic
  88. runnnUNet(setup,pars)
  89. #processed file is
  90. segFile=getSegmentationFile(pars)
  91. #SimpleITK.WriteImage(outImg,segFile)
  92. return segFile
  93. def test(parameterFile):
  94. fhome=os.path.expanduser('~')
  95. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  96. setup=json.load(f)
  97. sys.path.insert(0,setup["paths"]["nixWrapper"])
  98. import nixWrapper
  99. nixWrapper.loadLibrary("labkeyInterface")#force reload
  100. import labkeyInterface
  101. import labkeyDatabaseBrowser
  102. import labkeyFileBrowser
  103. nixWrapper.loadLibrary("parseConfig")
  104. import parseConfig
  105. with open(parameterFile) as f:
  106. pars=json.load(f)
  107. pars=parseConfig.convert(pars)
  108. pars=parseConfig.convertValues(pars)
  109. #print(pars)
  110. def doSegmentation(parameterFile):
  111. fhome=os.path.expanduser('~')
  112. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  113. setup=json.load(f)
  114. sys.path.insert(0,setup["paths"]["nixWrapper"])
  115. import nixWrapper
  116. nixWrapper.loadLibrary("labkeyInterface")#force reload
  117. import labkeyInterface
  118. import labkeyDatabaseBrowser
  119. import labkeyFileBrowser
  120. nixWrapper.loadLibrary("parseConfig")
  121. import parseConfig
  122. with open(parameterFile) as f:
  123. pars=json.load(f)
  124. pars=parseConfig.convert(pars)
  125. pars=parseConfig.convertValues(pars)
  126. project=pars['project']
  127. dataset=pars['targetQuery']
  128. schema=pars['targetSchema']
  129. view=pars['viewName']
  130. tempBase=pars['tempBase']
  131. if not os.path.isdir(tempBase):
  132. os.makedirs(tempBase)
  133. #start the database interface
  134. fconfig=os.path.join(fhome,'.labkey','network.json')
  135. net=labkeyInterface.labkeyInterface()
  136. net.init(fconfig)
  137. db=labkeyDatabaseBrowser.labkeyDB(net)
  138. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  139. #all images from database
  140. ds=db.selectRows(project,schema,dataset,[],view)
  141. #input
  142. #use webdav to transfer file (even though it is localhost)
  143. i=0
  144. rows=[ds['rows'][0]]
  145. rows=ds['rows']
  146. for row in rows:
  147. #check if file is already there
  148. #dummy tf to get the suffix
  149. sfx=pars['images']['segmentation']['suffix']
  150. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  151. outName=addVersion(\
  152. getSegmImagePath(\
  153. getStudyLabel(row,pars['participantField'])+sfx),\
  154. pars['version'])
  155. outFile=outpath+'/'+outName
  156. #check if file is there
  157. if not fb.entryExists(outFile):
  158. segFile=getSegmentationFile(pars)
  159. #remove existing file
  160. if os.path.isfile(segFile):
  161. os.remove(segFile)
  162. segFile=runSegmentation(fb,row,pars,setup)
  163. #copy file to file
  164. #normally I would update the targetQuery, but it contains previously set images
  165. #copy to labkey
  166. fb.writeFileToFile(segFile,outFile)
  167. print(segFile)
  168. #debug
  169. #update database
  170. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  171. row['SequenceNum']+=0.001*float(pars['versionNumber'])
  172. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  173. filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
  174. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  175. if len(ds1['rows'])>0:
  176. mode='update'
  177. outRow=ds1['rows'][0]
  178. else:
  179. mode='insert'
  180. outRow={v:row[v] for v in copyFields}
  181. outRow['Version']= pars['version']
  182. outRow['Segmentation']= outName
  183. print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
  184. #pull results back to LabKey
  185. print("Done")
  186. if __name__ == '__main__':
  187. #test(sys.argv[1])
  188. doSegmentation(sys.argv[1])
  189. #sys.exit()