Browse Source

Merge branch 'master' of wiscigt.powertheword.com:oil/iraemm

NIX User 3 years ago
parent
commit
e18a635685
2 changed files with 312 additions and 52 deletions
  1. 231 23
      pythonScripts/runSegmentation.py
  2. 81 29
      slicerModules/iraemmBrowser.py

+ 231 - 23
pythonScripts/runSegmentation.py

@@ -5,6 +5,9 @@ import subprocess
 import nibabel
 import shutil
 import sys
+import pathlib
+import SimpleITK
+import numpy
 
 #nothing gets done if you do import
 
@@ -42,6 +45,186 @@ def valueSubstitution(pars,val):
 
     return path
 
+def getCroppedImagePath(tempFile,crop):
+    p=pathlib.Path(tempFile)
+    sfx=''.join(p.suffixes)
+    return re.sub(sfx,crop+sfx,str(p))
+
+def getSuffix(tempFile):
+    p=pathlib.Path(tempFile)
+    return ''.join(p.suffixes)
+
+def getSegmImagePath(tempFile):
+    sfx=getSuffix(tempFile)
+    return re.sub(sfx,'_Segm'+sfx,tempFile)
+
+def addVersion(tempFile,version):
+    sfx=getSuffix(tempFile)
+    return re.sub(sfx,'_'+version+sfx,tempFile)
+
+def normalizeCT(ctFile,maskFile):
+    
+    im=SimpleITK.ReadImage(ctFile)
+    mask=SimpleITK.ReadImage(maskFile)
+    nm=SimpleITK.GetArrayFromImage(im)
+    nmask=SimpleITK.GetArrayFromImage(mask)
+    mu=numpy.mean(nm[nmask>0])
+    st=numpy.std(nm[nmask>0])
+    nm[nmask>0]=(nm[nmask>0]-mu)/st
+    nm[nmask==0]=0
+    im1=SimpleITK.GetImageFromArray(nm)
+    im1.SetOrigin(im.GetOrigin())
+    im1.SetSpacing(im.GetSpacing())
+    im1.SetDirection(im.GetDirection())
+    SimpleITK.WriteImage(im1,ctFile)
+
+def cropImage(tempFile,crop, cropData):
+
+    im=SimpleITK.ReadImage(tempFile)
+    sz=im.GetSize()
+    ax=int(cropData['axis'])
+    rng=[float(v) for v in cropData['range']]
+    #update cropData['n']
+    if cropData['n']=="NONE":
+        cropData['n']=sz[ax]
+    if not sz[ax]==cropData['n']:
+        print('Size mismatch {}:{}'.format(sz[ax],cropData['n']))
+    n=sz[ax]
+    ii=[int(x*n) for x in rng]
+    slc=[slice(None) for v in sz]
+    slc[ax]=slice(ii[0],ii[1])
+    im1=im[slc]
+    #im1=im.take(indices=range(i1,i2),axis=cropData['axis'])
+    SimpleITK.WriteImage(im1,getCroppedImagePath(tempFile,crop))
+    print("Written {}".format(getCroppedImagePath(tempFile,crop)))
+
+
+
+def runDeepMedic(setup,pars):
+    args=[]
+    args.append(os.path.join(setup['paths']['deepMedicVE'],'bin','python'))
+    args.append(setup['paths']['deepMedicRun'])
+    args.append('-model')
+    args.append(pars['deepmedic']['config']['model']['out'])
+    args.append('-test')
+    args.append(pars['deepmedic']['config']['test']['out'])
+    args.append('-dev')
+    args.append('cpu')
+    print(args) 
+    print(subprocess.run(args,check=True,stdout=subprocess.PIPE).stdout)
+
+def getSegmentationFile(pars,crop):
+    #this is how deep medic stores files
+    return getSegmImagePath(\
+            getCroppedImagePath(\
+            os.path.join(pars['tempBase'],'output','predictions','currentSession','predictions',\
+                pars['images']['images']['segmentations']['tempFile']),crop)
+            )
+
+def getWeight(x,w):
+    for r in w:
+        fw=[float(v) for v in r['range']]
+        if x>fw[1]:
+            continue
+        if x<fw[0]:
+            continue
+        n=float(r['n'])
+        if not 'k' in r:
+            return n 
+
+        k=float(r['k'])
+        return k*x+n
+    return 0
+
+def runSegmentation(fb,row,pars,setup):
+    
+    if False:
+        images=pars['images']['images']
+        outImg=mergeSegmentations(pars)
+        segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
+        SimpleITK.WriteImage(outImg,segFile)
+        return segFile
+     
+    #download to temp file (could be a fixed name)
+    project=pars['project']
+    images=pars['images']['images']
+    participantField=pars['participantField']
+    baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
+        getPatientLabel(row,participantField)+'/'+\
+        getVisitLabel(row))
+    cropData=pars['images']['crop']
+    #reset n
+    for crop in cropData:
+        cropData[crop]['n']="NONE"
+    
+    #download 
+    for im in images:
+        tmpFile=images[im]['tempFile']
+        if 'queryField' in images[im]:
+            fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],tmpFile)
+    #normalize 
+    normalizeCT(images['CT']['tempFile'],images['patientmask']['tempFile'])
+
+    #crop and store file names
+    for im in images:
+        tmpFile=images[im]['tempFile']
+            
+        with open(images[im]['fileList'],'w') as f:
+            for crop in cropData:
+                print('n={}'.format(cropData[crop]['n']))
+                if os.path.isfile(tmpFile):
+                    cropImage(tmpFile,crop,cropData[crop])
+                print('n={}'.format(cropData[crop]['n']))
+                f.write(getCroppedImagePath(tmpFile,crop)+'\n')
+
+    #normalize crops
+    for crop in cropData:
+        normalizeCT(getCroppedImagePath(images['CT']['tempFile'],crop),
+                getCroppedImagePath(images['patientmask']['tempFile'],crop))
+    
+    #run deep medic
+    runDeepMedic(setup,pars)
+
+    #merge segmentations
+    outImg=mergeSegmentations(pars)
+    segFile=os.path.join(pars['tempBase'],images['segmentations']['tempFile'])
+    SimpleITK.WriteImage(outImg,segFile)
+    return segFile
+        #
+
+def mergeSegmentations(pars):
+    
+    cropData=pars['images']['crop']
+    start=True
+    for c in cropData:
+        segFile=getSegmentationFile(pars,c)
+        si=SimpleITK.ReadImage(segFile)
+        rng=[float(v) for v in cropData[c]['range']]
+        n=cropData[c]['n']
+        print(n)
+        img=SimpleITK.ConstantPad(si,[0,0,int(rng[0]*n)],[0,0,n-int(rng[1]*n)],-1)
+        print(img.GetSize())
+        ni=SimpleITK.GetArrayFromImage(img)
+        print(ni.shape)
+        w1=numpy.zeros(ni.shape)
+        aw=[getWeight((x+0.5)/n,cropData[c]['w']) for x in numpy.arange(n)]
+        for k in numpy.arange(len(aw)):
+            w1[k,:,:]=aw[k]
+        if start:
+            w0=w1
+            imgTmpl=img
+            nout=ni
+            start=False
+            continue
+        nout[w1>w0]=ni[w1>w0]
+        w0[w1>w0]=w1[w1>w0]
+    iout=SimpleITK.GetImageFromArray(nout)
+    iout.SetDirection(img.GetDirection())
+    iout.SetOrigin(img.GetOrigin())
+    iout.SetSpacing(img.GetSpacing())
+    return iout
+
+        
 def main(parameterFile):
     
     fhome=os.path.expanduser('~')
@@ -69,8 +252,16 @@ def main(parameterFile):
 
     pars=parseConfig.convert(pars)
     pars=parseConfig.convertValues(pars)
+    print(pars)
+    
+    #update the config
+    cfg=pars['deepmedic']['config']
+    for c in cfg:
+        replacePatterns(cfg[c]['template'],\
+                cfg[c]['out'],\
+                pars['replacePattern'])
+
 
-    hi=0
     project=pars['project']
     dataset=pars['targetQuery']
     schema=pars['targetSchema']
@@ -81,40 +272,57 @@ def main(parameterFile):
         os.makedirs(tempBase)
 
 
-    participantField=pars['participantField']
-
     #all images from database
     ds=db.selectRows(project,schema,dataset,[])
 
     
-    #imageSelector={"CT":"CT","PET":"PETWB_orthancId"}
     #input
-    images=pars['images']
     #use webdav to transfer file (even though it is localhost)
 
-    tempNames={im:os.path.join(tempBase,images[im]['tempFile']) for im in images}
  
-
-    #update the config
-    cfg=pars['deepmedic']['config']
-    for c in cfg:
-        replacePatterns(cfg[c]['template'],\
-                cfg[c]['out'],\
-                pars['replacePattern'])
     i=0
     for row in ds["rows"]:
+       
+
+        #check if file is already there
+        #dummy tf to get the suffix
+        tf=getSegmentationFile(pars,'XX')
+        outpath=fb.buildPathURL(pars['project'],[pars['imageDir'],row['patientCode'],row['visitCode']])
+        outName=addVersion(\
+                getSegmImagePath(\
+                    getStudyLabel(row,pars['participantField'])+getSuffix(tf)),\
+                pars['version'])
+
+        outFile=outpath+'/'+outName
+
+        #check if file is there
+        if not fb.entryExists(outFile):
+            segFile=runSegmentation(fb,row,pars,setup)
+            #copy file to file
+            #normally I would update the targetQuery, but it contains previously set images
+            #copy to labkey
+            fb.writeFileToFile(segFile,outFile)
+
         
-        #download to temp file (could be a fixed name)
-        baseDir=fb.formatPathURL(project,pars['imageDir']+'/'+\
-            getPatientLabel(row,participantField)+'/'+\
-            getVisitLabel(row))
-        for im in images:
-            fb.readFileToFile(baseDir+'/'+row[images[im]['queryField']],
-                os.path.join(tempBase,images[im]['tempFile']))
-            
-        break        
-        i=i+1
 
+        #separate script (set version!)
+
+        #update database
+        copyFields=[pars['participantField'],'SequenceNum','patientCode','visitCode']
+        filters=[{'variable':v,'value':str(row[v]),'oper':'eq'} for v in copyFields]
+        ds1=db.selectRows(pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],filters)
+        if len(ds1['rows'])>0:
+            mode='update'
+            outRow=ds1['rows'][0]
+        else:
+            mode='insert'
+            outRow={v:row[v] for v in copyFields}
+        outRow[pars['version']]= outName
+        db.modifyRows(mode,pars['project'],pars['segmentationSchema'],pars['segmentationQuery'],[outRow])
+        #pull results back to LabKey
+        i+=1
+        if i==1:
+            break
     print("Done")
 
 

+ 81 - 29
slicerModule/iraemmBrowser.py → slicerModules/iraemmBrowser.py

@@ -2,10 +2,9 @@ import os
 import unittest
 from __main__ import vtk, qt, ctk, slicer
 from slicer.ScriptedLoadableModule import *
-import slicerNetwork
-import loadDicom
 import json
 import datetime
+import sys
 
 #
 # labkeySlicerPythonExtension
@@ -42,12 +41,41 @@ class iraemmBrowserWidget(ScriptedLoadableModuleWidget):
     print("Setting up iraemmBrowserWidget")
     ScriptedLoadableModuleWidget.setup(self)
     # Instantiate and connect widgets ...
-    self.network=slicerNetwork.labkeyURIHandler()
+    
+    fhome=os.path.expanduser('~')
+    fsetup=os.path.join(fhome,'.labkey','setup.json')
+    try:
+        with open(fsetup) as f:
+            self.setup=json.load(f)
+    except FileNotFoundError:
+        self.setup={}
+
+    try:
+        pt=self.setup['paths']
+    except KeyError:
+        self.setup['paths']={}
+
+    try:    
+        sys.path.append(self.setup['paths']['labkeyInterface'])
+    except KeyError:
+        self.setup['paths']['labkeyInterface']=loadLibrary('labkeyInterface')
+        with open(fsetup,'w') as f:
+            json.dump(self.setup,f,indent='\t')
+    
+    import labkeyInterface
+    import labkeyDatabaseBrowser
+    import labkeyFileBrowser
+
+
+    self.network=labkeyInterface.labkeyInterface() 
 
     fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
-    self.network.parseConfig(fconfig)
-    self.network.initRemote()
+    self.network.init(fconfig)
+
+    self.db=labkeyDatabaseBrowser.labkeyDB(self.network)
+    self.fb=labkeyFileBrowser.labkeyFileBrowser(self.network)
     self.project="iPNUMMretro/Study"
+    self.schema='study'
     self.dataset="Imaging1"
     self.reviewDataset="ImageReview"
     self.aeDataset="PET"
@@ -55,7 +83,7 @@ class iraemmBrowserWidget(ScriptedLoadableModuleWidget):
 
     
 
-    ds=self.network.filterDataset(self.project,self.dataset,[])
+    ds=self.db.selectRows(self.project,self.schema,self.dataset,[])
     ids=[row['PatientId'] for row in ds['rows']]
     ids=list(set(ids))
 
@@ -80,7 +108,10 @@ class iraemmBrowserWidget(ScriptedLoadableModuleWidget):
 
     self.segmentationField=qt.QLabel("Segmentation")
     setupFormLayout.addRow("Data field (Segmentation):",self.segmentationField)
-    
+   
+
+    self.idField=qt.QLabel(self.network.getUserId()['displayName'])
+    setupFormLayout.addRow("ID",self.idField)
     self.logic=iraemmBrowserLogic(self)
 
 
@@ -194,7 +225,7 @@ class iraemmBrowserWidget(ScriptedLoadableModuleWidget):
 
   def onPatientListChanged(self,i):
       idFilter={'variable':'PatientId','value':self.patientList.currentText,'oper':'eq'}
-      ds=self.network.filterDataset(self.project,self.dataset, [idFilter])
+      ds=self.db.selectRows(self.project,self.schema,self.dataset, [idFilter])
       seq=[int(row['SequenceNum']) for row in ds['rows']]
       self.visitList.clear()  
             
@@ -211,7 +242,7 @@ class iraemmBrowserWidget(ScriptedLoadableModuleWidget):
       idFilter={'variable':'PatientId',\
               'value':self.patientList.currentText,'oper':'eq'}
       sFilter={'variable':'SequenceNum','value':s,'oper':'eq'}
-      ds=self.network.filterDataset(self.project,self.dataset,[idFilter,sFilter])
+      ds=self.db.selectRows(self.project,self.schema,self.dataset,[idFilter,sFilter])
       if not len(ds['rows'])==1:
           print("Found incorrect number {} of matches for [{}]/[{}]".\
                   format(len(ds['rows']),\
@@ -374,7 +405,10 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
           #assume parent has the network set up
           self.parent=parent
           self.net=parent.network
+          self.db=parent.db
+          self.fb=parent.fb
           self.project=parent.project
+          self.schema=parent.schema
           self.participantField=parent.participantField.text
           self.segmentList=parent.segmentList
 
@@ -396,33 +430,44 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
 
   def loadImage(self,row,keepCached):
       
+      tempDir=os.path.join(os.path.expanduser('~'),'temp')
+      if not os.path.isdir(tempDir):
+          os.mkdir(tempDir)
       
       #fields={'ctResampled':True,'petResampled':False}
       fields={"CT":self.parent.ctField.text,\
               "PET":self.parent.petField.text,\
               "Segmentation":self.parent.segmentationField.text}
 
-      relativePaths={x:self.project+'/@files/preprocessedImages/'\
-             +row['patientCode']+'/'+row['visitCode']+'/'+row[y]\
+      relativePaths={x:['preprocessedImages',row['patientCode'],row['visitCode'],row[y]]\
              for (x,y) in fields.items()}
 
       self.volumeNode={}
       for f in relativePaths:
+
           p=relativePaths[f]
-          labkeyPath=self.net.GetLabkeyPathFromRelativePath(p)
-          rp=self.net.head(labkeyPath)
-          if not slicerNetwork.labkeyURIHandler.HTTPStatus(rp):
-              print("Failed to get {}".format(labkeyPath))
-              continue
+          
+          localPath=os.path.join(tempDir,p[-1])
 
-          #pushes it to background
+          if not os.path.isfile(localPath):
+              #download from server
+              remotePath=self.fb.formatPathURL(self.project,'/'.join(p))
+              if not self.fb.entryExists(remotePath):
+                  print("Failed to get {}".format(remotePath))
+                  continue
+
+              self.fb.readFileToFile(remotePath,localPath) 
+          
           properties={}
           #make sure segmentation gets loaded as a labelmap
           if f=="Segmentation":
               properties["labelmap"]=1
-
-          self.volumeNode[f]=self.net.loadNode(p,'VolumeFile',\
-                  properties=properties,returnNode=True,keepCached=keepCached)
+              
+          
+          self.volumeNode[f]=slicer.util.loadNodeFromFile(localPath,filetype='VolumeFile',properties=properties)
+          
+          if not keepCached:
+              os.remove(localPath)
 
       #mimic abdominalCT standardized window setting
       self.volumeNode['CT'].GetScalarVolumeDisplayNode().\
@@ -535,14 +580,15 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
           filters.append({'variable':f,'value':str(inputRow[f]),'oper':'eq'})
       
       
-      ds=self.net.filterDataset(project,dataset,filters)
+      ds=self.db.selectRows(project,self.schema,dataset,filters)
       return ds['rows']
 
   def loadReview(self,currentRow):
 
       #see if we have already done a review
+      currentRow['ModifiedBy']=self.net.getUserId()['id']
       rows=self.getUniqueRows(self.parent.project,self.parent.reviewDataset,\
-              [self.participantField,'visitCode','Segmentation'],currentRow)
+              [self.participantField,'visitCode','Segmentation','ModifiedBy'],currentRow)
 
       if len(rows)==0:
           return
@@ -559,7 +605,8 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
 
   #submit review to labkey
   def submitReview(self,currentRow,comment):
-      fields=[self.participantField,'visitCode','Segmentation']
+      currentRow['ModifiedBy']=self.net.getUserId()['id']
+      fields=[self.participantField,'visitCode','Segmentation','ModifiedBy']
       rows=self.getUniqueRows(self.parent.project,self.parent.reviewDataset,\
               fields,currentRow)
 
@@ -590,7 +637,7 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
 
       row['reviewComment']=comment
       row['Date']=datetime.datetime.now().ctime()
-      self.net.modifyDataset(mode,self.parent.project,\
+      self.db.modifyRows(mode,self.parent.project,self.parent.schema,\
               self.parent.reviewDataset,[row])
       
       print("review submitted")
@@ -617,7 +664,8 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
 
 
   def loadAE(self,currentRow):
-      fields=[self.participantField,'petResampled']
+      currentRow['ModifiedBy']=self.net.getUserId()['id']
+      fields=[self.participantField,'petResampled','ModifiedBy']
       rows=self.getUniqueRows(self.parent.project,self.parent.aeDataset,\
               fields,currentRow)
       
@@ -626,7 +674,7 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
 
       print("Found {} rows".format(len(rows)))
       row=rows[0]
-      
+     
       for seg in self.segmentList:
           name=seg+'AE'
           try:
@@ -637,7 +685,8 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
               continue
 
   def submitAE(self,currentRow):
-      fields=[self.participantField,'petResampled']
+      currentRow['ModifiedBy']=self.net.getUserId()['id']
+      fields=[self.participantField,'petResampled','ModifiedBy']
       rows=self.getUniqueRows(self.parent.project,self.parent.aeDataset,\
               fields,currentRow)
       
@@ -648,7 +697,10 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
           for f in fields:
               row[f]=currentRow[f]
           
-          row['SequenceNum']=currentRow['SequenceNum']
+          frows=self.getUniqueRows(self.parent.project,self.parent.aeDataset,\
+                  [self.participantField,'visitCode'],currentRow)
+
+          row['SequenceNum']=currentRow['SequenceNum']+0.01*len(frows)
       
       else:
           mode='update'
@@ -658,7 +710,7 @@ class iraemmBrowserLogic(ScriptedLoadableModuleLogic):
           row[seg+'AE']=self.getAEResult(seg) 
 
       row['Date']=datetime.datetime.now().ctime()
-      resp=self.net.modifyDataset(mode,self.parent.project,\
+      resp=self.db.modifyRows(mode,self.parent.project,self.parent.schema,\
               self.parent.aeDataset,[row])
       print("Response {}".format(resp))
       print("AE submitted")