workflow.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import config
  2. import getData
  3. import loadData
  4. import fitData
  5. import numpy
  6. import segmentation
  7. import plotData
  8. import os
  9. import datetime
  10. def getRows(setup,returnFB=False):
  11. qFilter=config.getFilter(setup)
  12. db,fb=getData.connectDB(setup['network'])
  13. rows=getData.getPatients(db,setup,qFilter)
  14. #print(rows)
  15. if returnFB:
  16. return fb,db,rows
  17. return rows
  18. def getRobust(setup,var,default=0):
  19. #get from dict, return default if not set
  20. try:
  21. return setup[var]
  22. except KeyError:
  23. return default
  24. def listRequiredFiles(stage,r,setup,db=None):
  25. code=config.getCode(r,setup)
  26. nclass=setup['nclass'][0]
  27. nr=setup['nr']
  28. nt=20
  29. qLambda=getRobust(setup,'qLambda')
  30. qLambdaC=getRobust(setup,'qLambdaC')
  31. if stage=='setCenters':
  32. names={x:[config.getPattern(x,code)] for x in ['CT','Dummy']}
  33. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  34. return names
  35. if stage=='fitIVF':
  36. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  37. names['center']=[]
  38. for ir in range(nr):
  39. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  40. names['center'].extend(rel)
  41. return names
  42. if stage=='plotIVF':
  43. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  44. names['center']=[]
  45. names['fitIVF']=[]
  46. for ir in range(nr):
  47. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  48. names['center'].extend(rel)
  49. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  50. #names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  51. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  52. names['center'].extend(rel)
  53. names.update({x:[config.getPattern(x,code)] for x in ['CT']})
  54. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  55. return names
  56. if stage=='fitCompartment':
  57. names={}
  58. names['center']=[]
  59. names['fitIVF']=[]
  60. for ir in range(nr):
  61. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  62. names['center'].extend(rel)
  63. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  64. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  65. names['segmentation']=[segmentation.getSegmentationFileName(r,setup,db=db)]
  66. return names
  67. if stage=='plotCompartment':
  68. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  69. names={}
  70. names['center']=[]
  71. names['fitIVF']=[]
  72. names['fitCompartment']=[]
  73. nseg=setup['nseg']
  74. for ir in range(nr):
  75. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  76. names['center'].extend(rel)
  77. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  78. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  79. for iseg in range(nseg):
  80. xc='fitCompartment'
  81. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC) for qn in sNames]
  82. names['fitCompartment'].extend(rel)
  83. names['segmentation']=[segmentation.getSegmentationFileName(r,setup,db=db)]
  84. names.update({x:[config.getPattern(x,code)] for x in ['CT','Dummy']})
  85. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  86. return names
  87. return {}
  88. def listCreatedFiles(stage,r,setup):
  89. code=config.getCode(r,setup)
  90. nclass=setup['nclass'][0]
  91. qLambda=getRobust(setup,'qLambda')
  92. qLambdaC=getRobust(setup,'qLambdaC')
  93. nr=setup['nr']
  94. nseg=getRobust(setup,'nseg')
  95. print(f'List created files nseg={nseg}')
  96. names={}
  97. getPattern=config.getPattern
  98. if stage=='setCenters':
  99. names['center']=[]
  100. for ir in range(nr):
  101. rel=[getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  102. names['center'].extend(rel)
  103. #rel=[config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  104. #names['center'].extend(rel)
  105. names['center'].append(getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  106. rel=[getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  107. names['center'].extend(rel)
  108. return names
  109. if stage=='fitIVF':
  110. names['fitIVF']=[]
  111. for ir in range(nr):
  112. names['fitIVF'].append(getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  113. return names
  114. if stage=='plotIVF':
  115. names['plotIVF']=[]
  116. sNames=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  117. for ir in range(nr):
  118. x=[getPattern('plotIVF',code=code,nclass=nclass,ir=ir,qaName=y,qLambda=qLambda) for y in sNames]
  119. names['plotIVF'].extend(x)
  120. return names
  121. if stage=='fitCompartment':
  122. xc='fitCompartment'
  123. names[xc]=[]
  124. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  125. for ir in range(nr):
  126. for iseg in range(nseg):
  127. _q=qLambda
  128. _qC=qLambdaC
  129. rel=[getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=_q,qLambdaC=_qC) for qn in sNames]
  130. names[xc].extend(rel)
  131. return names
  132. if stage=='plotCompartment':
  133. xc=stage
  134. names[xc]=[]
  135. sNames=['realizations','diff']
  136. for ir in range(nr):
  137. for s in range(nseg):
  138. iseg=s+1
  139. _q=qLambda
  140. _qC=qLambdaC
  141. rel=[getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=_q,qLambdaC=_qC) for qn in sNames]
  142. names[xc].extend(rel)
  143. return names
  144. return []
  145. def getRequiredFiles(stage,r,setup,fb,names=None,db=None):
  146. #fb,r=getRow(setup,True)
  147. if not names:
  148. names=listRequiredFiles(stage,r,setup,db=db)
  149. for f in names:
  150. _copyFromServer=getData.copyFromServer
  151. if f=='segmentation':
  152. _copyFromServer=segmentation.copyFromServer
  153. _copyFromServer(fb,r,setup,names[f])
  154. return fb,r
  155. def checkRequiredFiles(stage,r,setup,names=None,fb=None,doPrint=False,db=None):
  156. ok=True
  157. if not names:
  158. names=listRequiredFiles(stage,r,setup,db=db)
  159. for f in names:
  160. nm=names[f]
  161. for x in nm:
  162. avail=os.path.isfile(getData.getLocalPath(r,setup,x))
  163. if not avail:
  164. print(f'Missing {x}')
  165. if fb:
  166. _getURL=getData.getURL
  167. if f=='segmentation':
  168. _getURL=segmentation.getURL
  169. availRemote=fb.entryExists(_getURL(fb,r,setup,x))
  170. print(f'Available remote: {availRemote}')
  171. ok=False
  172. if doPrint:
  173. print(f'[{avail}] {x}')
  174. return ok
  175. def uploadCreatedFiles(stage,fb,r,setup,names=None):
  176. if not names:
  177. names=listCreatedFiles(stage,r,setup)
  178. for f in names:
  179. _copyToServer=getData.copyToServer
  180. _getURL=getData.getURL
  181. if f=='segmentation':
  182. _copyToServer=segmentation.copyToServer
  183. _getURL=segmentation.getURL
  184. _copyToServer(fb,r,setup,names[f])
  185. for x in names[f]:
  186. print('[{}] Uploaded {}'.format(fb.entryExists(_getURL(fb,r,setup,x)),x))
  187. #this is a poor fit for workflow, but no better logical unit was found, so here it is
  188. def makeMap(segs,kClass,tac):
  189. map={}
  190. vals=[(kClass[i],tac[:,i]) for i in range(len(kClass))]
  191. for (i,v) in zip(segs,vals):
  192. try:
  193. map[i].append(v)
  194. except KeyError:
  195. map[i]=[v]
  196. return map
  197. #def getDataAtPixels(data,loc) replaced by loadData.getTACAtPixels(data,loc)
  198. def updateDatabase(r,setup,stage,fb=None,db=None,categories=[]):
  199. #set database entry
  200. try:
  201. qLam=setup['qLambda']
  202. except KeyError:
  203. qLam=0
  204. nclass=setup['nclass'][0]
  205. code=config.getCode(r,setup)
  206. if stage=='plotIVF':
  207. m,samples=loadData.readIVF(r,setup,qLambda=qLam)
  208. chi2=samples[0,:]
  209. threshold=numpy.median(chi2)
  210. fit=fitData.getFit(samples,threshold)
  211. row={x:r[x] for x in ['PatientId','visitCode']}
  212. row['nclass']=nclass
  213. row['mean']=fit.mu[0]
  214. row['std']=fit.cov[0,0]
  215. row['qLambda']=qLam
  216. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLam) for x in categories}
  217. row.update(fNames)
  218. if db:
  219. db.modifyRows('insert',setup['project'],'lists','SummaryIVF',[row])
  220. def workflow(r,setup,stage,fb=None,db=None):
  221. setCenters=False
  222. setIVF=False
  223. plotIVF=False
  224. setC=True
  225. qLambda=getRobust(setup,'qLambda')
  226. qLambdaC=getRobust(setup,'qLambdaC')
  227. if stage=='setCenters':
  228. names=listRequiredFiles(stage,r,setup,db=db)
  229. if fb:
  230. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  231. if not checkRequiredFiles(stage,r,setup,names=names,fb=fb,doPrint=True,db=db):
  232. return
  233. loadData.saveCenters(r,setup)
  234. if stage=='fitIVF':
  235. #get required files
  236. #stage='fitIVF'
  237. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,db=db):
  238. return
  239. loadData.saveIVF(r,setup,nfit=30,qLambda=qLambda)
  240. if stage=='plotIVF':
  241. ir=0
  242. names=listRequiredFiles(stage,r,setup,db=db)
  243. if fb:
  244. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  245. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  246. return
  247. print('Loading files to memory')
  248. m,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  249. data=loadData.loadData(r,setup)
  250. ct=loadData.loadCT(r,setup)
  251. centerMapSPECT,centerMapCT=loadData.loadCenterMapNRRD(r,setup,ir=ir)
  252. t,dt=loadData.loadTime(r,setup)
  253. centers=loadData.loadCenters(r,setup,ir=ir)
  254. ivf=centers[m]
  255. chi2=samples[0,:]
  256. threshold=numpy.median(chi2)
  257. code=config.getCode(r,setup)
  258. ir=0
  259. categories=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  260. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLambda) for x in categories}
  261. files={x:getData.getLocalPath(r,setup,fNames[x]) for x in fNames}
  262. plotData.plotIVF(t,ivf,samples,threshold,file0=files['fits'],file1=files['diff'])
  263. plotData.plotIVFRealizations(t,ivf,samples,threshold,file=files['generatedIVF'])
  264. #temporarily blocking center generation
  265. plotData.plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=files['centerIVFSPECT'],
  266. file1=files['centerIVFCT'])
  267. updateDatabase(r,setup,stage,db=db,fb=fb,categories=categories)
  268. if stage=='fitCompartment':
  269. ir=0
  270. names=listRequiredFiles(stage,r,setup,db=db)
  271. if fb:
  272. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  273. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  274. return
  275. #load class classification
  276. u=loadData.loadCenterMap(r,setup)
  277. print(u.shape)
  278. #load segmentation
  279. seg=segmentation.getNRRDImage(r,setup,names)
  280. loc=numpy.nonzero(seg)
  281. vClass=[int(x) for x in u[loc]]
  282. segments=[int(x) for x in seg[loc]]
  283. print(segments)
  284. data=loadData.loadData(r,setup)
  285. tac=loadData.getTACAtPixels(data,loc)
  286. segMap=makeMap(segments,vClass,tac)
  287. #for x in segMap:
  288. # print('{} {}'.format(x,segMap[x]))
  289. #return
  290. m1,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  291. chi2=samples[0,:]
  292. threshold=numpy.median(chi2)
  293. ivfFit=fitData.getFit(samples,threshold)
  294. t,dt=loadData.loadTime(r,setup)
  295. centers=loadData.loadCenters(r,setup)
  296. #save segmentation pixels
  297. setup['nseg']=len(segMap.keys())
  298. for x in segMap:
  299. mArray=segMap[x]
  300. qCenter=numpy.zeros(t.shape[0])
  301. qData=numpy.zeros(t.shape[0])
  302. s=0
  303. #average over contributions for each segmentation included in map
  304. kCenters=[]
  305. for m in mArray:
  306. #m is a tuple of classId and tac
  307. kCenters.append(m[0])
  308. qCenter+=centers[m[0]]
  309. qData+=m[1]
  310. s+=1
  311. qCenter/=s
  312. qData/=s
  313. samplesC=fitData.fitCompartmentGlobal(ivfFit,t,qCenter,useJac=True,nfit=20,qLambda=qLambdaC)
  314. samplesC1=fitData.fitCompartmentGlobal(ivfFit,t,qData,nfit=20,useJac=True,qLambda=qLambdaC)
  315. loadData.saveSamples(r,setup,samplesC,kCenters,'kmeansFit',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  316. loadData.saveSamples(r,setup,samplesC1,[-1],'localFit',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  317. loadData.saveTAC(r,setup,qCenter,'kmeansTAC',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  318. loadData.saveTAC(r,setup,qData,'localTAC',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  319. if stage=='plotCompartment':
  320. ir=0
  321. names=listRequiredFiles(stage,r,setup,db=db)
  322. if fb:
  323. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  324. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  325. return
  326. tag='plotCompartment'
  327. seg=segmentation.getNRRDImage(r,setup,names)
  328. loc=numpy.nonzero(seg)
  329. segmentIds=list(set([int(x) for x in seg[loc]]))
  330. nclass=setup['nclass'][0]
  331. code=config.getCode(r,setup)
  332. setup['nseg']=len(segmentIds)
  333. print('Setting setup[nseg]={}'.format(setup['nseg']))
  334. t,dt=loadData.loadTime(r,setup)
  335. for iseg in segmentIds:
  336. m,samplesC=loadData.readSamples(r,setup,'kmeansFit',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  337. m1,samplesC1=loadData.readSamples(r,setup,'localFit',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  338. qCenter=loadData.readTAC(r,setup,'kmeansTAC',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  339. qData=loadData.readTAC(r,setup,'localTAC',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  340. chi2C=samplesC[0,:]
  341. threshold=numpy.median(chi2C)
  342. chi2C1=samplesC1[0,:]
  343. threshold1=numpy.median(chi2C1)
  344. fit=fitData.getFit(samplesC,threshold)
  345. fit1=fitData.getFit(samplesC1,threshold1)
  346. k1=fit.mu[0]
  347. stdK1=fit.cov[0,0]
  348. k11=fit1.mu[0]
  349. stdK11=fit1.cov[0,0]
  350. #update database with entries
  351. row={x:r[x] for x in ['PatientId','visitCode']}
  352. row['Date']=datetime.datetime.now().isoformat()
  353. row['nclass']=nclass
  354. row['option']='kmeansFit'
  355. row['mean']=k1
  356. row['std']=stdK1
  357. row['regionId']=iseg
  358. row['qLambda']=qLambda
  359. row['qLambdaC']=qLambdaC
  360. row['fitPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='realizations',\
  361. iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  362. row['diffPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='diff',\
  363. iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  364. row1={x:row[x] for x in row}
  365. row1['option']='localFit'
  366. row1['mean']=k11
  367. row1['std']=stdK11
  368. if db:
  369. db.modifyRows('insert',setup['project'],'lists','Summary',[row,row1])
  370. evalArray=[(samplesC,qCenter,'blue'),
  371. (samplesC1,qData,'orange')]
  372. file0=getData.getLocalPath(r,setup,row['fitPlot'])
  373. file1=getData.getLocalPath(r,setup,row['diffPlot'])
  374. plotData.plotSamples(t,evalArray,file0=file0,file1=file1)
  375. nseg=getRobust(setup,'nseg')
  376. print(f'At upload nseg={nseg}')
  377. uploadCreatedFiles(stage,fb,r,setup)
  378. #setup could be modified
  379. return setup
  380. def main(setupFile):
  381. with open(setupFile,'r') as f:
  382. setup=json.load(f)
  383. db,fb,rows=getRows(setup,returnFB=True)
  384. for r in rows:
  385. workflow(r,setup,setup['stage'],fb=fb,db=db)
  386. if __name__=="__main__":
  387. main(sys.argv[1])