runSegmentationnnUNet.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import numpy
  10. #nothing gets done if you do import
  11. def getPatientLabel(row,participantField='PatientId'):
  12. return row[participantField].replace('/','_')
  13. def getVisitLabel(row):
  14. return 'VISIT_'+str(int(row['SequenceNum']))
  15. def getStudyLabel(row,participantField='PatientId'):
  16. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  17. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  18. participantField='PatientId'):
  19. row['patientCode']=getPatientLabel(row,participantField)
  20. row['visitCode']=getVisitLabel(row)
  21. for im in imageResampledField:
  22. row[imageResampledField[im]]=gzFileNames[im]
  23. db.modifyRows('update',project,'study',dataset,[row])
  24. def replacePatterns(infile,outfile,replacePatterns):
  25. of=open(outfile,'w')
  26. with open(infile,'r') as f:
  27. data=f.read()
  28. for p in replacePatterns:
  29. val=replacePatterns[p]
  30. data=re.sub(p,val,data)
  31. of.write(data)
  32. of.close()
  33. def valueSubstitution(pars,val):
  34. if val.find('__home__')>-1:
  35. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  36. return path
  37. def getSuffix(tempFile):
  38. p=pathlib.Path(tempFile)
  39. return ''.join(p.suffixes)
  40. def getSegmImagePath(tempFile):
  41. sfx=getSuffix(tempFile)
  42. return re.sub(sfx,'_Segm'+sfx,tempFile)
  43. def addVersion(tempFile,version):
  44. sfx=getSuffix(tempFile)
  45. return re.sub(sfx,'_'+version+sfx,tempFile)
  46. def addnnUNetCode(tempFile,fileNumber=0):
  47. sfx=getSuffix(tempFile)
  48. return re.sub(sfx,'_'+'{:04d}'.format(fileNumber)+sfx,tempFile)
  49. def runnnUNet(setup,pars):
  50. args=[]
  51. #set the environment
  52. args.append(setup['paths']['nnUNetRunInference'])
  53. #location of input images
  54. args.extend(['-i',os.path.join(pars['tempBase'],'CT')])
  55. #output path is segmentations
  56. args.extend(['-o',os.path.join(pars['tempBase'],'segmentations')])
  57. #modelid, nnUNet internal rules.
  58. args.extend(['-t',pars['nnUNet']['ModelId']])
  59. #specify configuration (3d_fullres)
  60. args.extend(['-m',pars['nnUNet']['configuration']])
  61. print(args)
  62. my_env = os.environ
  63. for key in pars['nnUNet']['env']:
  64. my_env[key]=pars['nnUNet']['env'][key]
  65. print(subprocess.run(args,env=my_env,check=True,stdout=subprocess.PIPE).stdout)
  66. def getSegmentationFile(pars):
  67. #this is how deep medic stores files
  68. return os.path.join(pars['tempBase'],'segmentations',\
  69. pars['images']['CT']['tempFile'])
  70. def runSegmentation(fb,row,pars,setup):
  71. #download to temp file (could be a fixed name)
  72. project=pars['project']
  73. images=pars['images']
  74. participantField=pars['participantField']
  75. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  76. getPatientLabel(row,participantField)+'/'+\
  77. getVisitLabel(row))
  78. #download CT
  79. ctDir=os.path.join(pars['tempBase'],'CT')
  80. if not os.path.isdir(ctDir):
  81. os.mkdir(ctDir)
  82. fullFile=os.path.join(ctDir,images['CT']['tempFile'])
  83. fullFile=addnnUNetCode(fullFile)
  84. fb.readFileToFile(baseDir+'/'+row[images['CT']['queryField']],fullFile)
  85. #debug
  86. #run deep medic
  87. runnnUNet(setup,pars)
  88. #processed file is
  89. segFile=getSegmentationFile(pars)
  90. #SimpleITK.WriteImage(outImg,segFile)
  91. return segFile
  92. def test(parameterFile):
  93. fhome=os.path.expanduser('~')
  94. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  95. setup=json.load(f)
  96. sys.path.insert(0,setup["paths"]["nixWrapper"])
  97. import nixWrapper
  98. nixWrapper.loadLibrary("labkeyInterface")#force reload
  99. import labkeyInterface
  100. import labkeyDatabaseBrowser
  101. import labkeyFileBrowser
  102. nixWrapper.loadLibrary("parseConfig")
  103. import parseConfig
  104. with open(parameterFile) as f:
  105. pars=json.load(f)
  106. pars=parseConfig.convert(pars)
  107. pars=parseConfig.convertValues(pars)
  108. #print(pars)
  109. def doSegmentation(parameterFile):
  110. fhome=os.path.expanduser('~')
  111. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  112. setup=json.load(f)
  113. sys.path.insert(0,setup["paths"]["nixWrapper"])
  114. import nixWrapper
  115. nixWrapper.loadLibrary("labkeyInterface")#force reload
  116. import labkeyInterface
  117. import labkeyDatabaseBrowser
  118. import labkeyFileBrowser
  119. nixWrapper.loadLibrary("parseConfig")
  120. import parseConfig
  121. with open(parameterFile) as f:
  122. pars=json.load(f)
  123. pars=parseConfig.convert(pars)
  124. pars=parseConfig.convertValues(pars)
  125. project=pars['project']
  126. dataset=pars['targetQuery']
  127. schema=pars['targetSchema']
  128. view=pars['viewName']
  129. participantField=pars['participantField']
  130. tempBase=pars['tempBase']
  131. if not os.path.isdir(tempBase):
  132. os.makedirs(tempBase)
  133. #start the database interface
  134. fconfig=os.path.join(fhome,'.labkey','network.json')
  135. net=labkeyInterface.labkeyInterface()
  136. net.init(fconfig)
  137. db=labkeyDatabaseBrowser.labkeyDB(net)
  138. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  139. #all images from database
  140. ds=db.selectRows(project,schema,dataset,[],view)
  141. #input
  142. #use webdav to transfer file (even though it is localhost)
  143. i=0
  144. #for debugging
  145. rows=[ds['rows'][0]]
  146. #production mode
  147. rows=ds['rows']
  148. for row in rows:
  149. ctField=row[pars['images']['CT']['queryField']]
  150. if ctField==None:
  151. print('{}/{}: missing resampled CT'.format(row[participantField],row['SequenceNum']))
  152. continue
  153. #build file name
  154. sfx=pars['images']['segmentation']['suffix']
  155. outpath=fb.buildPathURL(pars['project'],\
  156. [pars['imageDir'],row['patientCode'],row['visitCode']])
  157. outName=addVersion(\
  158. getSegmImagePath(\
  159. getStudyLabel(row,pars['participantField'])+sfx),\
  160. pars['version'])
  161. outFile=outpath+'/'+outName
  162. #check if file is there
  163. if not fb.entryExists(outFile):
  164. segFile=getSegmentationFile(pars)
  165. #remove existing file
  166. if os.path.isfile(segFile):
  167. os.remove(segFile)
  168. segFile=runSegmentation(fb,row,pars,setup)
  169. #copy file to file
  170. #normally I would update the targetQuery, but it contains previously set images
  171. #copy to labkey
  172. fb.writeFileToFile(segFile,outFile)
  173. print(segFile)
  174. #debug
  175. #update database
  176. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  177. row['SequenceNum']+=0.001*float(pars['versionNumber'])
  178. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  179. filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
  180. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  181. if len(ds1['rows'])>0:
  182. mode='update'
  183. outRow=ds1['rows'][0]
  184. else:
  185. mode='insert'
  186. outRow={v:row[v] for v in copyFields}
  187. outRow['Version']= pars['version']
  188. outRow['Segmentation']= outName
  189. print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
  190. #push results back to LabKey
  191. print("Done")
  192. if __name__ == '__main__':
  193. #test(sys.argv[1])
  194. doSegmentation(sys.argv[1])
  195. #sys.exit()