generateFigures.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import nibabel
  2. import os
  3. import json
  4. import sys
  5. import numpy
  6. import matplotlib.pyplot
  7. #import chardet
  8. def buildPath(server,project,imageDir,patientCode,visitCode,imageType):
  9. path='/'.join([server,'labkey/_webdav',project,'@files',imageDir,patientCode,visitCode])
  10. tail='_notCropped_2mmVoxel'
  11. if imageType=='Segm':
  12. tail='_v5'
  13. path+='/'+patientCode+'-'+visitCode+'_'+imageType+tail+'.nii.gz'
  14. return path
  15. def getCOWAxis(seg,val,axis):
  16. #returns center of weight for segmentation image where val is selected
  17. if axis==0:
  18. #2,1 or 1,1
  19. i1=2
  20. i2=1
  21. if axis==1:
  22. #0,1 or 2,0
  23. i1=2
  24. i2=0
  25. if axis==2:
  26. #0,0 or 1,0
  27. i1=1
  28. i2=0
  29. s=numpy.sum(numpy.sum(seg==val,i1),i2)
  30. x=numpy.arange(len(s))
  31. s0=0
  32. for i in x:
  33. if s[i]==0:
  34. continue
  35. s0=i
  36. break
  37. s1=len(s)-1
  38. revIdx=numpy.arange(len(s),0,-1)-1
  39. for i in revIdx:
  40. if s[i]==0:
  41. continue
  42. s1=i
  43. break
  44. s1+=1
  45. return [s0,numpy.average(x,weights=s),s1]
  46. def getGeometry(seg,val):
  47. #return center of weight of segmentation seg for segment val as a 3D vector
  48. return [getCOWAxis(seg,val,x) for x in [0,1,2]]
  49. def getCOW(geom):
  50. return [x[1] for x in geom]
  51. def getRange(geom):
  52. return [[x[0],x[2]] for x in geom]
  53. def plot(imgs,t,val,tempBase):
  54. segColors=[[0,0,0],[0.1,0.1,0.1],[0,0.2,0],[1,0,0],[0,0,1],[1,0,1]]
  55. #3-lungs 4-thyroid 5-bowel
  56. delta=20
  57. if val==4:
  58. delta=40
  59. window=350
  60. level=40
  61. geo=getGeometry(imgs['Segm'],val)
  62. cowF=getCOW(geo)
  63. rng=getRange(geo)
  64. #print(rng)
  65. cowI=[int(x) for x in cowF]
  66. segment=imgs['Segm']==val
  67. i0=rng[0][0]-delta
  68. if i0<0:
  69. i0=0
  70. i1=rng[0][1]+delta
  71. if i1>imgs['CT'].shape[0]:
  72. i1=imgs['CT'].shape[0]
  73. k0=rng[2][0]-delta
  74. if k0<0:
  75. k0=0
  76. k1=rng[2][1]+delta
  77. if k1>imgs['CT'].shape[2]:
  78. k1=imgs['CT'].shape[2]
  79. if t=='CT':
  80. v0=level-0.5*window
  81. v1=v0+window
  82. matplotlib.pyplot.imshow(imgs['CT'][i0:i1,cowI[1],k0:k1].transpose(),cmap='gray',
  83. vmin=v0,vmax=v1)
  84. if t=='PET':
  85. matplotlib.pyplot.imshow(imgs['PET'][i0:i1,cowI[1],k0:k1].transpose(),cmap='inferno')
  86. #blueish
  87. if t=='CT':
  88. rgb=segColors[val]
  89. if t=='PET':
  90. rgb=[1,1,1]
  91. colors = [rgb+[c] for c in numpy.linspace(0,1,100)]
  92. cmap = matplotlib.colors.LinearSegmentedColormap.from_list('mycmap', colors, N=5)
  93. matplotlib.pyplot.imshow(segment[i0:i1,cowI[1],k0:k1].transpose(), cmap=cmap, alpha=0.2)
  94. matplotlib.pyplot.gca().invert_yaxis()
  95. outfile=os.path.join(tempBase,'slice{}_{}.png'.format(t,val))
  96. matplotlib.pyplot.savefig(outfile)
  97. return outfile
  98. def main(parameterFile):
  99. #mask for segmentations
  100. setupJSON=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
  101. with open(setupJSON) as f:
  102. setup=json.load(f)
  103. sys.path.insert(0,setup["paths"]["nixWrapper"])
  104. import nixWrapper
  105. nixWrapper.loadLibrary("labkeyInterface")
  106. import labkeyInterface
  107. import labkeyDatabaseBrowser
  108. import labkeyFileBrowser
  109. fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
  110. net=labkeyInterface.labkeyInterface()
  111. net.init(fconfig)
  112. db=labkeyDatabaseBrowser.labkeyDB(net)
  113. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  114. tempBase=os.path.join(os.path.expanduser('~'),'temp')
  115. with open(parameterFile) as f:
  116. pars=json.load(f)
  117. project=pars['project']
  118. dataset=pars['targetQuery']
  119. schema=pars['targetSchema']
  120. view=pars['viewName']
  121. reportSchema=pars['reportSchema']
  122. reportQuery=pars['reportQuery']
  123. participantField=pars['participantField']
  124. #all images from database
  125. ds=db.selectRows(project,schema,dataset,[],view)
  126. #input
  127. imageResampledField={"CT":"ctResampled","PET":"petResampled","patientmask":"ROImask"}
  128. rows=ds['rows']
  129. #rows=[ds['rows'][0]]
  130. for r in rows:
  131. print(r)
  132. iTypes=['CT','PET','Segm']
  133. needToCalculate=False
  134. for t in ['CT','PET']:
  135. idFilter={'variable':'patientCode','value':r['patientCode'],'oper':'eq'}
  136. visitFilter={'variable':'visitCode','value':r['visitCode'],'oper':'eq'}
  137. verFilter={'variable':'version','value':pars['version'],'oper':'eq'}
  138. typeFilter={'variable':'type','value':t,'oper':'eq'}
  139. ds2=db.selectRows(project,reportSchema,reportQuery,[idFilter,visitFilter,verFilter,typeFilter])
  140. if len(ds2['rows'])==0:
  141. #skip if row is present
  142. #there are in fact multiple rows for multiple organs...
  143. needToCalculate=True
  144. break
  145. if not needToCalculate:
  146. continue
  147. imgs={}
  148. for t in iTypes:
  149. try:
  150. imagePath=r['_labkeyurl_'+imageResampledField[t]]
  151. except KeyError:
  152. ds1=db.selectRows(project,pars['segmentationSchema'],pars['segmentationQuery'],\
  153. [idFilter,visitFilter,verFilter])
  154. imagePath=ds1['rows'][0]['_labkeyurl_segmentation']
  155. localPath=os.path.join(tempBase,'image'+t+'.nii.gz')
  156. if os.path.isfile(localPath):
  157. os.remove(localPath)
  158. fb.readFileToFile(imagePath,localPath)
  159. img=nibabel.load(localPath)
  160. imgs[t]=img.get_fdata()
  161. print('Loading completed')
  162. for t in ['CT','PET']:
  163. for val in [3,4,5]:
  164. outfile=plot(imgs,t,val,tempBase)
  165. remoteDir=fb.buildPathURL(project,[pars['imageDir'],r['patientCode'],r['visitCode']])
  166. imageFile=r['patientCode']+'-'+r['visitCode']+'_'+t+'_{}'.format(val)+'_'+pars['version']+'.png'
  167. remoteFile='/'.join([remoteDir,imageFile])
  168. fb.writeFileToFile(outfile,remoteFile)
  169. print('Uploaded {}'.format(remoteFile))
  170. os.remove(outfile)
  171. organFilter={'variable':'organ','value':'{}'.format(val),'oper':'eq'}
  172. typeFilter['value']=t
  173. ds3=db.selectRows(project,reportSchema,reportQuery,\
  174. [idFilter,visitFilter,verFilter,organFilter,typeFilter])
  175. if len(ds3['rows'])>0:
  176. mode='update'
  177. frow=ds3['rows'][0]
  178. else:
  179. mode='insert'
  180. frow={}
  181. for f in [participantField,'patientCode','visitCode']:
  182. frow[f]=r[f]
  183. frow['organ']='{}'.format(val)
  184. frow['type']=t
  185. frow['version']=pars['version']
  186. frow['file']=imageFile
  187. db.modifyRows(mode,project,reportSchema,reportQuery,[frow])
  188. print('Images uploaded')
  189. print('Done')
  190. if __name__ == '__main__':
  191. main(sys.argv[1])