workflow.py 15 KB


  1. import config
  2. import getData
  3. import loadData
  4. import fitData
  5. import numpy
  6. import segmentation
  7. import plotData
  8. import os
  9. def listRequiredFiles(stage,r,setup):
  10. code=config.getCode(r,setup)
  11. nclass=setup['nclass'][0]
  12. nr=setup['nr']
  13. nt=20
  14. if stage=='setCenters':
  15. names={x:[config.getPattern(x,code)] for x in ['CT','Dummy']}
  16. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  17. return names
  18. if stage=='fitIVF':
  19. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  20. names['center']=[]
  21. for ir in range(nr):
  22. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  23. names['center'].extend(rel)
  24. return names
  25. if stage=='plotIVF':
  26. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  27. names['center']=[]
  28. names['fitIVF']=[]
  29. qLambda=setup['qLambda']
  30. for ir in range(nr):
  31. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  32. names['center'].extend(rel)
  33. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  34. #names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  35. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  36. names['center'].extend(rel)
  37. #names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
  38. names.update({x:[config.getPattern(x,code)] for x in ['CT']})
  39. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  40. return names
  41. if stage=='fitCompartment':
  42. names={}
  43. names['center']=[]
  44. names['fitIVF']=[]
  45. qLambda=setup['qLambda']
  46. for ir in range(nr):
  47. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  48. names['center'].extend(rel)
  49. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  50. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  51. names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
  52. return names
  53. if stage=='plotCompartment':
  54. names={}
  55. names['center']=[]
  56. names['fitIVF']=[]
  57. names['fitCompartment']=[]
  58. nseg=setup['nseg']
  59. qLambda=setup['qLambda']
  60. for ir in range(nr):
  61. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  62. names['center'].extend(rel)
  63. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  64. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  65. for iseg in range(nseg):
  66. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
  67. names['fitCompartment'].extend(rel)
  68. names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
  69. names.update({x:[config.getPattern(x,code)] for x in ['CT','Dummy']})
  70. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  71. return names
  72. return {}
  73. def listCreatedFiles(stage,r,setup):
  74. code=config.getCode(r,setup)
  75. nclass=setup['nclass'][0]
  76. qLambda=setup['qLambda']
  77. nr=setup['nr']
  78. try:
  79. nseg=setup['nseg']
  80. except KeyError:
  81. nseg=0
  82. names={}
  83. if stage=='setCenters':
  84. names['center']=[]
  85. for ir in range(nr):
  86. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  87. names['center'].extend(rel)
  88. #rel=[config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  89. #names['center'].extend(rel)
  90. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  91. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  92. names['center'].extend(rel)
  93. return names
  94. if stage=='fitIVF':
  95. names['fitIVF']=[]
  96. for ir in range(nr):
  97. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  98. return names
  99. if stage=='plotIVF':
  100. names['plotIVF']=[]
  101. for ir in range(nr):
  102. x=[config.getPattern('plotIVF',code=code,nclass=nclass,ir=ir,qaName=y,qLambda=qLambda)
  103. for y in ['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']]
  104. names['plotIVF'].extend(x)
  105. return names
  106. if stage=='fitCompartment':
  107. xc='fitCompartment'
  108. names[xc]=[]
  109. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  110. for ir in range(nr):
  111. for iseg in range(nseg):
  112. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
  113. names[xc].extend(rel)
  114. if stage=='plotCompartment':
  115. xc='plotCompartment'
  116. names[xc]=[]
  117. sNames=['realizations','diff']
  118. for ir in range(nr):
  119. for iseg in range(nseg):
  120. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
  121. names[xc].extend(rel)
  122. return []
  123. def getRequiredFiles(stage,r,setup,fb,names=None):
  124. #fb,r=getRow(setup,True)
  125. if not names:
  126. names=listRequiredFiles(stage,r,setup)
  127. for f in names:
  128. _copyFromServer=getData.copyFromServer
  129. if f=='segmentation':
  130. _copyFromServer=segmentation.copyFromServer
  131. _copyFromServer(fb,r,setup,names[f])
  132. return fb,r
  133. def checkRequiredFiles(stage,r,setup,names=None,fb=None,doPrint=False):
  134. ok=True
  135. if not names:
  136. names=listRequiredFiles(stage,r,setup)
  137. for f in names:
  138. nm=names[f]
  139. for x in nm:
  140. avail=os.path.isfile(getData.getLocalPath(r,setup,x))
  141. if not avail:
  142. print(f'Missing {x}')
  143. if fb:
  144. _getURL=getData.getURL
  145. if f=='segmentation':
  146. _getURL=segmentation.getURL
  147. availRemote=fb.entryExists(_getURL(fb,r,setup,x))
  148. print(f'Available remote: {availRemote}')
  149. ok=False
  150. if doPrint:
  151. print(f'[{avail}] {x}')
  152. return ok
  153. def uploadCreatedFiles(stage,fb,r,setup,names=None):
  154. if not names:
  155. names=listCreatedFiles(stage,r,setup)
  156. for f in names:
  157. _copyToServer=getData.copyToServer
  158. _getURL=getData.getURL
  159. if f=='segmentation':
  160. _copyToServer=segmentation.copyToServer
  161. _getURL=segmentation.getURL
  162. _copyToServer(fb,r,setup,names[f])
  163. for x in names[f]:
  164. print('[{}] Uploaded {}'.format(fb.entryExists(_getURL(fb,r,setup,x)),x))
  165. #this is a poor fit for workflow, but no better logical unit was found, so here it is
  166. def makeMap(segs,kClass,tac):
  167. map={}
  168. vals=[(kClass[i],tac[:,i]) for i in range(len(kClass))]
  169. for (i,v) in zip(segs,vals):
  170. try:
  171. map[i].append(v)
  172. except KeyError:
  173. map[i]=[v]
  174. return map
  175. #def getDataAtPixels(data,loc) replaced by loadData.getTACAtPixels(data,loc)
  176. def updateDatabase(r,setup,stage,fb=None,db=None,categories=[]):
  177. #set database entry
  178. try:
  179. qLam=setup['qLambda']
  180. except KeyError:
  181. qLam=0
  182. nclass=setup['nclass'][0]
  183. code=config.getCode(r,setup)
  184. if stage=='plotIVF':
  185. m,samples=loadData.readIVF(r,setup,qLambda=qLam)
  186. chi2=samples[0,:]
  187. threshold=numpy.median(chi2)
  188. fit=fitData.getFit(samples,threshold)
  189. row={x:r[x] for x in ['PatientId','visitCode']}
  190. row['nclass']=nclass
  191. row['mean']=fit.mu[0]
  192. row['std']=fit.cov[0,0]
  193. row['qLambda']=qLam
  194. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLam) for x in categories}
  195. row.update(fNames)
  196. if db:
  197. db.modifyRows('insert',setup['project'],'lists','SummaryIVF',[row])
  198. def workflow(r,setup,stage,fb=None,db=None):
  199. setCenters=False
  200. setIVF=False
  201. plotIVF=False
  202. setC=True
  203. try:
  204. qLambda=setup['qLambda']
  205. except KeyError:
  206. qLambda=0
  207. if stage=='setCenters':
  208. names=listRequiredFiles(stage,r,setup)
  209. if fb:
  210. getRequiredFiles(stage,r,setup,fb,names=names)
  211. if not checkRequiredFiles(stage,r,setup,names=names,fb=fb,doPrint=True):
  212. return
  213. loadData.saveCenters(r,setup)
  214. if stage=='fitIVF':
  215. #get required files
  216. #stage='fitIVF'
  217. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True):
  218. return
  219. loadData.saveIVF(r,setup,nfit=30,qLambda=qLambda)
  220. if stage=='plotIVF':
  221. ir=0
  222. names=listRequiredFiles(stage,r,setup)
  223. if fb:
  224. getRequiredFiles(stage,r,setup,fb,names=names)
  225. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
  226. return
  227. print('Loading files to memory')
  228. m,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  229. data=loadData.loadData(r,setup)
  230. ct=loadData.loadCT(r,setup)
  231. centerMapSPECT,centerMapCT=loadData.loadCenterMapNRRD(r,setup,ir=ir)
  232. t,dt=loadData.loadTime(r,setup)
  233. centers=loadData.loadCenters(r,setup,ir=ir)
  234. ivf=centers[m]
  235. chi2=samples[0,:]
  236. threshold=numpy.median(chi2)
  237. code=config.getCode(r,setup)
  238. ir=0
  239. categories=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  240. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLambda) for x in categories}
  241. files={x:getData.getLocalPath(r,setup,fNames[x]) for x in fNames}
  242. plotData.plotIVF(t,ivf,samples,threshold,file0=files['fits'],file1=files['diff'])
  243. plotData.plotIVFRealizations(t,ivf,samples,threshold,file=files['generatedIVF'])
  244. #temporarily blocking center generation
  245. plotData.plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=files['centerIVFSPECT'],
  246. file1=files['centerIVFCT'])
  247. updateDatabase(r,setup,stage,db=db,fb=fb,categories=categories)
  248. if stage=='fitCompartment':
  249. ir=0
  250. names=listRequiredFiles(stage,r,setup)
  251. if fb:
  252. getRequiredFiles(stage,r,setup,fb,names=names)
  253. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
  254. return
  255. #load class classification
  256. u=loadData.loadCenterMap(r,setup)
  257. print(u.shape)
  258. #load segmentation
  259. seg=segmentation.getNRRDImage(r,setup,names)
  260. loc=numpy.nonzero(seg)
  261. vClass=[int(x) for x in u[loc]]
  262. segments=[int(x) for x in seg[loc]]
  263. print(segments)
  264. data=loadData.loadData(r,setup)
  265. tac=loadData.getTACAtPixels(data,loc)
  266. segMap=makeMap(segments,vClass,tac)
  267. #for x in segMap:
  268. # print('{} {}'.format(x,segMap[x]))
  269. #return
  270. m1,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  271. chi2=samples[0,:]
  272. threshold=numpy.median(chi2)
  273. ivfFit=fitData.getFit(samples,threshold)
  274. t,dt=loadData.loadTime(r,setup)
  275. centers=loadData.loadCenters(r,setup)
  276. #save segmentation pixels
  277. setup['nseg']=len(segMap.keys())
  278. for x in segMap:
  279. mArray=segMap[x]
  280. qCenter=numpy.zeros(t.shape[0])
  281. qData=numpy.zeros(t.shape[0])
  282. s=0
  283. #average over contributions for each segmentation included in map
  284. kCenters=[]
  285. for m in mArray:
  286. #m is a tuple of classId and tac
  287. kCenters.append(m[0])
  288. qCenter+=centers[m[0]]
  289. qData+=m[1]
  290. s+=1
  291. qCenter/=s
  292. qData/=s
  293. samplesC=fitData.fitCompartmentGlobal(ivfFit,t,qCenter,useJac=True,nfit=20)
  294. samplesC1=fitData.fitCompartmentGlobal(ivfFit,t,qData,nfit=20,useJac=True)
  295. loadData.saveSamples(r,setup,samplesC,kCenters,'kmeansFit',iseg=x,ir=ir,qLambda=qLambda)
  296. loadData.saveSamples(r,setup,samplesC1,[-1],'localFit',iseg=x,ir=ir,qLambda=qLambda)
  297. loadData.saveTAC(r,setup,qCenter,'kmeansTAC',iseg=x,ir=ir,qLambda=qLambda)
  298. loadData.saveTAC(r,setup,qData,'localTAC',iseg=x,ir=ir,qLambda=qLambda)
  299. if stage=='plotCompartment':
  300. ir=0
  301. names=listRequiredFiles(stage,r,setup)
  302. if fb:
  303. getRequiredFiles(stage,r,setup,fb,names=names)
  304. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
  305. return
  306. tag='plotCompartment'
  307. seg=segmentation.getNRRDImage(r,setup,names)
  308. loc=numpy.nonzero(seg)
  309. segmentIds=list(set([int(x) for x in seg[loc]]))
  310. nclass=setup['nclass'][0]
  311. code=config.getCode(r,setup)
  312. setup['nseg']=len(segmentIds)
  313. t,dt=loadData.loadTime(r,setup)
  314. for iseg in segmentIds:
  315. m,samplesC=loadData.readSamples(r,setup,'kmeansFit',ir=ir,iseg=iseg,qLambda=qLambda)
  316. m1,samplesC1=loadData.readSamples(r,setup,'localFit',ir=ir,iseg=iseg,qLambda=qLambda)
  317. qCenter=loadData.readTAC(r,setup,'kmeansTAC',ir=ir,iseg=iseg,qLambda=qLambda)
  318. qData=loadData.readTAC(r,setup,'localTAC',ir=ir,iseg=iseg,qLambda=qLambda)
  319. chi2C=samplesC[0,:]
  320. threshold=numpy.median(chi2C)
  321. chi2C1=samplesC1[0,:]
  322. threshold1=numpy.median(chi2C1)
  323. fit=fitData.getFit(samplesC,threshold)
  324. fit1=fitData.getFit(samplesC1,threshold1)
  325. k1=fit.mu[0]
  326. stdK1=fit.cov[0,0]
  327. k11=fit1.mu[0]
  328. stdK11=fit1.cov[0,0]
  329. #update database with entries
  330. row={x:r[x] for x in ['PatientId','visitCode']}
  331. row['Date']=datetime.datetime.now().isoformat()
  332. row['nclass']=nclass
  333. row['option']='kmeansFit'
  334. row['mean']=k1
  335. row['std']=stdK1
  336. row['regionId']=iseg
  337. row['fitPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='realizations',iseg=iseg,qLambda=qLambda)
  338. row['diffPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='diff',iseg=iseg,qLambda=qLambda)
  339. row1={x:row[x] for x in row}
  340. row1['option']='localFit'
  341. row1['mean']=k11
  342. row1['std']=stdK11
  343. row['qLambda']=qLambda
  344. if db:
  345. db.modifyRows('insert',setup['project'],'lists','Summary',[row,row1])
  346. evalArray=[(samplesC,qCenter,'blue'),
  347. (samplesC1,qData,'orange')]
  348. file0=getData.getLocalPath(r,setup,row['fitPlot'])
  349. file1=getData.getLocalPath(r,setup,row['diffPlot'])
  350. plotData.plotSamples(t,evalArray,file0=file0,file1=file1)
  351. uploadCreatedFiles(stage,fb,r,setup)