runSegmentation.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getSuffix(tempFile):
  39. p=pathlib.Path(tempFile)
  40. return ''.join(p.suffixes)
  41. def getSegmImagePath(tempFile):
  42. sfx=getSuffix(tempFile)
  43. return re.sub(sfx,'_Segm'+sfx,tempFile)
  44. def addVersion(tempFile,version):
  45. sfx=getSuffix(tempFile)
  46. return re.sub(sfx,'_'+version+sfx,tempFile)
  47. def normalizeCT(ctFile,maskFile):
  48. im=SimpleITK.ReadImage(ctFile)
  49. mask=SimpleITK.ReadImage(maskFile)
  50. nm=SimpleITK.GetArrayFromImage(im)
  51. nmask=SimpleITK.GetArrayFromImage(mask)
  52. mu=numpy.mean(nm[nmask>0])
  53. st=numpy.std(nm[nmask>0])
  54. nm[nmask>0]=(nm[nmask>0]-mu)/st
  55. nm[nmask==0]=0
  56. im1=SimpleITK.GetImageFromArray(nm)
  57. im1.SetOrigin(im.GetOrigin())
  58. im1.SetSpacing(im.GetSpacing())
  59. im1.SetDirection(im.GetDirection())
  60. SimpleITK.WriteImage(im1,ctFile)
  61. def runDeepMedic(setup,pars):
  62. args=[]
  63. args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
  64. args.append(setup['paths']['deepMedicRun'])
  65. args.append('-model')
  66. args.append(pars['deepmedic']['config']['model']['out'])
  67. args.append('-test')
  68. args.append(pars['deepmedic']['config']['test']['out'])
  69. args.append('-dev')
  70. args.append('cpu')
  71. print(args)
  72. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  73. def getSegmentationFile(pars):
  74. #this is how deep medic stores files
  75. return getSegmImagePath(\
  76. os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
  77. pars['images']['images']['segmentations']['tempFile'])
  78. )
  79. def runSegmentation(fb,row,pars,setup):
  80. #download to temp file (could be a fixed name)
  81. project=pars['project']
  82. images=pars['images']['images']
  83. participantField=pars['participantField']
  84. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  85. getPatientLabel(row,participantField)+'/'+\
  86. getVisitLabel(row))
  87. #download
  88. for im in images:
  89. tmpFile=images[im]['tempFile']
  90. if 'queryField' in images[im]:
  91. fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],tmpFile)
  92. #normalize
  93. normalizeCT(images['CT']['tempFile'],images['patientmask']['tempFile'])
  94. #update templates to know which files to process
  95. #run deep medic
  96. #runDeepMedic(setup,pars)
  97. #segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
  98. #SimpleITK.WriteImage(outImg,segFile)
  99. return segFile
  100. def main(parameterFile):
  101. fhome=os.path.expanduser('~')
  102. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  103. setup=json.load(f)
  104. sys.path.insert(0,setup["paths"]["labkeyInterface"])
  105. import labkeyInterface
  106. import labkeyDatabaseBrowser
  107. import labkeyFileBrowser
  108. sys.path.append(setup['paths']['parseConfig'])
  109. import parseConfig
  110. fconfig=os.path.join(fhome,'.labkey','network.json')
  111. net=labkeyInterface.labkeyInterface()
  112. net.init(fconfig)
  113. db=labkeyDatabaseBrowser.labkeyDB(net)
  114. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  115. with open(parameterFile) as f:
  116. pars=json.load(f)
  117. pars=parseConfig.convert(pars)
  118. pars=parseConfig.convertValues(pars)
  119. print(pars)
  120. #update the config
  121. cfg=pars['deepmedic']['config']
  122. for c in cfg:
  123. replacePatterns(cfg[c]['template'],\
  124. cfg[c]['out'],\
  125. pars['replacePattern'])
  126. project=pars['project']
  127. dataset=pars['targetQuery']
  128. schema=pars['targetSchema']
  129. tempBase=pars['tempBase']
  130. if not os.path.isdir(tempBase):
  131. os.makedirs(tempBase)
  132. #all images from database
  133. ds=db.selectRows(project,schema,dataset,[])
  134. #input
  135. #use webdav to transfer file (even though it is localhost)
  136. i=0
  137. for row in ds["rows"]:
  138. #check if file is already there
  139. #dummy tf to get the suffix
  140. tf=getSegmentationFile(pars,'XX')
  141. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  142. outName=addVersion(\
  143. getSegmImagePath(\
  144. getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
  145. pars['version'])
  146. outFile=outpath+'/'+outName
  147. #check if file is there
  148. if not fb.entryExists(outFile):
  149. segFile=runSegmentation(fb,row,pars,setup)
  150. #copy file to file
  151. #normally I would update the targetQuery, but it contains previously set images
  152. #copy to labkey
  153. #fb.writeFileToFile(segFile,outFile)
  154. #separate script (set version!)
  155. #update database
  156. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  157. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  158. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  159. if len(ds1['rows'])>0:
  160. mode='update'
  161. outRow=ds1['rows'][0]
  162. else:
  163. mode='insert'
  164. outRow={v:row[v] for v in copyFields}
  165. outRow[pars['version']]= outName
  166. db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow])
  167. #pull results back to LabKey
  168. i+=1
  169. if i==1:
  170. break
  171. print("Done")
  172. if __name__ == '__main__':
  173. main(sys.argv[1])
  174. #sys.exit()