Browse Source

Recent updates

Andrej Studen/Merlin 3 years ago
parent
commit
d449f0095c
1 changed files with 231 additions and 23 deletions
  1. 231 23
      pythonScripts/runSegmentation.py

+ 231 - 23
pythonScripts/runSegmentation.py

@@ -5,6 +5,9 @@ import subprocess
 import nibabel
 import nibabel
 import shutil
 import shutil
 import sys
 import sys
+import pathlib
+import SimpleITK
+import numpy
 
 
 #nothing gets done if you do import
 #nothing gets done if you do import
 
 
@@ -42,6 +45,186 @@ def valueSubstitution(pars,val):
 
 
     return path
     return path
 
 
+def getCroppedImagePath(tempFile,crop):
+    p=pathlib.Path(tempFile)
+    sfx=''.join(p.suffixes)
+    return re.sub(sfx,crop+sfx,str(p))
+
+def getSuffix(tempFile):
+    p=pathlib.Path(tempFile)
+    return ''.join(p.suffixes)
+
+def getSegmImagePath(tempFile):
+    sfx=getSuffix(tempFile)
+    return re.sub(sfx,'_Segm'+sfx,tempFile)
+
+def addVersion(tempFile,version):
+    sfx=getSuffix(tempFile)
+    return re.sub(sfx,'_'+version+sfx,tempFile)
+
+def normalizeCT(ctFile,maskFile):
+    
+    im=SimpleITK.ReadImage(ctFile)
+    mask=SimpleITK.ReadImage(maskFile)
+    nm=SimpleITK.GetArrayFromImage(im)
+    nmask=SimpleITK.GetArrayFromImage(mask)
+    mu=numpy.mean(nm[nmask>0])
+    st=numpy.std(nm[nmask>0])
+    nm[nmask>0]=(nm[nmask>0]-mu)/st
+    nm[nmask==0]=0
+    im1=SimpleITK.GetImageFromArray(nm)
+    im1.SetOrigin(im.GetOrigin())
+    im1.SetSpacing(im.GetSpacing())
+    im1.SetDirection(im.GetDirection())
+    SimpleITK.WriteImage(im1,ctFile)
+
+def cropImage(tempFile,crop, cropData):
+
+    im=SimpleITK.ReadImage(tempFile)
+    sz=im.GetSize()
+    ax=int(cropData['axis'])
+    rng=[float(v) for v in cropData['range']]
+    #update cropData['n']
+    if cropData['n']=="NONE":
+        cropData['n']=sz[ax]
+    if not sz[ax]==cropData['n']:
+        print('Size mismatch {}:{}'.format(sz[ax],cropData['n']))
+    n=sz[ax]
+    ii=[int(x*n) for x in rng]
+    slc=[slice(None) for v in sz]
+    slc[ax]=slice(ii[0],ii[1])
+    im1=im[slc]
+    #im1=im.take(indices=range(i1,i2),axis=cropData['axis'])
+    SimpleITK.WriteImage(im1,getCroppedImagePath(tempFile,crop))
+    print("Written {}".format(getCroppedImagePath(tempFile,crop)))
+
+
+
+def runDeepMedic(setup,pars):
+    args=[]
+    args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
+    args.append(setup['paths']['deepMedicRun'])
+    args.append('-model')
+    args.append(pars['deepmedic']['config']['model']['out'])
+    args.append('-test')
+    args.append(pars['deepmedic']['config']['test']['out'])
+    args.append('-dev')
+    args.append('cpu')
+    print(args) 
+    print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
+
+def getSegmentationFile(pars,crop):
+    #this is how deep medic stores files
+    return getSegmImagePath(\
+            getCroppedImagePath(\
+            os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
+                pars['images']['images']['segmentations']['tempFile']),crop)
+            )
+
+def getWeight(x,w):
+    for r in w:
+        fw=[float(v) for v in r['range']]
+        if x>fw[1]:
+            continue
+        if x<fw[0]:
+            continue
+        n=float(r['n'])
+        if not 'k' in r:
+            return n 
+
+        k=float(r['k'])
+        return k*x+n
+    return 0
+
+def runSegmentation(fb,row,pars,setup):
+    
+    if False:
+        images=pars['images']['images']
+        outImg=mergeSegmentations(pars)
+        segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
+        SimpleITK.WriteImage(outImg,segFile)
+        return segFile
+     
+    #download to temp file (could be a fixed name)
+    project=pars['project']
+    images=pars['images']['images']
+    participantField=pars['participantField']
+    baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
+        getPatientLabel(row,participantField)+'/'+\
+        getVisitLabel(row))
+    cropData=pars['images']['crop']
+    #reset n
+    for crop in cropData:
+        cropData[crop]['n']="NONE"
+    
+    #download 
+    for im in images:
+        tmpFile=images[im]['tempFile']
+        if 'queryField' in images[im]:
+            fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],tmpFile)
+    #normalize 
+    normalizeCT(images['CT']['tempFile'],images['patientmask']['tempFile'])
+
+    #crop and store file names
+    for im in images:
+        tmpFile=images[im]['tempFile']
+            
+        with open(images[im]['fileList'],'w') as f:
+            for crop in cropData:
+                print('n={}'.format(cropData[crop]['n']))
+                if os.path.isfile(tmpFile):
+                    cropImage(tmpFile,crop,cropData[crop])
+                print('n={}'.format(cropData[crop]['n']))
+                f.write(getCroppedImagePath(tmpFile,crop)+'\n')
+
+    #normalize crops
+    for crop in cropData:
+        normalizeCT(getCroppedImagePath(images['CT']['tempFile'],crop),
+                getCroppedImagePath(images['patientmask']['tempFile'],crop))
+    
+    #run deep medic
+    runDeepMedic(setup,pars)
+
+    #merge segmentations
+    outImg=mergeSegmentations(pars)
+    segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
+    SimpleITK.WriteImage(outImg,segFile)
+    return segFile
+        #
+
+def mergeSegmentations(pars):
+    
+    cropData=pars['images']['crop']
+    start=True
+    for c in cropData:
+        segFile=getSegmentationFile(pars,c)
+        si=SimpleITK.ReadImage(segFile)
+        rng=[float(v) for v in cropData[c]['range']]
+        n=cropData[c]['n']
+        print(n)
+        img=SimpleITK.ConstantPad(si,[0,0,int(rng[0]*n)],[0,0,n-int(rng[1]*n)],-1)
+        print(img.GetSize())
+        ni=SimpleITK.GetArrayFromImage(img)
+        print(ni.shape)
+        w1=numpy.zeros(ni.shape)
+        aw=[getWeight((x+0.5)/n,cropData[c]['w']) for x in numpy.arange(n)]
+        for k in numpy.arange(len(aw)):
+            w1[k,:,:]=aw[k]
+        if start:
+            w0=w1
+            imgTmpl=img
+            nout=ni
+            start=False
+            continue
+        nout[w1>w0]=ni[w1>w0]
+        w0[w1>w0]=w1[w1>w0]
+    iout=SimpleITK.GetImageFromArray(nout)
+    iout.SetDirection(img.GetDirection())
+    iout.SetOrigin(img.GetOrigin())
+    iout.SetSpacing(img.GetSpacing())
+    return iout
+
+        
 def main(parameterFile):
 def main(parameterFile):
     
     
     fhome=os.path.expanduser('~')
     fhome=os.path.expanduser('~')
@@ -69,8 +252,16 @@ def main(parameterFile):
 
 
     pars=parseConfig.convert(pars)
     pars=parseConfig.convert(pars)
     pars=parseConfig.convertValues(pars)
     pars=parseConfig.convertValues(pars)
+    print(pars)
+    
+    #update the config
+    cfg=pars['deepmedic']['config']
+    for c in cfg:
+        replacePatterns(cfg[c]['template'],\
+                cfg[c]['out'],\
+                pars['replacePattern'])
+
 
 
-    hi=0
     project=pars['project']
     project=pars['project']
     dataset=pars['targetQuery']
     dataset=pars['targetQuery']
     schema=pars['targetSchema']
     schema=pars['targetSchema']
@@ -81,40 +272,57 @@ def main(parameterFile):
         os.makedirs(tempBase)
         os.makedirs(tempBase)
 
 
 
 
-    participantField=pars['participantField']
-
     #all images from database
     #all images from database
     ds=db.selectRows(project,schema,dataset,[])
     ds=db.selectRows(project,schema,dataset,[])
 
 
     
     
-    #imageSelector={"CT":"CT","PET":"PETWB_orthancId"}
     #input
     #input
-    images=pars['images']
     #use webdav to transfer file (even though it is localhost)
     #use webdav to transfer file (even though it is localhost)
 
 
-    tempNames={im:os.path.join(tempBase,images[im]['tempFile']) for im in images}
  
  
-
-    #update the config
-    cfg=pars['deepmedic']['config']
-    for c in cfg:
-        replacePatterns(cfg[c]['template'],\
-                cfg[c]['out'],\
-                pars['replacePattern'])
     i=0
     i=0
     for row in ds["rows"]:
     for row in ds["rows"]:
+       
+
+        #check if file is already there
+        #dummy tf to get the suffix
+        tf=getSegmentationFile(pars,'XX')
+        outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
+        outName=addVersion(\
+                getSegmImagePath(\
+                    getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
+                pars['version'])
+
+        outFile=outpath+'/'+outName
+
+        #check if file is there
+        if not fb.entryExists(outFile):
+            segFile=runSegmentation(fb,row,pars,setup)
+            #copy file to file
+            #normally I would update the targetQuery, but it contains previously set images
+            #copy to labkey
+            fb.writeFileToFile(segFile,outFile)
+
         
         
-        #download to temp file (could be a fixed name)
-        baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
-            getPatientLabel(row,participantField)+'/'+\
-            getVisitLabel(row))
-        for im in images:
-            fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],
-                os.path.join(tempBase,images[im]['tempFile']))
-            
-        break        
-        i=i+1
 
 
+        #separate script (set version!)
+
+        #update database
+        copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
+        filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
+        ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
+        if len(ds1['rows'])>0:
+            mode='update'
+            outRow=ds1['rows'][0]
+        else:
+            mode='insert'
+            outRow={v:row[v] for v in copyFields}
+        outRow[pars['version']]= outName
+        db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow])
+        #pull results back to LabKey
+        i+=1
+        if i==1:
+            break
     print("Done")
     print("Done")