Bladeren bron

Removing references to labkeySlicerWidget

Andrej Studen 6 jaren geleden
bovenliggende
commit
2c8c6dd7a6
2 gewijzigde bestanden met toevoegingen van 267 en 233 verwijderingen
  1. 47 21
      cardiacSPECT/cardiacSPECT.py
  2. 220 212
      cardiacSPECT/parseDicom.py

+ 47 - 21
cardiacSPECT/cardiacSPECT.py

@@ -4,11 +4,12 @@ import unittest
 import vtk, qt, ctk, slicer
 from slicer.ScriptedLoadableModule import *
 import logging
-import parseDicom as pd
+import parseDicom
 import vtkInterface as vi
 import fileIO
 import slicer
 import numpy as np
+import slicerNetwork
 #
 # cardiacSPECT
 #
@@ -46,8 +47,15 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
   def setup(self):
     ScriptedLoadableModuleWidget.setup(self)
 
+
     self.selectRemote=fileIO.remoteFileSelector()
-    self.network=slicer.modules.labkeySlicerPythonExtensionWidget.network
+    try:
+        self.network=slicer.modules.labkeySlicerPythonExtensionWidget.network
+    except:
+        self.network=slicerNetwork.labkeyURIHandler()
+
+    self.logic=cardiacSPECTLogic(self)
+    self.logic.setURIHandler(self.network)
     self.selectRemote.setMaster(self)
 
     # Instantiate and connect widgets ...
@@ -197,8 +205,6 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
     # Add vertical spacer
     self.layout.addStretch(1)
 
-    self.logic=cardiacSPECTLogic()
-
     self.resetPosition=1
 
   def cleanup(self):
@@ -260,7 +266,7 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
         self.meanROIResult.setText(str(s))
 
   def onDrawTimePlotClicked(self):
-        n=self.time_frame_select.maximum
+        n=self.time_frame_select.maximum+1
         ft=self.logic.frame_time
 
         #find number of segments
@@ -270,9 +276,9 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
         cn = slicer.mrmlScene.AddNode(slicer.vtkMRMLChartNode())
 
         for j in range(0,ns):
-            segment="Segment_"+str(j+1)
             #add node for data
             dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
+            dn.SetName(self.logic.getSegmentName(j))
             a = dn.GetArray()
             a.SetNumberOfTuples(n)
             dt=0;
@@ -289,7 +295,9 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
                 a.SetComponent(i, 2, 0)
                 print("[{0} at {1:.2f}:{2:.2f}]".format(vol,fx,fy))
 
-            cn.AddArray(segment, dn.GetID())
+
+            #fish the number of the segment
+            cn.AddArray(self.logic.getSegmentName(j), dn.GetID())
 
         cn.SetProperty('default', 'title', 'ROI time plot')
         cn.SetProperty('default', 'xAxisLabel', 'time [ms]')
@@ -297,8 +305,10 @@ class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
 
         #update the chart node
         cvns = slicer.mrmlScene.GetNodesByClass('vtkMRMLChartViewNode')
-        cvns.InitTraversal()
-        cvn = cvns.GetNextItemAsObject()
+        if cvns.GetNumberOfItems() == 0:
+            cvn = slicer.mrmlScene.AddNode(slicer.vtkMRMLChartViewNode())
+        else:
+            cvn = cvns.GetItemAsObject(0)
         cvn.SetChartNodeID(cn.GetID())
 
   def onCountSegmentsClicked(self):
@@ -324,15 +334,22 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
   Uses ScriptedLoadableModuleLogic base class, available at:
   https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
   """
+  def __init__(self,parent):
+      ScriptedLoadableModuleLogic.__init__(self, parent)
+      self.pd=parseDicom.parseDicomLogic(self)
+
+  def setURIHandler(self,net):
+      self.net=net
+      self.pd.setURIHandler(net)
 
   def loadData(self,widget):
 
     inputDir=str(widget.dataPath.text)
     self.frame_data, self.frame_time, self.frame_origin, \
-        self.frame_pixel_size, self.frame_orientation=pd.read_dynamic_SPECT(inputDir)
+        self.frame_pixel_size, self.frame_orientation=self.pd.read_dynamic_SPECT(inputDir)
 
     self.ct_data,self.ct_origin,self.ct_pixel_size, \
-        self.ct_orientation=pd.read_CT(inputDir)
+        self.ct_orientation=self.pd.read_CT(inputDir)
 
     self.ct_orientation=vi.completeOrientation(self.ct_orientation)
     self.frame_orientation=vi.completeOrientation(self.frame_orientation)
@@ -420,7 +437,7 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
     #    return 0
 
     #edit here to change for more segments
-    segment=segNode.GetSegmentation().GetNthSegmentID(i)
+    segment=segNode.GetSegmentation().GetNthSegmentID(int(i))
     mask = segNode.GetBinaryLabelmapRepresentation(segment)
     if mask==None:
         print("Segment {} not found".format(segment))
@@ -474,16 +491,25 @@ class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
     return s/n
 
   def countSegments(self):
-    fNode=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode").GetItemAsObject(0)
+    segNodeList=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
+    if segNodeList.GetNumberOfItems()==0:
+        return 0
+    fNode=segNodeList.GetItemAsObject(0)
     segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
-    i=1
-    while 1:
-        segName="Segment_"+str(i)
-        mask = segNode.GetBinaryLabelmapRepresentation(segName)
-        if mask==None:
-           break
-        i+=1
-    return i-1
+    if fNode==None:
+        return 0
+    return segNode.GetSegmentation().GetNumberOfSegments()
+
+  def getSegmentName(self,i):
+      segNodeList=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
+      if segNodeList.GetNumberOfItems()==0:
+          return "NONE"
+      fNode=segNodeList.GetItemAsObject(0)
+      segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
+      if fNode==None:
+          return "NONE"
+      return segNode.GetSegmentation().GetSegment(segNode.GetSegmentation().GetNthSegmentID(i)).GetName()
+
 
 class cardiacSPECTTest(ScriptedLoadableModuleTest):
   """

+ 220 - 212
cardiacSPECT/parseDicom.py

@@ -4,6 +4,7 @@ import dicom
 import numpy as np
 import re
 import slicer
+from slicer.ScriptedLoadableModule import *
 
 #rom os import listdir
 #from os.path import isfile, join
@@ -14,9 +15,10 @@ import slicer
 #root = tk.Tk()
 #root.withdraw()
 #file_path = filedialog.askopenfilename()
-class parseDicom:
+class parseDicom(ScriptedLoadableModule):
   def __init__(self, parent):
-    parent.title = "parse dicom"
+    ScriptedLoadableModule.__init__(self, parent)
+    parent.title = "parseDicom"
     parent.categories = ["Examples"]
     parent.dependencies = []
     parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
@@ -30,279 +32,285 @@ class parseDicom:
     """ # replace with organization, grant and thanks.
     self.parent = parent
 
-def filelist(mypath):
-#mypath=os.environ['PWD']
-#list files
-    if mypath.find('labkey://')==0:
-        print("Using labkey")
-        labkeyPath=re.sub('labkey://','',mypath)
-        #not sure if labkey is available, so try it
-        net=slicer.modules.labkeySlicerPythonExtensionWidget.network
-        print("Found network")
-        #url=slicer.modules.labkeySlicerPythonExtensionWidget.serverURL.text
-        #print("Seting url={}".format(url))
-        ok, files=net.listRemoteDir(labkeyPath)
-        if not ok:
-            print "Error accessing path"
-            return []
-
-    if mypath.find('file://')==0:
-        print("Using local files")
-        localPath=re.sub('file://','',mypath)
-        files = [os.path.join(localPath,f) for f in os.listdir(localPath)
-            if os.path.isfile(os.path.join(localPath, f))]
-
-    return files
-
-def getfile(origin,f):
-
-    if origin.find('labkey')==0:
-        try:
-            #not sure if labkey is available, but try it
-            net=slicer.modules.labkeySlicerPythonExtensionWidget.network
-            print("Using labkey")
-            url=slicer.modules.labkeySlicerPythonExtensionWidget.serverURL.text
-            print("Sever:{0}, file:{1}".format(url,f))
-            return [net.readFile(str(url),f),1]
-        except:
-            print('Could not access labkey. Exiting')
-            return ['NULL',0]
-
-    if origin.find('file')==0:
-        print("Using local directory")
-        return [f,1]
+class parseDicomWidget(ScriptedLoadableModuleWidget):
+    def setup(self):
+        ScriptedLoadableModuleWidget.setup(self)
+        self.logic=parseDicomLogic(self)
 
-    return ['NULL',0]
+class parseDicomLogic(ScriptedLoadableModuleLogic):
 
-def read_dynamic_SPECT(mypath):
-    axisShift=(2,1,0)
+    def __init__(self,parent):
+        ScriptedLoadableModuleLogic.__init__(self, parent)
 
-    origin=re.sub('([^:/])://(.*)$',r'\1',mypath)
-    onlyfiles=filelist(mypath)
-    for f in onlyfiles:
-        print '{}:'.format(f)
+    def setURIHandler(self,net):
+        self.net=net
 
-        g,ok=getfile(origin,f)
-        if not(ok):
+    def filelist(self,mypath):
+        if mypath.find('labkey://')==0:
+            print("Using labkey")
+            labkeyPath=re.sub('labkey://','',mypath)
+            #not sure if labkey is available, so try it
+            #url=slicer.modules.labkeySlicerPythonExtensionWidget.serverURL.text
+            #print("Seting url={}".format(url))
+            ok, files=self.net.listRemoteDir(labkeyPath)
+            if not ok:
+                print "Error accessing path"
+                return []
+
+            if mypath.find('file://')==0:
+                print("Using local files")
+                localPath=re.sub('file://','',mypath)
+                files = [os.path.join(localPath,f) for f in os.listdir(localPath)
+                    if os.path.isfile(os.path.join(localPath, f))]
+            return files
+
+    def getfile(self,origin,f):
+
+        if origin.find('labkey')==0:
+            try:
+                #not sure if labkey is available, but try it
+                print("Using labkey")
+                url=self.net.GetHostName()
+                print("Sever:{0}, file:{1}".format(url,f))
+                return [self.net.readFile(str(url),f),1]
+            except:
+                print('Could not access labkey. Exiting')
+                return ['NULL',0]
+
+        if origin.find('file')==0:
+            print("Using local directory")
+            return [f,1]
+
+        return ['NULL',0]
+
+    def read_dynamic_SPECT(self,mypath):
+        axisShift=(2,1,0)
+
+        origin=re.sub('([^:/])://(.*)$',r'\1',mypath)
+        onlyfiles=self.filelist(mypath)
+        for f in onlyfiles:
+            print '{}:'.format(f)
+
+            g,ok=self.getfile(origin,f)
+            if not(ok):
                 return
 
-        try:
-            plan = dicom.read_file(g)
-        except:
-            print ("Not a dicom file")
-            continue
-        try:
-            nframe=plan[0x0019,0x10a5].value;
-        except:
-            print ("Tag not found;")
-            continue
-        if not (type(nframe) is list) :
-            print("nframe not a list")
-            continue
+            try:
+                plan = dicom.read_file(g)
+            except:
+                print ("Not a dicom file")
+                continue
+            try:
+                nframe=plan[0x0019,0x10a5].value;
+            except:
+                print ("Tag not found;")
+                continue
+            if not (type(nframe) is list) :
+                print("nframe not a list")
+                continue
 
     #this is the "master" file where data on other files can be had
     #here we found out the duration of the frame and their distribution through
     #phases and cycles
-        print('Found master file')
+            print('Found master file')
 
-        for i in range(1,len(nframe)):
-            nframe[i]+=nframe[i-1]
+            for i in range(1,len(nframe)):
+                nframe[i]+=nframe[i-1]
 
-        print(nframe)
+            print(nframe)
 
     #nframe now holds for index i total number of frames collected up
     #to the end of each phase
 
-        frame_start=plan[0x0019,0x10a7].value
-        frame_stop=plan[0x0019,0x10a8].value
-        frame_duration=plan[0x0019,0x10a9].value
-        break
+            frame_start=plan[0x0019,0x10a7].value
+            frame_stop=plan[0x0019,0x10a8].value
+            frame_duration=plan[0x0019,0x10a9].value
+            break
     #print "rep [{}] start [{}] stop [{}] duration [{}]".format(
     #len(rep),len(rep_start),len(rep_stop),len(rep_duration))
 
 #select AC reconstructed data
-    frame_time=np.zeros(nframe[-1]);
-    frame_data=np.empty([1,1,1,nframe[-1]])
-    center = [0,0,0]
-    pixel_size =[0,0,0]
-    frame_orientation=[0,0,0,0,0,0]
-    for f in onlyfiles:
-
-        g,ok=getfile(origin,f)
-        if not(ok):
+        frame_time=np.zeros(nframe[-1]);
+        frame_data=np.empty([1,1,1,nframe[-1]])
+        center = [0,0,0]
+        pixel_size =[0,0,0]
+        frame_orientation=[0,0,0,0,0,0]
+        for f in onlyfiles:
+
+            g,ok=self.getfile(origin,f)
+            if not(ok):
+                continue
+
+            try:
+                plan = dicom.read_file(g)
+            except:
+                print ("Not a dicom file")
+                continue
+
+            try:
+                pf=plan[0x0018,0x5020]
+            except:
+                print ("ProcessingFunction not found")
                 continue
 
-        try:
-            plan = dicom.read_file(g)
-        except:
-            print ("Not a dicom file")
-            continue
-
-        try:
-            pf=plan[0x0018,0x5020]
-        except:
-            print ("ProcessingFunction not found")
-            continue
-
-        try:
-            phase=plan[0x0035,0x1005].value
-            cycle=plan[0x0035,0x1004].value
-        except:
-            print ("Phase/Cycle tag not found")
-            continue
-
-        #convert phase/cycle to frame index
-        off=0
-        if phase > 1:
-            off=nframe[phase-2]
-        ifi=off+cycle-1
+            try:
+                phase=plan[0x0035,0x1005].value
+                cycle=plan[0x0035,0x1004].value
+            except:
+                print ("Phase/Cycle tag not found")
+                continue
+
+            #convert phase/cycle to frame index
+            off=0
+            if phase > 1:
+                off=nframe[phase-2]
+            ifi=off+cycle-1
 
     #from values in the master file determine frame time
     #(as the mid point between starting and ending the frame)
-        frame_time[ifi]=0.5*(frame_start[ifi]+frame_stop[ifi]); #in ms
+            frame_time[ifi]=0.5*(frame_start[ifi]+frame_stop[ifi]); #in ms
 
-        print "({},{}) converted to {} at {} for {}".format(
-            phase,cycle,ifi,frame_time[ifi],frame_duration[ifi])
+            print "({},{}) converted to {} at {} for {}".format(
+                phase,cycle,ifi,frame_time[ifi],frame_duration[ifi])
 
     #play with pixel data
-        if frame_data.shape[0] == 1:
-            sh=np.transpose(plan.pixel_array,axisShift).shape;
-            sh=list(sh)
-            sh.append(nframe[-1])
-            frame_data=np.empty(sh)
-            print "Setting frame_data to",sh
+            if frame_data.shape[0] == 1:
+                sh=np.transpose(plan.pixel_array,axisShift).shape;
+                sh=list(sh)
+                sh.append(nframe[-1])
+                frame_data=np.empty(sh)
+                print "Setting frame_data to",sh
 
             #check & update pixel size
-        pixel_size_read=[plan.PixelSpacing[0],plan.PixelSpacing[1],
+            pixel_size_read=[plan.PixelSpacing[0],plan.PixelSpacing[1],
                         plan.SliceThickness]
 
-        for i in range(0,3):
-            if pixel_size[i] == 0:
-                pixel_size[i] = float(pixel_size_read[i])
-            if abs(pixel_size[i]-pixel_size_read[i]) > 1e-3:
-                print 'Pixel size mismatch {.2f}/{.2f}'.format(pixel_size[i],
-                pixel_size_read[i])
-
-        center_read=plan.DetectorInformationSequence[0].ImagePositionPatient
-        print "Stored center at ({0},{1},{2})".format(center[0],center[1],center[2])
-        print "Read   center at ({0},{1},{2})".format(center_read[0],center_read[1],center_read[2])
-        for i in range(0,3):
-            if center[i] == 0:
-                center[i] = float(center_read[i])
-            if abs(center[i]-center_read[i]) > 1e-3:
-                        print 'Image center mismatch {.2f}/{.2f}'.format(center[i],
+            for i in range(0,3):
+                if pixel_size[i] == 0:
+                    pixel_size[i] = float(pixel_size_read[i])
+                if abs(pixel_size[i]-pixel_size_read[i]) > 1e-3:
+                    print 'Pixel size mismatch {.2f}/{.2f}'.format(pixel_size[i],
+                    pixel_size_read[i])
+
+            center_read=plan.DetectorInformationSequence[0].ImagePositionPatient
+            print "Stored center at ({0},{1},{2})".format(center[0],center[1],center[2])
+            print "Read   center at ({0},{1},{2})".format(center_read[0],center_read[1],center_read[2])
+            for i in range(0,3):
+                if center[i] == 0:
+                    center[i] = float(center_read[i])
+                if abs(center[i]-center_read[i]) > 1e-3:
+                    print 'Image center mismatch {.2f}/{.2f}'.format(center[i],
                         center_read[i])
 
-        frame_orientation_read=plan.DetectorInformationSequence[0].ImageOrientationPatient
-        for i in range(0,6):
-            if frame_orientation[i] == 0:
-                frame_orientation[i] = float(frame_orientation_read[i])
-            if abs(frame_orientation[i]-frame_orientation_read[i]) > 1e-3:
-                        print 'Image orientation mismatch {.2f}/{.2f}'.format(
+            frame_orientation_read=plan.DetectorInformationSequence[0].ImageOrientationPatient
+            for i in range(0,6):
+                if frame_orientation[i] == 0:
+                    frame_orientation[i] = float(frame_orientation_read[i])
+                if abs(frame_orientation[i]-frame_orientation_read[i]) > 1e-3:
+                    print 'Image orientation mismatch {.2f}/{.2f}'.format(
                         frame_rotation[i], frame_orientation_read[i])
 
 
 
 
-        frame_data[:,:,:,ifi]=np.transpose(plan.pixel_array,axisShift)
+            frame_data[:,:,:,ifi]=np.transpose(plan.pixel_array,axisShift)
 
     #print('Orientation: ({0:.2f},{1:.2f},{2:.2f}),({3:.2f},{4:.2f},{5:.2f})').format( \
     #    frame_orientation[0],frame_orientation[1],frame_orientation[2], \
     #    frame_orientation[3],frame_orientation[4],frame_orientation[5])
 
-    return [frame_data,frame_time,center,pixel_size,frame_orientation]
+        return [frame_data,frame_time,center,pixel_size,frame_orientation]
 
-def read_CT(mypath):
-    onlyfiles=filelist(mypath)
-    origin=re.sub('([^:/])://(.*)$',r'\1',mypath)
+    def read_CT(self,mypath):
+        onlyfiles=self.filelist(mypath)
+        origin=re.sub('([^:/])://(.*)$',r'\1',mypath)
 
-    ct_data = []
-    ct_idx = []
-    ct_z = []
-    ct_pixel_size = [0,0,0]
-    ct_center = [0,0,0]
-    ct_center[2]=1e30
-    ct_orientation=[0,0,0,0,0,0]
-    for f in onlyfiles:
-        print '{}:'.format(f)
+        ct_data = []
+        ct_idx = []
+        ct_z = []
+        ct_pixel_size = [0,0,0]
+        ct_center = [0,0,0]
+        ct_center[2]=1e30
+        ct_orientation=[0,0,0,0,0,0]
+        for f in onlyfiles:
+            print '{}:'.format(f)
 
-        g,ok=getfile(origin,f)
-        if not(ok):
+            g,ok=self.getfile(origin,f)
+            if not(ok):
                 return
 
-        try:
-            plan = dicom.read_file(g)
-        except:
-            print ("Not a dicom file")
-            continue
+            try:
+                plan = dicom.read_file(g)
+            except:
+                print ("Not a dicom file")
+                continue
 
-        if plan.Modality != 'CT':
-            print ('Not a CT file')
-            continue
+            if plan.Modality != 'CT':
+                print ('Not a CT file')
+                continue
 
         #this doesn't work in 2019 data version
         #if re.match("AC",plan.SeriesDescription) == None:
         #    print (plan.SeriesDescription)
         #    print ('Not a AC file')
         #    continue
-        try:
-            iType=plan.ImageType
-        except:
-            print "Image type not found"
-            continue;
-
-        if iType[3].find("SPI")<0:
-            print "Not a spiral image"
-            continue;
-
-
-        #a slice of pure CT
-        print '.',
-        ct_data.append(plan.pixel_array)
-        ct_idx.append(plan.InstanceNumber)
-        ct_z.append(plan.ImagePositionPatient[2])
-        #ct_center.append(plan.ImagePositionPatient)
-
-        pixel_size_read=[plan.PixelSpacing[0],plan.PixelSpacing[1],
-            plan.SliceThickness]
-
-
-        for i in range(0,3):
-            if ct_pixel_size[i] == 0:
-                ct_pixel_size[i] = float(pixel_size_read[i])
-            if abs(ct_pixel_size[i]-pixel_size_read[i]) > 1e-3:
-                print 'Pixel size mismatch {.2f}/{.2f}'.format(ct_pixel_size[i],
-                pixel_size_read[i])
-
-        for i in range(0,2):
-            if ct_center[i] == 0:
-                ct_center[i] = float(plan.ImagePositionPatient[i])
-            if abs(ct_center[i]-plan.ImagePositionPatient[i]) > 1e-3:
+            try:
+                iType=plan.ImageType
+            except:
+                print "Image type not found"
+                continue;
+
+            if iType[3].find("SPI")<0:
+                print "Not a spiral image"
+                continue;
+
+
+    #a slice of pure CT
+            print '.',
+            ct_data.append(plan.pixel_array)
+            ct_idx.append(plan.InstanceNumber)
+            ct_z.append(plan.ImagePositionPatient[2])
+
+            pixel_size_read=[plan.PixelSpacing[0],plan.PixelSpacing[1],
+                plan.SliceThickness]
+
+
+            for i in range(0,3):
+                if ct_pixel_size[i] == 0:
+                    ct_pixel_size[i] = float(pixel_size_read[i])
+                if abs(ct_pixel_size[i]-pixel_size_read[i]) > 1e-3:
+                    print 'Pixel size mismatch {.2f}/{.2f}'.format(ct_pixel_size[i],
+                        pixel_size_read[i])
+
+            for i in range(0,2):
+                if ct_center[i] == 0:
+                    ct_center[i] = float(plan.ImagePositionPatient[i])
+                if abs(ct_center[i]-plan.ImagePositionPatient[i]) > 1e-3:
                         print 'Image center mismatch {.2f}/{.2f}'.format(ct_center[i],
                         plan.ImagePositionPatient[i])
         #not average, but minimum (!) why??
 
-        if plan.ImagePositionPatient[2]<ct_center[2]:
-            ct_center[2]=plan.ImagePositionPatient[2]
+            if plan.ImagePositionPatient[2]<ct_center[2]:
+                ct_center[2]=plan.ImagePositionPatient[2]
 
-        for i in range(0,6):
-            if ct_orientation[i] == 0:
-                ct_orientation[i] = float(plan.ImageOrientationPatient[i])
-            if abs(ct_orientation[i]-plan.ImageOrientationPatient[i]) > 1e-3:
-                print 'Image orientation mismatch {0:.2f}/{1:.2f}'.format(ct_orientation[i],
-                plan.ImageOrientationPatient[i])
+            for i in range(0,6):
+                if ct_orientation[i] == 0:
+                    ct_orientation[i] = float(plan.ImageOrientationPatient[i])
+                if abs(ct_orientation[i]-plan.ImageOrientationPatient[i]) > 1e-3:
+                    print 'Image orientation mismatch {0:.2f}/{1:.2f}'.format(ct_orientation[i],
+                    plan.ImageOrientationPatient[i])
 
-    print
-    nz=len(ct_idx)
+        print
+        nz=len(ct_idx)
     #not average, again
     #ct_center[2]/=nz
-    sh=ct_data[-1].shape
-    sh_list=list(sh)
-    sh_list.append(nz)
-    data_array=np.zeros(sh_list)
+        sh=ct_data[-1].shape
+        sh_list=list(sh)
+        sh_list.append(nz)
+        data_array=np.zeros(sh_list)
 
-    for k in range(0,nz):
-        kp=int(np.round((ct_z[k]-ct_center[2])/ct_pixel_size[2]))
-        data_array[:,:,kp]=np.transpose(ct_data[k])
+        for k in range(0,nz):
+            kp=int(np.round((ct_z[k]-ct_center[2])/ct_pixel_size[2]))
+            data_array[:,:,kp]=np.transpose(ct_data[k])
 
-    return data_array,ct_center,ct_pixel_size,ct_orientation
+        return data_array,ct_center,ct_pixel_size,ct_orientation