123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- import os
- import json
- import re
- import subprocess
- import nibabel
- import shutil
- import sys
- import pathlib
- import SimpleITK
- import numpy
- #nothing gets done if you do import
- def getPatientLabel(row,participantField='PatientId'):
- return row[participantField].replace('/','_')
- def getVisitLabel(row):
- return 'VISIT_'+str(int(row['SequenceNum']))
- def getStudyLabel(row,participantField='PatientId'):
- return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
- def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
- participantField='PatientId'):
- row['patientCode']=getPatientLabel(row,participantField)
- row['visitCode']=getVisitLabel(row)
- for im in imageResampledField:
- row[imageResampledField[im]]=gzFileNames[im]
- db.modifyRows('update',project,'study',dataset,[row])
-
- def replacePatterns(infile,outfile,replacePatterns):
- of=open(outfile,'w')
- with open(infile,'r') as f:
- data=f.read()
- for p in replacePatterns:
- val=replacePatterns[p]
- data=re.sub(p,val,data)
- of.write(data)
- of.close()
-
- def valueSubstitution(pars,val):
- if val.find('__home__')>-1:
- val=re.sub(r'__home__',os.path.expanduser('~'),val)
- return path
- def getCroppedImagePath(tempFile,crop):
- p=pathlib.Path(tempFile)
- sfx=''.join(p.suffixes)
- return re.sub(sfx,crop+sfx,str(p))
- def getSuffix(tempFile):
- p=pathlib.Path(tempFile)
- return ''.join(p.suffixes)
- def getSegmImagePath(tempFile):
- sfx=getSuffix(tempFile)
- return re.sub(sfx,'_Segm'+sfx,tempFile)
- def addVersion(tempFile,version):
- sfx=getSuffix(tempFile)
- return re.sub(sfx,'_'+version+sfx,tempFile)
- def normalizeCT(ctFile,maskFile):
-
- im=SimpleITK.ReadImage(ctFile)
- mask=SimpleITK.ReadImage(maskFile)
- nm=SimpleITK.GetArrayFromImage(im)
- nmask=SimpleITK.GetArrayFromImage(mask)
- mu=numpy.mean(nm[nmask>0])
- st=numpy.std(nm[nmask>0])
- nm[nmask>0]=(nm[nmask>0]-mu)/st
- nm[nmask==0]=0
- im1=SimpleITK.GetImageFromArray(nm)
- im1.SetOrigin(im.GetOrigin())
- im1.SetSpacing(im.GetSpacing())
- im1.SetDirection(im.GetDirection())
- SimpleITK.WriteImage(im1,ctFile)
- def cropImage(tempFile,crop, cropData):
- im=SimpleITK.ReadImage(tempFile)
- sz=im.GetSize()
- ax=int(cropData['axis'])
- rng=[float(v) for v in cropData['range']]
- #update cropData['n']
- if cropData['n']=="NONE":
- cropData['n']=sz[ax]
- if not sz[ax]==cropData['n']:
- print('Size mismatch {}:{}'.format(sz[ax],cropData['n']))
- n=sz[ax]
- ii=[int(x*n) for x in rng]
- slc=[slice(None) for v in sz]
- slc[ax]=slice(ii[0],ii[1])
- im1=im[slc]
- #im1=im.take(indices=range(i1,i2),axis=cropData['axis'])
- SimpleITK.WriteImage(im1,getCroppedImagePath(tempFile,crop))
- print("Written {}".format(getCroppedImagePath(tempFile,crop)))
- def runDeepMedic(setup,pars):
- args=[]
- args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
- args.append(setup['paths']['deepMedicRun'])
- args.append('-model')
- args.append(pars['deepmedic']['config']['model']['out'])
- args.append('-test')
- args.append(pars['deepmedic']['config']['test']['out'])
- args.append('-dev')
- args.append('cpu')
- print(args)
- print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
- def getSegmentationFile(pars,crop):
- #this is how deep medic stores files
- return getSegmImagePath(\
- getCroppedImagePath(\
- os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
- pars['images']['images']['segmentations']['tempFile']),crop)
- )
- def getWeight(x,w):
- for r in w:
- fw=[float(v) for v in r['range']]
- if x>fw[1]:
- continue
- if x<fw[0]:
- continue
- n=float(r['n'])
- if not 'k' in r:
- return n
- k=float(r['k'])
- return k*x+n
- return 0
- def runSegmentation(fb,row,pars,setup):
-
- if False:
- images=pars['images']['images']
- outImg=mergeSegmentations(pars)
- segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
- SimpleITK.WriteImage(outImg,segFile)
- return segFile
-
- #download to temp file (could be a fixed name)
- project=pars['project']
- images=pars['images']['images']
- participantField=pars['participantField']
- baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
- getPatientLabel(row,participantField)+'/'+\
- getVisitLabel(row))
- cropData=pars['images']['crop']
- #reset n
- for crop in cropData:
- cropData[crop]['n']="NONE"
-
- #download
- for im in images:
- tmpFile=images[im]['tempFile']
- if 'queryField' in images[im]:
- fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],tmpFile)
- #normalize
- normalizeCT(images['CT']['tempFile'],images['patientmask']['tempFile'])
- #crop and store file names
- for im in images:
- tmpFile=images[im]['tempFile']
-
- with open(images[im]['fileList'],'w') as f:
- for crop in cropData:
- print('n={}'.format(cropData[crop]['n']))
- if os.path.isfile(tmpFile):
- cropImage(tmpFile,crop,cropData[crop])
- print('n={}'.format(cropData[crop]['n']))
- f.write(getCroppedImagePath(tmpFile,crop)+'\n')
- #normalize crops
- for crop in cropData:
- normalizeCT(getCroppedImagePath(images['CT']['tempFile'],crop),
- getCroppedImagePath(images['patientmask']['tempFile'],crop))
-
- #run deep medic
- runDeepMedic(setup,pars)
- #merge segmentations
- outImg=mergeSegmentations(pars)
- segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
- SimpleITK.WriteImage(outImg,segFile)
- return segFile
- #
- def mergeSegmentations(pars):
-
- cropData=pars['images']['crop']
- start=True
- for c in cropData:
- segFile=getSegmentationFile(pars,c)
- si=SimpleITK.ReadImage(segFile)
- rng=[float(v) for v in cropData[c]['range']]
- n=cropData[c]['n']
- print(n)
- img=SimpleITK.ConstantPad(si,[0,0,int(rng[0]*n)],[0,0,n-int(rng[1]*n)],-1)
- print(img.GetSize())
- ni=SimpleITK.GetArrayFromImage(img)
- print(ni.shape)
- w1=numpy.zeros(ni.shape)
- aw=[getWeight((x+0.5)/n,cropData[c]['w']) for x in numpy.arange(n)]
- for k in numpy.arange(len(aw)):
- w1[k,:,:]=aw[k]
- if start:
- w0=w1
- imgTmpl=img
- nout=ni
- start=False
- continue
- nout[w1>w0]=ni[w1>w0]
- w0[w1>w0]=w1[w1>w0]
- iout=SimpleITK.GetImageFromArray(nout)
- iout.SetDirection(img.GetDirection())
- iout.SetOrigin(img.GetOrigin())
- iout.SetSpacing(img.GetSpacing())
- return iout
-
- def main(parameterFile):
-
- fhome=os.path.expanduser('~')
- with open(os.path.join(fhome,".labkey","setup.json")) as f:
- setup=json.load(f)
- sys.path.insert(0,setup["paths"]["labkeyInterface"])
- import labkeyInterface
- import labkeyDatabaseBrowser
- import labkeyFileBrowser
- sys.path.append(setup['paths']['parseConfig'])
- import parseConfig
- fconfig=os.path.join(fhome,'.labkey','network.json')
- net=labkeyInterface.labkeyInterface()
- net.init(fconfig)
- db=labkeyDatabaseBrowser.labkeyDB(net)
- fb=labkeyFileBrowser.labkeyFileBrowser(net)
- with open(parameterFile) as f:
- pars=json.load(f)
- pars=parseConfig.convert(pars)
- pars=parseConfig.convertValues(pars)
- print(pars)
-
- #update the config
- cfg=pars['deepmedic']['config']
- for c in cfg:
- replacePatterns(cfg[c]['template'],\
- cfg[c]['out'],\
- pars['replacePattern'])
- project=pars['project']
- dataset=pars['targetQuery']
- schema=pars['targetSchema']
- tempBase=pars['tempBase']
- if not os.path.isdir(tempBase):
- os.makedirs(tempBase)
- #all images from database
- ds=db.selectRows(project,schema,dataset,[])
-
- #input
- #use webdav to transfer file (even though it is localhost)
-
- i=0
- for row in ds["rows"]:
-
- #check if file is already there
- #dummy tf to get the suffix
- tf=getSegmentationFile(pars,'XX')
- outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
- outName=addVersion(\
- getSegmImagePath(\
- getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
- pars['version'])
- outFile=outpath+'/'+outName
- #check if file is there
- if not fb.entryExists(outFile):
- segFile=runSegmentation(fb,row,pars,setup)
- #copy file to file
- #normally I would update the targetQuery, but it contains previously set images
- #copy to labkey
- fb.writeFileToFile(segFile,outFile)
-
- #separate script (set version!)
- #update database
- copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
- filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
- ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
- if len(ds1['rows'])>0:
- mode='update'
- outRow=ds1['rows'][0]
- else:
- mode='insert'
- outRow={v:row[v] for v in copyFields}
- outRow[pars['version']]= outName
- db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow])
- #pull results back to LabKey
- i+=1
- if i==1:
- break
- print("Done")
- if __name__ == '__main__':
- main(sys.argv[1])
- #sys.exit()
|