workflow.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. import config
  2. import getData
  3. import loadData
  4. import fitData
  5. import numpy
  6. import segmentation
  7. import plotData
  8. import os
  9. import datetime
  10. def getRows(setup,returnFB=False):
  11. qFilter=config.getFilter(setup)
  12. db,fb=getData.connectDB(setup['network'])
  13. rows=getData.getPatients(db,setup,qFilter)
  14. #print(rows)
  15. if returnFB:
  16. return fb,db,rows
  17. return rows
  18. def getRobust(setup,var,default=0):
  19. #get from dict, return default if not set
  20. try:
  21. return setup[var]
  22. except KeyError:
  23. return default
  24. def listRequiredFiles(stage,r,setup,db=None):
  25. code=config.getCode(r,setup)
  26. nclass=setup['nclass'][0]
  27. nr=setup['nr']
  28. nt=20
  29. qLambda=getRobust(setup,'qLambda')
  30. qLambdaC=getRobust(setup,'qLambdaC')
  31. if stage=='setCenters':
  32. names={x:[config.getPattern(x,code)] for x in ['CT','Dummy']}
  33. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  34. return names
  35. if stage=='fitIVF':
  36. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  37. names['center']=[]
  38. for ir in range(nr):
  39. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  40. names['center'].extend(rel)
  41. return names
  42. if stage=='plotIVF':
  43. names={x:[config.getPattern(x,code)] for x in ['Dummy']}
  44. names['center']=[]
  45. names['fitIVF']=[]
  46. for ir in range(nr):
  47. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  48. names['center'].extend(rel)
  49. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  50. #names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  51. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  52. names['center'].extend(rel)
  53. names.update({x:[config.getPattern(x,code)] for x in ['CT']})
  54. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  55. return names
  56. if stage=='fitCompartment':
  57. names={}
  58. names['center']=[]
  59. names['fitIVF']=[]
  60. for ir in range(nr):
  61. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  62. names['center'].extend(rel)
  63. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  64. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  65. names['segmentation']=[segmentation.getSegmentationFileName(r,setup,db=db)]
  66. return names
  67. if stage=='plotCompartment':
  68. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  69. names={}
  70. names['center']=[]
  71. names['fitIVF']=[]
  72. names['fitCompartment']=[]
  73. nseg=setup['nseg']
  74. for ir in range(nr):
  75. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  76. names['center'].extend(rel)
  77. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  78. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  79. for iseg in range(nseg):
  80. xc='fitCompartment'
  81. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC) for qn in sNames]
  82. names['fitCompartment'].extend(rel)
  83. names['segmentation']=[segmentation.getSegmentationFileName(r,setup,db=db)]
  84. names.update({x:[config.getPattern(x,code)] for x in ['CT','Dummy']})
  85. names['SPECT']=[config.getPattern('SPECT',code=code,timepoint=i) for i in range(nt)]
  86. return names
  87. return {}
  88. def listCreatedFiles(stage,r,setup):
  89. code=config.getCode(r,setup)
  90. nclass=setup['nclass'][0]
  91. qLambda=getRobust(setup,'qLambda')
  92. qLambdaC=getRobust(setup,'qLambdaC')
  93. nr=setup['nr']
  94. nseg=getRobust(setup,'nseg')
  95. names={}
  96. if stage=='setCenters':
  97. names['center']=[]
  98. for ir in range(nr):
  99. rel=[config.getPattern('center',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  100. names['center'].extend(rel)
  101. #rel=[config.getPattern('centerWeight',code=code,nclass=nclass,ir=ir,ic=i) for i in range(nclass)]
  102. #names['center'].extend(rel)
  103. names['center'].append(config.getPattern('centerMap',code=code,nclass=nclass,ir=ir))
  104. rel=[config.getPattern('centerNRRD',code=code,nclass=nclass,ir=ir,qaName=x) for x in ['CT','SPECT']]
  105. names['center'].extend(rel)
  106. return names
  107. if stage=='fitIVF':
  108. names['fitIVF']=[]
  109. for ir in range(nr):
  110. names['fitIVF'].append(config.getPattern('fitIVF',code=code,nclass=nclass,ir=ir,qLambda=qLambda))
  111. return names
  112. if stage=='plotIVF':
  113. names['plotIVF']=[]
  114. for ir in range(nr):
  115. x=[config.getPattern('plotIVF',code=code,nclass=nclass,ir=ir,qaName=y,qLambda=qLambda)
  116. for y in ['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']]
  117. names['plotIVF'].extend(x)
  118. return names
  119. if stage=='fitCompartment':
  120. xc='fitCompartment'
  121. names[xc]=[]
  122. sNames=['kmeansFit','localFit','kmeansTAC','localTAC']
  123. for ir in range(nr):
  124. for iseg in range(nseg):
  125. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC) for qn in sNames]
  126. names[xc].extend(rel)
  127. if stage=='plotCompartment':
  128. xc='plotCompartment'
  129. names[xc]=[]
  130. sNames=['realizations','diff']
  131. for ir in range(nr):
  132. for iseg in range(nseg):
  133. rel=[config.getPattern(xc,code=code,nclass=nclass,ir=ir,qaName=qn,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC) for qn in sNames]
  134. names[xc].extend(rel)
  135. return []
  136. def getRequiredFiles(stage,r,setup,fb,names=None,db=None):
  137. #fb,r=getRow(setup,True)
  138. if not names:
  139. names=listRequiredFiles(stage,r,setup,db=db)
  140. for f in names:
  141. _copyFromServer=getData.copyFromServer
  142. if f=='segmentation':
  143. _copyFromServer=segmentation.copyFromServer
  144. _copyFromServer(fb,r,setup,names[f])
  145. return fb,r
  146. def checkRequiredFiles(stage,r,setup,names=None,fb=None,doPrint=False,db=None):
  147. ok=True
  148. if not names:
  149. names=listRequiredFiles(stage,r,setup,db=db)
  150. for f in names:
  151. nm=names[f]
  152. for x in nm:
  153. avail=os.path.isfile(getData.getLocalPath(r,setup,x))
  154. if not avail:
  155. print(f'Missing {x}')
  156. if fb:
  157. _getURL=getData.getURL
  158. if f=='segmentation':
  159. _getURL=segmentation.getURL
  160. availRemote=fb.entryExists(_getURL(fb,r,setup,x))
  161. print(f'Available remote: {availRemote}')
  162. ok=False
  163. if doPrint:
  164. print(f'[{avail}] {x}')
  165. return ok
  166. def uploadCreatedFiles(stage,fb,r,setup,names=None):
  167. if not names:
  168. names=listCreatedFiles(stage,r,setup)
  169. for f in names:
  170. _copyToServer=getData.copyToServer
  171. _getURL=getData.getURL
  172. if f=='segmentation':
  173. _copyToServer=segmentation.copyToServer
  174. _getURL=segmentation.getURL
  175. _copyToServer(fb,r,setup,names[f])
  176. for x in names[f]:
  177. print('[{}] Uploaded {}'.format(fb.entryExists(_getURL(fb,r,setup,x)),x))
  178. #this is a poor fit for workflow, but no better logical unit was found, so here it is
  179. def makeMap(segs,kClass,tac):
  180. map={}
  181. vals=[(kClass[i],tac[:,i]) for i in range(len(kClass))]
  182. for (i,v) in zip(segs,vals):
  183. try:
  184. map[i].append(v)
  185. except KeyError:
  186. map[i]=[v]
  187. return map
  188. #def getDataAtPixels(data,loc) replaced by loadData.getTACAtPixels(data,loc)
  189. def updateDatabase(r,setup,stage,fb=None,db=None,categories=[]):
  190. #set database entry
  191. try:
  192. qLam=setup['qLambda']
  193. except KeyError:
  194. qLam=0
  195. nclass=setup['nclass'][0]
  196. code=config.getCode(r,setup)
  197. if stage=='plotIVF':
  198. m,samples=loadData.readIVF(r,setup,qLambda=qLam)
  199. chi2=samples[0,:]
  200. threshold=numpy.median(chi2)
  201. fit=fitData.getFit(samples,threshold)
  202. row={x:r[x] for x in ['PatientId','visitCode']}
  203. row['nclass']=nclass
  204. row['mean']=fit.mu[0]
  205. row['std']=fit.cov[0,0]
  206. row['qLambda']=qLam
  207. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLam) for x in categories}
  208. row.update(fNames)
  209. if db:
  210. db.modifyRows('insert',setup['project'],'lists','SummaryIVF',[row])
  211. def workflow(r,setup,stage,fb=None,db=None):
  212. setCenters=False
  213. setIVF=False
  214. plotIVF=False
  215. setC=True
  216. qLambda=getRobust(setup,'qLambda')
  217. qLambdaC=getRobust(setup,'qLambdaC')
  218. if stage=='setCenters':
  219. names=listRequiredFiles(stage,r,setup,db=db)
  220. if fb:
  221. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  222. if not checkRequiredFiles(stage,r,setup,names=names,fb=fb,doPrint=True,db=db):
  223. return
  224. loadData.saveCenters(r,setup)
  225. if stage=='fitIVF':
  226. #get required files
  227. #stage='fitIVF'
  228. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,db=db):
  229. return
  230. loadData.saveIVF(r,setup,nfit=30,qLambda=qLambda)
  231. if stage=='plotIVF':
  232. ir=0
  233. names=listRequiredFiles(stage,r,setup,db=db)
  234. if fb:
  235. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  236. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  237. return
  238. print('Loading files to memory')
  239. m,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  240. data=loadData.loadData(r,setup)
  241. ct=loadData.loadCT(r,setup)
  242. centerMapSPECT,centerMapCT=loadData.loadCenterMapNRRD(r,setup,ir=ir)
  243. t,dt=loadData.loadTime(r,setup)
  244. centers=loadData.loadCenters(r,setup,ir=ir)
  245. ivf=centers[m]
  246. chi2=samples[0,:]
  247. threshold=numpy.median(chi2)
  248. code=config.getCode(r,setup)
  249. ir=0
  250. categories=['fits','diff','generatedIVF','centerIVFSPECT','centerIVFCT']
  251. fNames={x:config.getPattern('plotIVF',code=code,ir=0,nclass=nclass,qaName=x,qLambda=qLambda) for x in categories}
  252. files={x:getData.getLocalPath(r,setup,fNames[x]) for x in fNames}
  253. plotData.plotIVF(t,ivf,samples,threshold,file0=files['fits'],file1=files['diff'])
  254. plotData.plotIVFRealizations(t,ivf,samples,threshold,file=files['generatedIVF'])
  255. #temporarily blocking center generation
  256. plotData.plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=files['centerIVFSPECT'],
  257. file1=files['centerIVFCT'])
  258. updateDatabase(r,setup,stage,db=db,fb=fb,categories=categories)
  259. if stage=='fitCompartment':
  260. ir=0
  261. names=listRequiredFiles(stage,r,setup,db=db)
  262. if fb:
  263. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  264. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  265. return
  266. #load class classification
  267. u=loadData.loadCenterMap(r,setup)
  268. print(u.shape)
  269. #load segmentation
  270. seg=segmentation.getNRRDImage(r,setup,names)
  271. loc=numpy.nonzero(seg)
  272. vClass=[int(x) for x in u[loc]]
  273. segments=[int(x) for x in seg[loc]]
  274. print(segments)
  275. data=loadData.loadData(r,setup)
  276. tac=loadData.getTACAtPixels(data,loc)
  277. segMap=makeMap(segments,vClass,tac)
  278. #for x in segMap:
  279. # print('{} {}'.format(x,segMap[x]))
  280. #return
  281. m1,samples=loadData.readIVF(r,setup,qLambda=qLambda)
  282. chi2=samples[0,:]
  283. threshold=numpy.median(chi2)
  284. ivfFit=fitData.getFit(samples,threshold)
  285. t,dt=loadData.loadTime(r,setup)
  286. centers=loadData.loadCenters(r,setup)
  287. #save segmentation pixels
  288. setup['nseg']=len(segMap.keys())
  289. for x in segMap:
  290. mArray=segMap[x]
  291. qCenter=numpy.zeros(t.shape[0])
  292. qData=numpy.zeros(t.shape[0])
  293. s=0
  294. #average over contributions for each segmentation included in map
  295. kCenters=[]
  296. for m in mArray:
  297. #m is a tuple of classId and tac
  298. kCenters.append(m[0])
  299. qCenter+=centers[m[0]]
  300. qData+=m[1]
  301. s+=1
  302. qCenter/=s
  303. qData/=s
  304. samplesC=fitData.fitCompartmentGlobal(ivfFit,t,qCenter,useJac=True,nfit=20,qLambda=qLambdaC)
  305. samplesC1=fitData.fitCompartmentGlobal(ivfFit,t,qData,nfit=20,useJac=True,qLambda=qLambdaC)
  306. loadData.saveSamples(r,setup,samplesC,kCenters,'kmeansFit',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  307. loadData.saveSamples(r,setup,samplesC1,[-1],'localFit',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  308. loadData.saveTAC(r,setup,qCenter,'kmeansTAC',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  309. loadData.saveTAC(r,setup,qData,'localTAC',iseg=x,ir=ir,qLambda=qLambda,qLambdaC=qLambdaC)
  310. if stage=='plotCompartment':
  311. ir=0
  312. names=listRequiredFiles(stage,r,setup,db=db)
  313. if fb:
  314. getRequiredFiles(stage,r,setup,fb,names=names,db=db)
  315. if not checkRequiredFiles(stage,r,setup,fb=fb,doPrint=True,names=names,db=db):
  316. return
  317. tag='plotCompartment'
  318. seg=segmentation.getNRRDImage(r,setup,names)
  319. loc=numpy.nonzero(seg)
  320. segmentIds=list(set([int(x) for x in seg[loc]]))
  321. nclass=setup['nclass'][0]
  322. code=config.getCode(r,setup)
  323. setup['nseg']=len(segmentIds)
  324. t,dt=loadData.loadTime(r,setup)
  325. for iseg in segmentIds:
  326. m,samplesC=loadData.readSamples(r,setup,'kmeansFit',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  327. m1,samplesC1=loadData.readSamples(r,setup,'localFit',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  328. qCenter=loadData.readTAC(r,setup,'kmeansTAC',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  329. qData=loadData.readTAC(r,setup,'localTAC',ir=ir,iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  330. chi2C=samplesC[0,:]
  331. threshold=numpy.median(chi2C)
  332. chi2C1=samplesC1[0,:]
  333. threshold1=numpy.median(chi2C1)
  334. fit=fitData.getFit(samplesC,threshold)
  335. fit1=fitData.getFit(samplesC1,threshold1)
  336. k1=fit.mu[0]
  337. stdK1=fit.cov[0,0]
  338. k11=fit1.mu[0]
  339. stdK11=fit1.cov[0,0]
  340. #update database with entries
  341. row={x:r[x] for x in ['PatientId','visitCode']}
  342. row['Date']=datetime.datetime.now().isoformat()
  343. row['nclass']=nclass
  344. row['option']='kmeansFit'
  345. row['mean']=k1
  346. row['std']=stdK1
  347. row['regionId']=iseg
  348. row['fitPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='realizations',iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  349. row['diffPlot']=config.getPattern(tag,code=code,ir=0,nclass=nclass,qaName='diff',iseg=iseg,qLambda=qLambda,qLambdaC=qLambdaC)
  350. row1={x:row[x] for x in row}
  351. row1['option']='localFit'
  352. row1['mean']=k11
  353. row1['std']=stdK11
  354. row['qLambda']=qLambda
  355. row['qLambdaC']=qLambdaC
  356. if db:
  357. db.modifyRows('insert',setup['project'],'lists','Summary',[row,row1])
  358. evalArray=[(samplesC,qCenter,'blue'),
  359. (samplesC1,qData,'orange')]
  360. file0=getData.getLocalPath(r,setup,row['fitPlot'])
  361. file1=getData.getLocalPath(r,setup,row['diffPlot'])
  362. plotData.plotSamples(t,evalArray,file0=file0,file1=file1)
  363. uploadCreatedFiles(stage,fb,r,setup)
  364. #setup could be modified
  365. return setup
  366. def main(setupFile):
  367. with open(setupFile,'r') as f:
  368. setup=json.load(f)
  369. db,fb,rows=getRows(setup,returnFB=True)
  370. for r in rows:
  371. workflow(r,setup,setup['stage'],fb=fb,db=db)
  372. if __name__=="__main__":
  373. main(sys.argv[1])