generateFigures.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import nibabel
  2. import os
  3. import json
  4. import sys
  5. import numpy
  6. import matplotlib.pyplot
  7. #import chardet
  8. def buildPath(server,project,imageDir,patientCode,visitCode,imageType):
  9. path='/'.join([server,'labkey/_webdav',project,'@files',imageDir,patientCode,visitCode])
  10. tail='_notCropped_2mmVoxel'
  11. if imageType=='Segm':
  12. tail='_v5'
  13. path+='/'+patientCode+'-'+visitCode+'_'+imageType+tail+'.nii.gz'
  14. return path
  15. def getCOWAxis(seg,val,axis):
  16. #returns center of weight for segmentation image where val is selected
  17. if axis==0:
  18. #2,1 or 1,1
  19. i1=2
  20. i2=1
  21. if axis==1:
  22. #0,1 or 2,0
  23. i1=2
  24. i2=0
  25. if axis==2:
  26. #0,0 or 1,0
  27. i1=1
  28. i2=0
  29. s=numpy.sum(numpy.sum(seg==val,i1),i2)
  30. x=numpy.arange(len(s))
  31. s0=0
  32. for i in x:
  33. if s[i]==0:
  34. continue
  35. s0=i
  36. break
  37. s1=len(s)-1
  38. revIdx=numpy.arange(len(s),0,-1)-1
  39. for i in revIdx:
  40. if s[i]==0:
  41. continue
  42. s1=i
  43. break
  44. s1+=1
  45. try:
  46. sm=numpy.average(x,weights=s)
  47. except ZeroDivisionError:
  48. print('getCOWaxis - Zero division error')
  49. raise
  50. return [s0,numpy.average(x,weights=s),s1]
  51. def getGeometry(seg,val):
  52. #return center of weight of segmentation seg for segment val as a 3D vector
  53. return [getCOWAxis(seg,val,x) for x in [0,1,2]]
  54. def getCOW(geom):
  55. return [x[1] for x in geom]
  56. def getRange(geom):
  57. return [[x[0],x[2]] for x in geom]
  58. def plot(imgs,t,val,tempBase):
  59. segColors=[[0,0,0],[0.1,0.1,0.1],[0,0.2,0],[1,0,0],[0,0,1],[1,0,1]]
  60. #3-lungs 4-thyroid 5-bowel
  61. delta=20
  62. if val==4:
  63. delta=40
  64. window=350
  65. level=40
  66. geo=getGeometry(imgs['Segm'],val)
  67. cowF=getCOW(geo)
  68. rng=getRange(geo)
  69. #print(rng)
  70. cowI=[int(x) for x in cowF]
  71. segment=imgs['Segm']==val
  72. i0=rng[0][0]-delta
  73. if i0<0:
  74. i0=0
  75. i1=rng[0][1]+delta
  76. if i1>imgs['CT'].shape[0]:
  77. i1=imgs['CT'].shape[0]
  78. k0=rng[2][0]-delta
  79. if k0<0:
  80. k0=0
  81. k1=rng[2][1]+delta
  82. if k1>imgs['CT'].shape[2]:
  83. k1=imgs['CT'].shape[2]
  84. if t=='CT':
  85. v0=level-0.5*window
  86. v1=v0+window
  87. matplotlib.pyplot.imshow(imgs['CT'][i0:i1,cowI[1],k0:k1].transpose(),cmap='gray',
  88. vmin=v0,vmax=v1)
  89. if t=='PET':
  90. matplotlib.pyplot.imshow(imgs['PET'][i0:i1,cowI[1],k0:k1].transpose(),cmap='inferno')
  91. #blueish
  92. if t=='CT':
  93. rgb=segColors[val]
  94. if t=='PET':
  95. rgb=[1,1,1]
  96. colors = [rgb+[c] for c in numpy.linspace(0,1,100)]
  97. cmap = matplotlib.colors.LinearSegmentedColormap.from_list('mycmap', colors, N=5)
  98. matplotlib.pyplot.imshow(segment[i0:i1,cowI[1],k0:k1].transpose(), cmap=cmap, alpha=0.2)
  99. matplotlib.pyplot.gca().invert_yaxis()
  100. outfile=os.path.join(tempBase,'slice{}_{}.png'.format(t,val))
  101. matplotlib.pyplot.savefig(outfile)
  102. return outfile
  103. def main(parameterFile):
  104. #mask for segmentations
  105. setupJSON=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
  106. with open(setupJSON) as f:
  107. setup=json.load(f)
  108. sys.path.insert(0,setup["paths"]["nixWrapper"])
  109. import nixWrapper
  110. nixWrapper.loadLibrary("labkeyInterface")
  111. import labkeyInterface
  112. import labkeyDatabaseBrowser
  113. import labkeyFileBrowser
  114. fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
  115. net=labkeyInterface.labkeyInterface()
  116. net.init(fconfig)
  117. db=labkeyDatabaseBrowser.labkeyDB(net)
  118. fb=labkeyFileBrowser.labkeyFileBrowser(net)
  119. tempBase=os.path.join(os.path.expanduser('~'),'temp')
  120. with open(parameterFile) as f:
  121. pars=json.load(f)
  122. project=pars['project']
  123. dataset=pars['targetQuery']
  124. schema=pars['targetSchema']
  125. view=pars['viewName']
  126. reportSchema=pars['reportSchema']
  127. reportQuery=pars['reportQuery']
  128. participantField=pars['participantField']
  129. #all images from database
  130. ds=db.selectRows(project,schema,dataset,[],view)
  131. #input
  132. imageResampledField={"CT":"ctResampled","PET":"petResampled","patientmask":"ROImask"}
  133. rows=ds['rows']
  134. #rows=[r for r in rows if r[participantField]=='8701/08']
  135. #rows=[ds['rows'][0]]
  136. for r in rows:
  137. missingCodes=[r[f]==None for f in ['patientCode','visitCode']]
  138. if any(missingCodes):
  139. print('[{}/{}] - Skipping, missing codes'.\
  140. format(r[participantField],r['SequenceNum']))
  141. continue
  142. #print(r)
  143. iTypes=['CT','PET','Segm']
  144. needToCalculate=False
  145. for t in ['CT','PET']:
  146. idFilter={'variable':'patientCode','value':r['patientCode'],'oper':'eq'}
  147. visitFilter={'variable':'visitCode','value':r['visitCode'],'oper':'eq'}
  148. verFilter={'variable':'version','value':pars['version'],'oper':'eq'}
  149. typeFilter={'variable':'type','value':t,'oper':'eq'}
  150. ds2=db.selectRows(project,reportSchema,reportQuery,\
  151. [idFilter,visitFilter,verFilter,typeFilter])
  152. if len(ds2['rows'])==0:
  153. #skip if row is present
  154. #there are in fact multiple rows for multiple organs...
  155. needToCalculate=True
  156. break
  157. if not needToCalculate:
  158. print('[{}/{}] - done'.format(r[participantField],r['SequenceNum']))
  159. continue
  160. ds1=db.selectRows(project,pars['segmentationSchema'],pars['segmentationQuery'],\
  161. [idFilter,visitFilter,verFilter])
  162. #check if CT, PET and Segm images are set
  163. imagesAvailable=[r[imageResampledField[t]] for t in ['CT','PET']]
  164. imagesAvailable=[f!=None for f in imagesAvailable]
  165. try:
  166. imagesAvailable.append(ds1['rows'][0]['segmentation']!=None)
  167. except IndexError:
  168. imagesAvailable.append(False)
  169. if not all(imagesAvailable):
  170. print('[{}/{}] Skipping - not all images available :{}'.\
  171. format(r[participantField],r['SequenceNum'],imagesAvailable))
  172. continue
  173. imgs={}
  174. for t in iTypes:
  175. try:
  176. imagePath=r['_labkeyurl_'+imageResampledField[t]]
  177. except KeyError:
  178. imagePath=ds1['rows'][0]['_labkeyurl_segmentation']
  179. localPath=os.path.join(tempBase,'image'+t+'.nii.gz')
  180. if os.path.isfile(localPath):
  181. os.remove(localPath)
  182. fb.readFileToFile(imagePath,localPath)
  183. img=nibabel.load(localPath)
  184. imgs[t]=img.get_fdata()
  185. print('Loading completed')
  186. for t in ['CT','PET']:
  187. for val in [3,4,5]:
  188. try:
  189. outfile=plot(imgs,t,val,tempBase)
  190. except ZeroDivisionError:
  191. print('[{}/{}] - Skipping. Failed to plot for organ [{}]'.\
  192. format(r[participantField],r['SequenceNum'],val))
  193. continue
  194. remoteDir=fb.buildPathURL(project,
  195. [pars['imageDir'],r['patientCode'],r['visitCode']])
  196. imageFile=r['patientCode']+'-'+r['visitCode']+'_'+t+'_{}'.\
  197. format(val)+'_'+pars['version']+'.png'
  198. remoteFile='/'.join([remoteDir,imageFile])
  199. fb.writeFileToFile(outfile,remoteFile)
  200. print('Uploaded {}'.format(remoteFile))
  201. os.remove(outfile)
  202. organFilter={'variable':'organ','value':'{}'.format(val),'oper':'eq'}
  203. typeFilter['value']=t
  204. ds3=db.selectRows(project,reportSchema,reportQuery,\
  205. [idFilter,visitFilter,verFilter,organFilter,typeFilter])
  206. if len(ds3['rows'])>0:
  207. mode='update'
  208. frow=ds3['rows'][0]
  209. else:
  210. mode='insert'
  211. frow={}
  212. for f in [participantField,'patientCode','visitCode']:
  213. frow[f]=r[f]
  214. frow['organ']='{}'.format(val)
  215. frow['type']=t
  216. frow['version']=pars['version']
  217. frow['file']=imageFile
  218. db.modifyRows(mode,project,reportSchema,reportQuery,[frow])
  219. print('Images uploaded')
  220. print('Done')
  221. if __name__ == '__main__':
  222. main(sys.argv[1])