workflow.py 16 KB


  1. import config
  2. import getData
  3. import loadData
  4. import fitData
  5. import numpy
  6. import segmentation
  7. import plotData
  8. import os
  9. def getRows(setup,returnFB=False):
  10. qFilter=config.getFilter(setup)
  11. db,fb=getData.connectDB(setup['network'])
  12. rows=getData.getPatients(db,setup,qFilter)
  13. #print(rows)
  14. if returnFB:
  15. return fb,db,rows
  16. return rows
  17. def listRequiredFiles(stage,r,setup):
  18. code=config.getCode(r,setup)
  19. nclass=setup['nclass'][0]
  20. nr=setup['nr']
  21. nt=20
  22. if stage=='setCenters':
  23. names={x:[config.getPattern(x,code)] for x in ['CT','Dummy']}
  24. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  25. return names
  26. if stage=='fitIVF':
  27. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  28. names['center']=[]
  29. for ir in range(nr):
  30. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  31. names['center'].extend(rel)
  32. return names
  33. if stage=='plotIVF':
  34. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  35. names['center']=[]
  36. names['fitIVF']=[]
  37. qLambda=setup['qLambda']
  38. for ir in range(nr):
  39. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  40. names['center'].extend(rel)
  41. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  42. #names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  43. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  44. names['center'].extend(rel)
  45. #names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
  46. names.update({x:[config.getPattern(x,code)] for x in ['CT']})
  47. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  48. return names
  49. if stage=='fitCompartment':
  50. names={}
  51. names['center']=[]
  52. names['fitIVF']=[]
  53. qLambda=setup['qLambda']
  54. for ir in range(nr):
  55. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  56. names['center'].extend(rel)
  57. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  58. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  59. names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
  60. return names
  61. if stage=='plotCompartment':
  62. names={}
  63. names['center']=[]
  64. names['fitIVF']=[]
  65. names['fitCompartment']=[]
  66. nseg=setup['nseg']
  67. qLambda=setup['qLambda']
  68. for ir in range(nr):
  69. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  70. names['center'].extend(rel)
  71. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  72. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  73. for iseg in range(nseg):
  74. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
  75. names['fitCompartment'].extend(rel)
  76. names['segmentation']=[segmentation.getSegmentationFileName(r,setup)]
  77. names.update({x:[config.getPattern(x,code)] for x in ['CT','Dummy']})
  78. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  79. return names
  80. return {}
  81. def listCreatedFiles(stage,r,setup):
  82. code=config.getCode(r,setup)
  83. nclass=setup['nclass'][0]
  84. qLambda=setup['qLambda']
  85. nr=setup['nr']
  86. try:
  87. nseg=setup['nseg']
  88. except KeyError:
  89. nseg=0
  90. names={}
  91. if stage=='setCenters':
  92. names['center']=[]
  93. for ir in range(nr):
  94. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  95. names['center'].extend(rel)
  96. #rel=[config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  97. #names['center'].extend(rel)
  98. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  99. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  100. names['center'].extend(rel)
  101. return names
  102. if stage=='fitIVF':
  103. names['fitIVF']=[]
  104. for ir in range(nr):
  105. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  106. return names
  107. if stage=='plotIVF':
  108. names['plotIVF']=[]
  109. for ir in range(nr):
  110. x=[config.getPattern('plotIVF',code=code,nclass=nclass,ir=ir,qaName=y,qLambda=qLambda)
  111. for y in ['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']]
  112. names['plotIVF'].extend(x)
  113. return names
  114. if stage=='fitCompartment':
  115. xc='fitCompartment'
  116. names[xc]=[]
  117. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  118. for ir in range(nr):
  119. for iseg in range(nseg):
  120. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
  121. names[xc].extend(rel)
  122. if stage=='plotCompartment':
  123. xc='plotCompartment'
  124. names[xc]=[]
  125. sNames=['realizations','diff']
  126. for ir in range(nr):
  127. for iseg in range(nseg):
  128. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda) for qn in sNames]
  129. names[xc].extend(rel)
  130. return []
  131. def getRequiredFiles(stage,r,setup,fb,names=None):
  132. #fb,r=getRow(setup,True)
  133. if not names:
  134. names=listRequiredFiles(stage,r,setup)
  135. for f in names:
  136. _copyFromServer=getData.copyFromServer
  137. if f=='segmentation':
  138. _copyFromServer=segmentation.copyFromServer
  139. _copyFromServer(fb,r,setup,names[f])
  140. return fb,r
  141. def checkRequiredFiles(stage,r,setup,names=None,fb=None,doPrint=False):
  142. ok=True
  143. if not names:
  144. names=listRequiredFiles(stage,r,setup)
  145. for f in names:
  146. nm=names[f]
  147. for x in nm:
  148. avail=os.path.isfile(getData.getLocalPath(r,setup,x))
  149. if not avail:
  150. print(f'Missing {x}')
  151. if fb:
  152. _getURL=getData.getURL
  153. if f=='segmentation':
  154. _getURL=segmentation.getURL
  155. availRemote=fb.entryExists(_getURL(fb,r,setup,x))
  156. print(f'Available remote: {availRemote}')
  157. ok=False
  158. if doPrint:
  159. print(f'[{avail}] {x}')
  160. return ok
  161. def uploadCreatedFiles(stage,fb,r,setup,names=None):
  162. if not names:
  163. names=listCreatedFiles(stage,r,setup)
  164. for f in names:
  165. _copyToServer=getData.copyToServer
  166. _getURL=getData.getURL
  167. if f=='segmentation':
  168. _copyToServer=segmentation.copyToServer
  169. _getURL=segmentation.getURL
  170. _copyToServer(fb,r,setup,names[f])
  171. for x in names[f]:
  172. print('[{}] Uploaded {}'.format(fb.entryExists(_getURL(fb,r,setup,x)),x))
  173. #this is a poor fit for workflow, but no better logical unit was found, so here it is
  174. def makeMap(segs,kClass,tac):
  175. map={}
  176. vals=[(kClass[i],tac[:,i]) for i in range(len(kClass))]
  177. for (i,v) in zip(segs,vals):
  178. try:
  179. map[i].append(v)
  180. except KeyError:
  181. map[i]=[v]
  182. return map
  183. #def getDataAtPixels(data,loc) replaced by loadData.getTACAtPixels(data,loc)
  184. def updateDatabase(r,setup,stage,fb=None,db=None,categories=[]):
  185. #set database entry
  186. try:
  187. qLam=setup['qLambda']
  188. except KeyError:
  189. qLam=0
  190. nclass=setup['nclass'][0]
  191. code=config.getCode(r,setup)
  192. if stage=='plotIVF':
  193. m,samples=loadData.readIVF(r,setup,qLambda=qLam)
  194. chi2=samples[0,:]
  195. threshold=numpy.median(chi2)
  196. fit=fitData.getFit(samples,threshold)
  197. row={x:r[x] for x in ['PatientId','visitCode']}
  198. row['nclass']=nclass
  199. row['mean']=fit.mu[0]
  200. row['std']=fit.cov[0,0]
  201. row['qLambda']=qLam
  202. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLam) for x in categories}
  203. row.update(fNames)
  204. if db:
  205. db.modifyRows('insert',setup['project'],'lists','SummaryIVF',[row])
  206. def workflow(r,setup,stage,fb=None,db=None):
  207. setCenters=False
  208. setIVF=False
  209. plotIVF=False
  210. setC=True
  211. try:
  212. qLambda=setup['qLambda']
  213. except KeyError:
  214. qLambda=0
  215. if stage=='setCenters':
  216. names=listRequiredFiles(stage,r,setup)
  217. if fb:
  218. getRequiredFiles(stage,r,setup,fb,names=names)
  219. if not checkRequiredFiles(stage,r,setup,names=names,fb=fb,doPrint=True):
  220. return
  221. loadData.saveCenters(r,setup)
  222. if stage=='fitIVF':
  223. #get required files
  224. #stage='fitIVF'
  225. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True):
  226. return
  227. loadData.saveIVF(r,setup,nfit=30,qLambda=qLambda)
  228. if stage=='plotIVF':
  229. ir=0
  230. names=listRequiredFiles(stage,r,setup)
  231. if fb:
  232. getRequiredFiles(stage,r,setup,fb,names=names)
  233. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
  234. return
  235. print('Loading files to memory')
  236. m,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  237. data=loadData.loadData(r,setup)
  238. ct=loadData.loadCT(r,setup)
  239. centerMapSPECT,centerMapCT=loadData.loadCenterMapNRRD(r,setup,ir=ir)
  240. t,dt=loadData.loadTime(r,setup)
  241. centers=loadData.loadCenters(r,setup,ir=ir)
  242. ivf=centers[m]
  243. chi2=samples[0,:]
  244. threshold=numpy.median(chi2)
  245. code=config.getCode(r,setup)
  246. ir=0
  247. categories=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  248. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLambda) for x in categories}
  249. files={x:getData.getLocalPath(r,setup,fNames[x]) for x in fNames}
  250. plotData.plotIVF(t,ivf,samples,threshold,file0=files['fits'],file1=files['diff'])
  251. plotData.plotIVFRealizations(t,ivf,samples,threshold,file=files['generatedIVF'])
  252. #temporarily blocking center generation
  253. plotData.plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=files['centerIVFSPECT'],
  254. file1=files['centerIVFCT'])
  255. updateDatabase(r,setup,stage,db=db,fb=fb,categories=categories)
  256. if stage=='fitCompartment':
  257. ir=0
  258. names=listRequiredFiles(stage,r,setup)
  259. if fb:
  260. getRequiredFiles(stage,r,setup,fb,names=names)
  261. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
  262. return
  263. #load class classification
  264. u=loadData.loadCenterMap(r,setup)
  265. print(u.shape)
  266. #load segmentation
  267. seg=segmentation.getNRRDImage(r,setup,names)
  268. loc=numpy.nonzero(seg)
  269. vClass=[int(x) for x in u[loc]]
  270. segments=[int(x) for x in seg[loc]]
  271. print(segments)
  272. data=loadData.loadData(r,setup)
  273. tac=loadData.getTACAtPixels(data,loc)
  274. segMap=makeMap(segments,vClass,tac)
  275. #for x in segMap:
  276. # print('{} {}'.format(x,segMap[x]))
  277. #return
  278. m1,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  279. chi2=samples[0,:]
  280. threshold=numpy.median(chi2)
  281. ivfFit=fitData.getFit(samples,threshold)
  282. t,dt=loadData.loadTime(r,setup)
  283. centers=loadData.loadCenters(r,setup)
  284. #save segmentation pixels
  285. setup['nseg']=len(segMap.keys())
  286. for x in segMap:
  287. mArray=segMap[x]
  288. qCenter=numpy.zeros(t.shape[0])
  289. qData=numpy.zeros(t.shape[0])
  290. s=0
  291. #average over contributions for each segmentation included in map
  292. kCenters=[]
  293. for m in mArray:
  294. #m is a tuple of classId and tac
  295. kCenters.append(m[0])
  296. qCenter+=centers[m[0]]
  297. qData+=m[1]
  298. s+=1
  299. qCenter/=s
  300. qData/=s
  301. samplesC=fitData.fitCompartmentGlobal(ivfFit,t,qCenter,useJac=True,nfit=20)
  302. samplesC1=fitData.fitCompartmentGlobal(ivfFit,t,qData,nfit=20,useJac=True)
  303. loadData.saveSamples(r,setup,samplesC,kCenters,'kmeansFit',iseg=x,ir=ir,qLambda=qLambda)
  304. loadData.saveSamples(r,setup,samplesC1,[-1],'localFit',iseg=x,ir=ir,qLambda=qLambda)
  305. loadData.saveTAC(r,setup,qCenter,'kmeansTAC',iseg=x,ir=ir,qLambda=qLambda)
  306. loadData.saveTAC(r,setup,qData,'localTAC',iseg=x,ir=ir,qLambda=qLambda)
  307. if stage=='plotCompartment':
  308. ir=0
  309. names=listRequiredFiles(stage,r,setup)
  310. if fb:
  311. getRequiredFiles(stage,r,setup,fb,names=names)
  312. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names):
  313. return
  314. tag='plotCompartment'
  315. seg=segmentation.getNRRDImage(r,setup,names)
  316. loc=numpy.nonzero(seg)
  317. segmentIds=list(set([int(x) for x in seg[loc]]))
  318. nclass=setup['nclass'][0]
  319. code=config.getCode(r,setup)
  320. setup['nseg']=len(segmentIds)
  321. t,dt=loadData.loadTime(r,setup)
  322. for iseg in segmentIds:
  323. m,samplesC=loadData.readSamples(r,setup,'kmeansFit',ir=ir,iseg=iseg,qLambda=qLambda)
  324. m1,samplesC1=loadData.readSamples(r,setup,'localFit',ir=ir,iseg=iseg,qLambda=qLambda)
  325. qCenter=loadData.readTAC(r,setup,'kmeansTAC',ir=ir,iseg=iseg,qLambda=qLambda)
  326. qData=loadData.readTAC(r,setup,'localTAC',ir=ir,iseg=iseg,qLambda=qLambda)
  327. chi2C=samplesC[0,:]
  328. threshold=numpy.median(chi2C)
  329. chi2C1=samplesC1[0,:]
  330. threshold1=numpy.median(chi2C1)
  331. fit=fitData.getFit(samplesC,threshold)
  332. fit1=fitData.getFit(samplesC1,threshold1)
  333. k1=fit.mu[0]
  334. stdK1=fit.cov[0,0]
  335. k11=fit1.mu[0]
  336. stdK11=fit1.cov[0,0]
  337. #update database with entries
  338. row={x:r[x] for x in ['PatientId','visitCode']}
  339. row['Date']=datetime.datetime.now().isoformat()
  340. row['nclass']=nclass
  341. row['option']='kmeansFit'
  342. row['mean']=k1
  343. row['std']=stdK1
  344. row['regionId']=iseg
  345. row['fitPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='realizations',iseg=iseg,qLambda=qLambda)
  346. row['diffPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='diff',iseg=iseg,qLambda=qLambda)
  347. row1={x:row[x] for x in row}
  348. row1['option']='localFit'
  349. row1['mean']=k11
  350. row1['std']=stdK11
  351. row['qLambda']=qLambda
  352. if db:
  353. db.modifyRows('insert',setup['project'],'lists','Summary',[row,row1])
  354. evalArray=[(samplesC,qCenter,'blue'),
  355. (samplesC1,qData,'orange')]
  356. file0=getData.getLocalPath(r,setup,row['fitPlot'])
  357. file1=getData.getLocalPath(r,setup,row['diffPlot'])
  358. plotData.plotSamples(t,evalArray,file0=file0,file1=file1)
  359. uploadCreatedFiles(stage,fb,r,setup)
  360. def main(setupFile):
  361. with open(setupFile,'r') as f:
  362. setup=json.load(f)
  363. db,fb,rows=getRows(setup,returnFB=True)
  364. for r in rows:
  365. workflow(r,setup,setup['stage'],fb=fb,db=db)
  366. if __name__=="__main__":
  367. main(sys.argv[1])