workflow.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import config
  2. import getData
  3. import loadData
  4. import fitData
  5. import numpy
  6. import segmentation
  7. import plotData
  8. import os
  9. import datetime
  10. import sys
  11. import json
  12. def getRows(setup,returnFB=False):
  13. qFilter=config.getFilter(setup)
  14. db,fb=getData.connectDB(setup['network'])
  15. rows=getData.getPatients(db,setup,qFilter)
  16. #print(rows)
  17. if returnFB:
  18. return fb,db,rows
  19. return rows
  20. def getRobust(setup,var,default=0):
  21. #get from dict, return default if not set
  22. try:
  23. return setup[var]
  24. except KeyError:
  25. return default
  26. def listRequiredFiles(stage,r,setup,db=None):
  27. code=config.getCode(r,setup)
  28. nclass=setup['nclass'][0]
  29. nr=setup['nr']
  30. nt=20
  31. qLambda=getRobust(setup,'qLambda')
  32. qLambdaC=getRobust(setup,'qLambdaC')
  33. if stage=='setCenters':
  34. names={x:[config.getPattern(x,code)] for x in ['CT','Dummy']}
  35. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  36. return names
  37. if stage=='fitIVF':
  38. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  39. names['center']=[]
  40. for ir in range(nr):
  41. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  42. names['center'].extend(rel)
  43. return names
  44. if stage=='plotIVF':
  45. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  46. names['center']=[]
  47. names['fitIVF']=[]
  48. for ir in range(nr):
  49. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  50. names['center'].extend(rel)
  51. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  52. #names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  53. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  54. names['center'].extend(rel)
  55. names.update({x:[config.getPattern(x,code)] for x in ['CT']})
  56. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  57. return names
  58. if stage=='fitCompartment':
  59. names={}
  60. names['center']=[]
  61. names['fitIVF']=[]
  62. for ir in range(nr):
  63. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  64. names['center'].extend(rel)
  65. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  66. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  67. names['segmentation']=[segmentation.getSegmentationFileName(r,setup,db=db)]
  68. return names
  69. if stage=='plotCompartment':
  70. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  71. names={}
  72. names['center']=[]
  73. names['fitIVF']=[]
  74. names['fitCompartment']=[]
  75. nseg=setup['nseg']
  76. for ir in range(nr):
  77. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  78. names['center'].extend(rel)
  79. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  80. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  81. for iseg in range(nseg):
  82. xc='fitCompartment'
  83. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC) for qn in sNames]
  84. names['fitCompartment'].extend(rel)
  85. names['segmentation']=[segmentation.getSegmentationFileName(r,setup,db=db)]
  86. names.update({x:[config.getPattern(x,code)] for x in ['CT','Dummy']})
  87. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  88. return names
  89. return {}
  90. def listCreatedFiles(stage,r,setup):
  91. code=config.getCode(r,setup)
  92. nclass=setup['nclass'][0]
  93. qLambda=getRobust(setup,'qLambda')
  94. qLambdaC=getRobust(setup,'qLambdaC')
  95. nr=setup['nr']
  96. nseg=getRobust(setup,'nseg')
  97. print(f'List created files nseg={nseg}')
  98. names={}
  99. getPattern=config.getPattern
  100. if stage=='setCenters':
  101. names['center']=[]
  102. for ir in range(nr):
  103. rel=[getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  104. names['center'].extend(rel)
  105. #rel=[config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  106. #names['center'].extend(rel)
  107. names['center'].append(getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  108. rel=[getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  109. names['center'].extend(rel)
  110. return names
  111. if stage=='fitIVF':
  112. names['fitIVF']=[]
  113. for ir in range(nr):
  114. names['fitIVF'].append(getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  115. return names
  116. if stage=='plotIVF':
  117. names['plotIVF']=[]
  118. sNames=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  119. for ir in range(nr):
  120. x=[getPattern('plotIVF',code=code,nclass=nclass,ir=ir,qaName=y,qLambda=qLambda) for y in sNames]
  121. names['plotIVF'].extend(x)
  122. return names
  123. if stage=='fitCompartment':
  124. xc='fitCompartment'
  125. names[xc]=[]
  126. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  127. for ir in range(nr):
  128. for iseg in range(nseg):
  129. _q=qLambda
  130. _qC=qLambdaC
  131. rel=[getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=_q,qLambdaC=_qC) for qn in sNames]
  132. names[xc].extend(rel)
  133. return names
  134. if stage=='plotCompartment':
  135. xc=stage
  136. names[xc]=[]
  137. sNames=['realizations','diff']
  138. for ir in range(nr):
  139. for s in range(nseg):
  140. iseg=s+1
  141. _q=qLambda
  142. _qC=qLambdaC
  143. rel=[getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=_q,qLambdaC=_qC) for qn in sNames]
  144. names[xc].extend(rel)
  145. return names
  146. return []
  147. def getRequiredFiles(stage,r,setup,fb,names=None,db=None):
  148. #fb,r=getRow(setup,True)
  149. if not names:
  150. names=listRequiredFiles(stage,r,setup,db=db)
  151. for f in names:
  152. _copyFromServer=getData.copyFromServer
  153. if f=='segmentation':
  154. _copyFromServer=segmentation.copyFromServer
  155. _copyFromServer(fb,r,setup,names[f])
  156. return fb,r
  157. def checkRequiredFiles(stage,r,setup,names=None,fb=None,doPrint=False,db=None):
  158. ok=True
  159. if not names:
  160. names=listRequiredFiles(stage,r,setup,db=db)
  161. for f in names:
  162. nm=names[f]
  163. for x in nm:
  164. avail=os.path.isfile(getData.getLocalPath(r,setup,x))
  165. if not avail:
  166. print(f'Missing {x}')
  167. if fb:
  168. _getURL=getData.getURL
  169. if f=='segmentation':
  170. _getURL=segmentation.getURL
  171. availRemote=fb.entryExists(_getURL(fb,r,setup,x))
  172. print(f'Available remote: {availRemote}')
  173. ok=False
  174. if doPrint:
  175. print(f'[{avail}] {x}')
  176. return ok
  177. def uploadCreatedFiles(stage,fb,r,setup,names=None):
  178. if not names:
  179. names=listCreatedFiles(stage,r,setup)
  180. for f in names:
  181. _copyToServer=getData.copyToServer
  182. _getURL=getData.getURL
  183. if f=='segmentation':
  184. _copyToServer=segmentation.copyToServer
  185. _getURL=segmentation.getURL
  186. _copyToServer(fb,r,setup,names[f])
  187. for x in names[f]:
  188. print('[{}] Uploaded {}'.format(fb.entryExists(_getURL(fb,r,setup,x)),x))
  189. #this is a poor fit for workflow, but no better logical unit was found, so here it is
  190. def makeMap(segs,kClass,tac):
  191. map={}
  192. vals=[(kClass[i],tac[:,i]) for i in range(len(kClass))]
  193. for (i,v) in zip(segs,vals):
  194. try:
  195. map[i].append(v)
  196. except KeyError:
  197. map[i]=[v]
  198. return map
  199. #def getDataAtPixels(data,loc) replaced by loadData.getTACAtPixels(data,loc)
  200. def updateDatabase(r,setup,stage,fb=None,db=None,categories=[]):
  201. #set database entry
  202. try:
  203. qLam=setup['qLambda']
  204. except KeyError:
  205. qLam=0
  206. nclass=setup['nclass'][0]
  207. code=config.getCode(r,setup)
  208. if stage=='plotIVF':
  209. m,samples=loadData.readIVF(r,setup,qLambda=qLam)
  210. chi2=samples[0,:]
  211. threshold=numpy.median(chi2)
  212. fit=fitData.getFit(samples,threshold)
  213. row={x:r[x] for x in ['PatientId','visitCode']}
  214. row['nclass']=nclass
  215. row['mean']=fit.mu[0]
  216. row['std']=fit.cov[0,0]
  217. row['qLambda']=qLam
  218. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLam) for x in categories}
  219. row.update(fNames)
  220. if db:
  221. db.modifyRows('insert',setup['project'],'lists','SummaryIVF',[row])
  222. def workflow(r,setup,stage,fb=None,db=None):
  223. setCenters=False
  224. setIVF=False
  225. plotIVF=False
  226. setC=True
  227. qLambda=getRobust(setup,'qLambda')
  228. qLambdaC=getRobust(setup,'qLambdaC')
  229. print('Running {} for {}'.format(stage,config.getCode(r,setup)))
  230. if stage=='setCenters':
  231. names=listRequiredFiles(stage,r,setup,db=db)
  232. if fb:
  233. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  234. if not checkRequiredFiles(stage,r,setup,names=names,fb=fb,doPrint=True,db=db):
  235. return
  236. loadData.saveCenters(r,setup)
  237. if stage=='fitIVF':
  238. #get required files
  239. #stage='fitIVF'
  240. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,db=db):
  241. return
  242. loadData.saveIVF(r,setup,nfit=30,qLambda=qLambda)
  243. if stage=='plotIVF':
  244. ir=0
  245. names=listRequiredFiles(stage,r,setup,db=db)
  246. if fb:
  247. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  248. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  249. return
  250. print('Loading files to memory')
  251. m,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  252. data=loadData.loadData(r,setup)
  253. ct=loadData.loadCT(r,setup)
  254. centerMapSPECT,centerMapCT=loadData.loadCenterMapNRRD(r,setup,ir=ir)
  255. t,dt=loadData.loadTime(r,setup)
  256. centers=loadData.loadCenters(r,setup,ir=ir)
  257. ivf=centers[m]
  258. chi2=samples[0,:]
  259. threshold=numpy.median(chi2)
  260. code=config.getCode(r,setup)
  261. ir=0
  262. categories=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  263. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLambda) for x in categories}
  264. files={x:getData.getLocalPath(r,setup,fNames[x]) for x in fNames}
  265. plotData.plotIVF(t,ivf,samples,threshold,file0=files['fits'],file1=files['diff'])
  266. plotData.plotIVFRealizations(t,ivf,samples,threshold,file=files['generatedIVF'])
  267. #temporarily blocking center generation
  268. plotData.plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=files['centerIVFSPECT'],
  269. file1=files['centerIVFCT'])
  270. updateDatabase(r,setup,stage,db=db,fb=fb,categories=categories)
  271. if stage=='fitCompartment':
  272. ir=0
  273. names=listRequiredFiles(stage,r,setup,db=db)
  274. if fb:
  275. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  276. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  277. return
  278. #load class classification
  279. u=loadData.loadCenterMap(r,setup)
  280. print(u.shape)
  281. #load segmentation
  282. seg=segmentation.getNRRDImage(r,setup,names)
  283. loc=numpy.nonzero(seg)
  284. vClass=[int(x) for x in u[loc]]
  285. segments=[int(x) for x in seg[loc]]
  286. print(segments)
  287. data=loadData.loadData(r,setup)
  288. tac=loadData.getTACAtPixels(data,loc)
  289. segMap=makeMap(segments,vClass,tac)
  290. #for x in segMap:
  291. # print('{} {}'.format(x,segMap[x]))
  292. #return
  293. m1,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  294. chi2=samples[0,:]
  295. threshold=numpy.median(chi2)
  296. ivfFit=fitData.getFit(samples,threshold)
  297. t,dt=loadData.loadTime(r,setup)
  298. centers=loadData.loadCenters(r,setup)
  299. #save segmentation pixels
  300. setup['nseg']=len(segMap.keys())
  301. for x in segMap:
  302. mArray=segMap[x]
  303. qCenter=numpy.zeros(t.shape[0])
  304. qData=numpy.zeros(t.shape[0])
  305. s=0
  306. #average over contributions for each segmentation included in map
  307. kCenters=[]
  308. for m in mArray:
  309. #m is a tuple of classId and tac
  310. kCenters.append(m[0])
  311. qCenter+=centers[m[0]]
  312. qData+=m[1]
  313. s+=1
  314. qCenter/=s
  315. qData/=s
  316. samplesC=fitData.fitCompartmentGlobal(ivfFit,t,qCenter,useJac=True,nfit=20,qLambda=qLambdaC)
  317. samplesC1=fitData.fitCompartmentGlobal(ivfFit,t,qData,nfit=20,useJac=True,qLambda=qLambdaC)
  318. loadData.saveSamples(r,setup,samplesC,kCenters,'kmeansFit',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  319. loadData.saveSamples(r,setup,samplesC1,[-1],'localFit',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  320. loadData.saveTAC(r,setup,qCenter,'kmeansTAC',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  321. loadData.saveTAC(r,setup,qData,'localTAC',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  322. if stage=='plotCompartment':
  323. ir=0
  324. names=listRequiredFiles(stage,r,setup,db=db)
  325. if fb:
  326. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  327. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  328. return
  329. tag='plotCompartment'
  330. seg=segmentation.getNRRDImage(r,setup,names)
  331. loc=numpy.nonzero(seg)
  332. segmentIds=list(set([int(x) for x in seg[loc]]))
  333. nclass=setup['nclass'][0]
  334. code=config.getCode(r,setup)
  335. setup['nseg']=len(segmentIds)
  336. print('Setting setup[nseg]={}'.format(setup['nseg']))
  337. t,dt=loadData.loadTime(r,setup)
  338. for iseg in segmentIds:
  339. m,samplesC=loadData.readSamples(r,setup,'kmeansFit',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  340. m1,samplesC1=loadData.readSamples(r,setup,'localFit',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  341. qCenter=loadData.readTAC(r,setup,'kmeansTAC',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  342. qData=loadData.readTAC(r,setup,'localTAC',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  343. chi2C=samplesC[0,:]
  344. threshold=numpy.median(chi2C)
  345. chi2C1=samplesC1[0,:]
  346. threshold1=numpy.median(chi2C1)
  347. fit=fitData.getFit(samplesC,threshold)
  348. fit1=fitData.getFit(samplesC1,threshold1)
  349. k1=fit.mu[0]
  350. stdK1=fit.cov[0,0]
  351. k11=fit1.mu[0]
  352. stdK11=fit1.cov[0,0]
  353. #update database with entries
  354. row={x:r[x] for x in ['PatientId','visitCode']}
  355. row['Date']=datetime.datetime.now().isoformat()
  356. row['nclass']=nclass
  357. row['option']='kmeansFit'
  358. row['mean']=k1
  359. row['std']=stdK1
  360. row['regionId']=iseg
  361. row['qLambda']=qLambda
  362. row['qLambdaC']=qLambdaC
  363. row['fitPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='realizations',\
  364. iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  365. row['diffPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='diff',\
  366. iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  367. row1={x:row[x] for x in row}
  368. row1['option']='localFit'
  369. row1['mean']=k11
  370. row1['std']=stdK11
  371. if db:
  372. db.modifyRows('insert',setup['project'],'lists','Summary',[row,row1])
  373. evalArray=[(samplesC,qCenter,'blue'),
  374. (samplesC1,qData,'orange')]
  375. file0=getData.getLocalPath(r,setup,row['fitPlot'])
  376. file1=getData.getLocalPath(r,setup,row['diffPlot'])
  377. plotData.plotSamples(t,evalArray,file0=file0,file1=file1)
  378. nseg=getRobust(setup,'nseg')
  379. print(f'At upload nseg={nseg}')
  380. uploadCreatedFiles(stage,fb,r,setup)
  381. #setup could be modified
  382. return setup
  383. def main(setupFile):
  384. with open(setupFile,'r') as f:
  385. setup=json.load(f)
  386. db,fb,rows=getRows(setup,returnFB=True)
  387. for r in rows:
  388. workflow(r,setup,setup['stage'],fb=fb,db=db)
  389. if __name__=="__main__":
  390. main(sys.argv[1])