123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- import os
- import json
- import re
- import subprocess
- import nibabel
- import shutil
- import sys
- import pathlib
- import SimpleITK
- import numpy
- #nothing gets done if you do import
- def getPatientLabel(row,participantField='PatientId'):
- return row[participantField].replace('/','_')
- def getVisitLabel(row):
- return 'VISIT_'+str(int(row['SequenceNum']))
- def getStudyLabel(row,participantField='PatientId'):
- return getPatientLabel(row,participantField)+'-'+getVisitLabel(row)
- def updateRow(project,dataset,row,imageResampledField,gzFileNames,\
- participantField='PatientId'):
- row['patientCode']=getPatientLabel(row,participantField)
- row['visitCode']=getVisitLabel(row)
- for im in imageResampledField:
- row[imageResampledField[im]]=gzFileNames[im]
- db.modifyRows('update',project,'study',dataset,[row])
-
- def replacePatterns(infile,outfile,replacePatterns):
- of=open(outfile,'w')
- with open(infile,'r') as f:
- data=f.read()
- for p in replacePatterns:
- val=replacePatterns[p]
- data=re.sub(p,val,data)
- of.write(data)
- of.close()
-
- def valueSubstitution(pars,val):
- if val.find('__home__')>-1:
- val=re.sub(r'__home__',os.path.expanduser('~'),val)
- return path
- def getSuffix(tempFile):
- p=pathlib.Path(tempFile)
- return ''.join(p.suffixes)
- def getSegmImagePath(tempFile):
- #return tempFile
- sfx=getSuffix(tempFile)
- return re.sub(sfx,'_Segm'+sfx,tempFile)
- def addVersion(tempFile,version):
- sfx=getSuffix(tempFile)
- return re.sub(sfx,'_'+version+sfx,tempFile)
- def normalizeCT(ctFile,maskFile):
-
- im=SimpleITK.ReadImage(ctFile)
- mask=SimpleITK.ReadImage(maskFile)
- nm=SimpleITK.GetArrayFromImage(im)
- nmask=SimpleITK.GetArrayFromImage(mask)
- mu=numpy.mean(nm[nmask>0])
- st=numpy.std(nm[nmask>0])
- nm[nmask>0]=(nm[nmask>0]-mu)/st
- nm[nmask==0]=0
- im1=SimpleITK.GetImageFromArray(nm)
- im1.SetOrigin(im.GetOrigin())
- im1.SetSpacing(im.GetSpacing())
- im1.SetDirection(im.GetDirection())
- SimpleITK.WriteImage(im1,ctFile)
- def runDeepMedic(setup,pars):
- args=[]
- #args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
- args.append(setup['paths']['deepMedicRun'])
- args.append('-model')
- args.append(pars['deepmedic']['config']['model']['out'])
- args.append('-test')
- args.append(pars['deepmedic']['config']['test']['out'])
- args.append('-dev')
- args.append('cpu')
- print(args)
- print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
- def runDeepMedicDocker(setup,pars):
- args=[]
- args.extend(['docker-compose','-f',pars['deepmedic']['segmentationdmYAML'],'up'])
- print(args)
- print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
- def getSegmentationFile(pars):
- #this is how deep medic stores files
- return getSegmImagePath(\
- os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
- pars['images']['segmentation']['tempFile'])
- )
- def runSegmentation(fb,row,pars,setup):
-
-
- #download to temp file (could be a fixed name)
- project=pars['project']
- images=pars['images']
- participantField=pars['participantField']
- baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
- getPatientLabel(row,participantField)+'/'+\
- getVisitLabel(row))
-
- #download
- fullFile={key:os.path.join(pars['tempBase'],images[key]['tempFile']) for key in images}
- for im in images:
- if 'queryField' in images[im]:
- fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],fullFile[im])
- print('Loaded {}'.format(fullFile[im]))
-
- #normalize
- normalizeCT(fullFile['CT'],fullFile['patientmask'])
- #update templates to know which files to process
- #run deep medic
- #runDeepMedicDocker(setup,pars)
- runDeepMedic(setup,pars)
-
- #processed file is
- segFile=getSegmentationFile(pars)
- #SimpleITK.WriteImage(outImg,segFile)
- return segFile
-
- def main(parameterFile):
-
- fhome=os.path.expanduser('~')
- with open(os.path.join(fhome,".labkey","setup.json")) as f:
- setup=json.load(f)
- sys.path.insert(0,setup["paths"]["nixWrapper"])
-
- import nixWrapper
-
- nixWrapper.loadLibrary("labkeyInterface")
- import labkeyInterface
- import labkeyDatabaseBrowser
- import labkeyFileBrowser
- nixWrapper.loadLibrary("parseConfig")
- import parseConfig
- with open(parameterFile) as f:
- pars=json.load(f)
-
- pars=parseConfig.convert(pars)
- pars=parseConfig.convertValues(pars)
- print(pars)
- #images=pars['images']
- #ctFile=os.path.join(pars['tempBase'],images['CT']['tempFile'])
- #maskFile=os.path.join(pars['tempBase'],images['patientmask']['tempFile'])
- #normalizeCT(ctFile,maskFile)
-
- def doSegmentation(parameterFile):
- fhome=os.path.expanduser('~')
- with open(os.path.join(fhome,".labkey","setup.json")) as f:
- setup=json.load(f)
- sys.path.insert(0,setup["paths"]["nixWrapper"])
-
- import nixWrapper
-
- nixWrapper.loadLibrary("labkeyInterface")
- import labkeyInterface
- import labkeyDatabaseBrowser
- import labkeyFileBrowser
- nixWrapper.loadLibrary("parseConfig")
- import parseConfig
- fconfig=os.path.join(fhome,'.labkey','network.json')
- net=labkeyInterface.labkeyInterface()
- net.init(fconfig)
- db=labkeyDatabaseBrowser.labkeyDB(net)
- fb=labkeyFileBrowser.labkeyFileBrowser(net)
- with open(parameterFile) as f:
- pars=json.load(f)
- pars=parseConfig.convert(pars,setup)
- pars=parseConfig.convertValues(pars)
- print(pars)
-
- #update the config
- cfg=pars['deepmedic']['config']
- for c in cfg:
- replacePatterns(cfg[c]['template'],\
- cfg[c]['out'],\
- pars['replacePattern'])
- project=pars['project']
- dataset=pars['targetQuery']
- schema=pars['targetSchema']
- tempBase=pars['tempBase']
- if not os.path.isdir(tempBase):
- os.makedirs(tempBase)
- #all images from database
- qFilter=pars['entryFilter']
- ds=db.selectRows(project,schema,dataset,qFilter)
- print('Got {} rows'.format(len(ds['rows'])))
-
- #input
- #use webdav to transfer file (even though it is localhost)
-
- i=0
- for row in ds["rows"]:
-
- #check if file is already there
- #dummy tf to get the suffix
- tf=getSegmentationFile(pars)
- outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
- outName=addVersion(\
- getSegmImagePath(\
- getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
- pars['version'])
- outFile=outpath+'/'+outName
- #check if file is there
- if not fb.entryExists(outFile):
- segFile=runSegmentation(fb,row,pars,setup)
- #copy file to file
- #normally I would update the targetQuery, but it contains previously set images
- #copy to labkey
- fb.writeFileToFile(segFile,outFile)
-
- else:
- print('File {} available'.format(outFile))
- #separate script (set version!)
- #update database
- copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
- row['SequenceNum']+=0.001*float(pars['versionNumber'])
- filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
- filters.append({'variable':'Version','value':pars['version'],'oper':'eq'})
- ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
- if len(ds1['rows'])>0:
- mode='update'
- outRow=ds1['rows'][0]
- else:
- mode='insert'
- outRow={v:row[v] for v in copyFields}
- outRow['Version']= pars['version']
- outRow['Segmentation']= outName
- print(db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow]))
- #push results back to LabKey
- i+=1
- if i==1 and pars['debug']:
- break
- print("Done")
- if __name__ == '__main__':
- #main(sys.argv[1])
- doSegmentation(sys.argv[1])
- #sys.exit()
|