runSegmentation.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import os
  2. import json
  3. import re
  4. import subprocess
  5. import nibabel
  6. import shutil
  7. import sys
  8. import pathlib
  9. import SimpleITK
  10. import numpy
  11. #nothing gets done if you do import
  12. def getPatientLabel(row,participantField='PatientId'):
  13. return row[participantField].replace('/','_')
  14. def getVisitLabel(row):
  15. return 'VISIT_'+str(int(row['SequenceNum']))
  16. def getStudyLabel(row,participantField='PatientId'):
  17. return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
  18. def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
  19. participantField='PatientId'):
  20. row['patientCode']=getPatientLabel(row,participantField)
  21. row['visitCode']=getVisitLabel(row)
  22. for im in imageResampledField:
  23. row[imageResampledField[im]]=gzFileNames[im]
  24. db.modifyRows('update',project,'study',dataset,[row])
  25. def replacePatterns(infile,outfile,replacePatterns):
  26. of=open(outfile,'w')
  27. with open(infile,'r') as f:
  28. data=f.read()
  29. for p in replacePatterns:
  30. val=replacePatterns[p]
  31. data=re.sub(p,val,data)
  32. of.write(data)
  33. of.close()
  34. def valueSubstitution(pars,val):
  35. if val.find('__home__')>-1:
  36. val=re.sub(r'__home__',os.path.expanduser('~'),val)
  37. return path
  38. def getCroppedImagePath(tempFile,crop):
  39. p=pathlib.Path(tempFile)
  40. sfx=''.join(p.suffixes)
  41. return re.sub(sfx,crop+sfx,str(p))
  42. def getSuffix(tempFile):
  43. p=pathlib.Path(tempFile)
  44. return ''.join(p.suffixes)
  45. def getSegmImagePath(tempFile):
  46. sfx=getSuffix(tempFile)
  47. return re.sub(sfx,'_Segm'+sfx,tempFile)
  48. def addVersion(tempFile,version):
  49. sfx=getSuffix(tempFile)
  50. return re.sub(sfx,'_'+version+sfx,tempFile)
  51. def normalizeCT(ctFile,maskFile):
  52. im=SimpleITK.ReadImage(ctFile)
  53. mask=SimpleITK.ReadImage(maskFile)
  54. nm=SimpleITK.GetArrayFromImage(im)
  55. nmask=SimpleITK.GetArrayFromImage(mask)
  56. mu=numpy.mean(nm[nmask>0])
  57. st=numpy.std(nm[nmask>0])
  58. nm[nmask>0]=(nm[nmask>0]-mu)/st
  59. nm[nmask==0]=0
  60. im1=SimpleITK.GetImageFromArray(nm)
  61. im1.SetOrigin(im.GetOrigin())
  62. im1.SetSpacing(im.GetSpacing())
  63. im1.SetDirection(im.GetDirection())
  64. SimpleITK.WriteImage(im1,ctFile)
  65. def cropImage(tempFile,crop, cropData):
  66. im=SimpleITK.ReadImage(tempFile)
  67. sz=im.GetSize()
  68. ax=int(cropData['axis'])
  69. rng=[float(v) for v in cropData['range']]
  70. #update cropData['n']
  71. if cropData['n']=="NONE":
  72. cropData['n']=sz[ax]
  73. if not sz[ax]==cropData['n']:
  74. print('Size mismatch {}:{}'.format(sz[ax],cropData['n']))
  75. n=sz[ax]
  76. ii=[int(x*n) for x in rng]
  77. slc=[slice(None) for v in sz]
  78. slc[ax]=slice(ii[0],ii[1])
  79. im1=im[slc]
  80. #im1=im.take(indices=range(i1,i2),axis=cropData['axis'])
  81. SimpleITK.WriteImage(im1,getCroppedImagePath(tempFile,crop))
  82. print("Written {}".format(getCroppedImagePath(tempFile,crop)))
  83. def runDeepMedic(setup,pars):
  84. args=[]
  85. args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
  86. args.append(setup['paths']['deepMedicRun'])
  87. args.append('-model')
  88. args.append(pars['deepmedic']['config']['model']['out'])
  89. args.append('-test')
  90. args.append(pars['deepmedic']['config']['test']['out'])
  91. args.append('-dev')
  92. args.append('cpu')
  93. print(args)
  94. print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
  95. def getSegmentationFile(pars,crop):
  96. #this is how deep medic stores files
  97. return getSegmImagePath(\
  98. getCroppedImagePath(\
  99. os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
  100. pars['images']['images']['segmentations']['tempFile']),crop)
  101. )
  102. def getWeight(x,w):
  103. for r in w:
  104. fw=[float(v) for v in r['range']]
  105. if x>fw[1]:
  106. continue
  107. if x<fw[0]:
  108. continue
  109. n=float(r['n'])
  110. if not 'k' in r:
  111. return n
  112. k=float(r['k'])
  113. return k*x+n
  114. return 0
  115. def runSegmentation(fb,row,pars,setup):
  116. if False:
  117. images=pars['images']['images']
  118. outImg=mergeSegmentations(pars)
  119. segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
  120. SimpleITK.WriteImage(outImg,segFile)
  121. return segFile
  122. #download to temp file (could be a fixed name)
  123. project=pars['project']
  124. images=pars['images']['images']
  125. participantField=pars['participantField']
  126. baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
  127. getPatientLabel(row,participantField)+'/'+\
  128. getVisitLabel(row))
  129. cropData=pars['images']['crop']
  130. #reset n
  131. for crop in cropData:
  132. cropData[crop]['n']="NONE"
  133. #download
  134. for im in images:
  135. tmpFile=images[im]['tempFile']
  136. if 'queryField' in images[im]:
  137. fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],tmpFile)
  138. #normalize
  139. normalizeCT(images['CT']['tempFile'],images['patientmask']['tempFile'])
  140. #crop and store file names
  141. for im in images:
  142. tmpFile=images[im]['tempFile']
  143. with open(images[im]['fileList'],'w') as f:
  144. for crop in cropData:
  145. print('n={}'.format(cropData[crop]['n']))
  146. if os.path.isfile(tmpFile):
  147. cropImage(tmpFile,crop,cropData[crop])
  148. print('n={}'.format(cropData[crop]['n']))
  149. f.write(getCroppedImagePath(tmpFile,crop)+'\n')
  150. #normalize crops
  151. for crop in cropData:
  152. normalizeCT(getCroppedImagePath(images['CT']['tempFile'],crop),
  153. getCroppedImagePath(images['patientmask']['tempFile'],crop))
  154. #run deep medic
  155. runDeepMedic(setup,pars)
  156. #merge segmentations
  157. outImg=mergeSegmentations(pars)
  158. segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
  159. SimpleITK.WriteImage(outImg,segFile)
  160. return segFile
  161. #
  162. def mergeSegmentations(pars):
  163. cropData=pars['images']['crop']
  164. start=True
  165. for c in cropData:
  166. segFile=getSegmentationFile(pars,c)
  167. si=SimpleITK.ReadImage(segFile)
  168. rng=[float(v) for v in cropData[c]['range']]
  169. n=cropData[c]['n']
  170. print(n)
  171. img=SimpleITK.ConstantPad(si,[0,0,int(rng[0]*n)],[0,0,n-int(rng[1]*n)],-1)
  172. print(img.GetSize())
  173. ni=SimpleITK.GetArrayFromImage(img)
  174. print(ni.shape)
  175. w1=numpy.zeros(ni.shape)
  176. aw=[getWeight((x+0.5)/n,cropData[c]['w']) for x in numpy.arange(n)]
  177. for k in numpy.arange(len(aw)):
  178. w1[k,:,:]=aw[k]
  179. if start:
  180. w0=w1
  181. imgTmpl=img
  182. nout=ni
  183. start=False
  184. continue
  185. nout[w1>w0]=ni[w1>w0]
  186. w0[w1>w0]=w1[w1>w0]
  187. iout=SimpleITK.GetImageFromArray(nout)
  188. iout.SetDirection(img.GetDirection())
  189. iout.SetOrigin(img.GetOrigin())
  190. iout.SetSpacing(img.GetSpacing())
  191. return iout
  192. def main(parameterFile):
  193. fhome=os.path.expanduser('~')
  194. with open(os.path.join(fhome,".labkey","setup.json")) as f:
  195. setup=json.load(f)
  196. sys.path.insert(0,setup["paths"]["labkeyInterface"])
  197. import labkeyInterface
  198. import labkeyDatabaseBrowser
  199. import labkeyFileBrowser
  200. sys.path.append(setup['paths']['parseConfig'])
  201. import parseConfig
  202. fconfig=os.path.join(fhome,'.labkey','network.json')
  203. net=labkeyInterface.labkeyInterface()
  204. net.init(fconfig)
  205. db=labkeyDatabaseBrowser.labkeyDB(net)
  206. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  207. with open(parameterFile) as f:
  208. pars=json.load(f)
  209. pars=parseConfig.convert(pars)
  210. pars=parseConfig.convertValues(pars)
  211. print(pars)
  212. #update the config
  213. cfg=pars['deepmedic']['config']
  214. for c in cfg:
  215. replacePatterns(cfg[c]['template'],\
  216. cfg[c]['out'],\
  217. pars['replacePattern'])
  218. project=pars['project']
  219. dataset=pars['targetQuery']
  220. schema=pars['targetSchema']
  221. tempBase=pars['tempBase']
  222. if not os.path.isdir(tempBase):
  223. os.makedirs(tempBase)
  224. #all images from database
  225. ds=db.selectRows(project,schema,dataset,[])
  226. #input
  227. #use webdav to transfer file (even though it is localhost)
  228. i=0
  229. for row in ds["rows"]:
  230. #check if file is already there
  231. #dummy tf to get the suffix
  232. tf=getSegmentationFile(pars,'XX')
  233. outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
  234. outName=addVersion(\
  235. getSegmImagePath(\
  236. getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
  237. pars['version'])
  238. outFile=outpath+'/'+outName
  239. #check if file is there
  240. if not fb.entryExists(outFile):
  241. segFile=runSegmentation(fb,row,pars,setup)
  242. #copy file to file
  243. #normally I would update the targetQuery, but it contains previously set images
  244. #copy to labkey
  245. fb.writeFileToFile(segFile,outFile)
  246. #separate script (set version!)
  247. #update database
  248. copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
  249. filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
  250. ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
  251. if len(ds1['rows'])>0:
  252. mode='update'
  253. outRow=ds1['rows'][0]
  254. else:
  255. mode='insert'
  256. outRow={v:row[v] for v in copyFields}
  257. outRow[pars['version']]= outName
  258. db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow])
  259. #pull results back to LabKey
  260. i+=1
  261. if i==1:
  262. break
  263. print("Done")
  264. if __name__ == '__main__':
  265. main(sys.argv[1])
  266. #sys.exit()