import os
import sys
import pydicom
import numpy as np
import re
import pathlib

class parseDicom:

    def __init__(self):
      pass

    def setFileBrowser(self,fb):
        self.fb=fb

    def setTempBase(self,tb):
        self.tempBase=tb

    def filelist(self,mypath,remote=True):
        if remote:
            print("Using labkey")
            #not sure if labkey is available, so try it
            ok, files=self.fb.listRemoteDir(mypath)
            if not ok:
                print("Error accessing path")
                return []
            #files=[self.net.GetRelativePathFromLabkeyPath(f) for f in files]

        else:
            print("Using local files")
            #localPath=re.sub('file://','',mypath)
            localPath=mypath
            files = [os.path.join(localPath,f) for f in os.listdir(localPath)
                if os.path.isfile(os.path.join(localPath, f))]

        return files

    def getfile(self,f,remote=True):

        if remote:
            try:
                #not sure if labkey is available, but try it
                print("Using labkey")
                p=pathlib.Path(f)
                localPath=os.path.join(self.tempBase,p.name)
                self.fb.readFileToFile(f,localPath)
                return [open(localPath,'rb'),1]
            except:
                print('Could not access labkey. Exiting')
                return ['NULL',0]

        else:
            print("Using local directory")
            return [open(f,'rb'),1]

        return ['NULL',0]

    def readMasterFile(self,g):
        #this is the "master" file where data on other files can be had
        #here we found out the duration of the frame and their distribution through
        #phases and cycles
        try:
            plan = pydicom.dcmread(g)
        except:
            print ("{}: Not a dicom file".format(g))
            return False

        try:
            self.nframe=plan[0x0019,0x10a5].value;
        except:
            print ("{}: Not a master file".format(g))
            return False
        if not (type(self.nframe) is list) :
            print("nframe not a list")
            return False

        #nframe now holds for index i total number of frames collected up
        #to the end of each phase

        for i in range(1,len(self.nframe)):
            self.nframe[i]+=self.nframe[i-1]

        self.frame_start=plan[0x0019,0x10a7].value
        self.frame_stop=plan[0x0019,0x10a8].value
        self.frame_duration=plan[0x0019,0x10a9].value

        self.frame_time=np.zeros(self.nframe[-1]);
        self.frame_data=np.empty([1,1,1,self.nframe[-1]])
        self.center = [0,0,0]
        self.pixel_size =[0,0,0]
        self.frame_orientation=[0,0,0,0,0,0]
        return True

    def readNMFile(self,g):

        try:
            plan = pydicom.dcmread(g)
        except:
            print ("{}: Not a dicom file".format(g))
            return False

        try:
            pf=plan[0x0018,0x5020]
        except:
            print("Not a NM file. Exiting")
            return False

        try:
            phase=plan[0x0035,0x1005].value
            cycle=plan[0x0035,0x1004].value
        except:
            print("Missing phase/cycle values")
            return False

        #convert phase/cycle to frame index
        off=0
        if phase > 1:
            off=self.nframe[phase-2]
        ifi=off+cycle-1

        #from values in the master file determine frame time
        #(as the mid point between starting and ending the frame)
        self.frame_time[ifi]=0.5*(self.frame_start[ifi]+self.frame_stop[ifi]); #in ms

        print("({},{}) converted to {} at {} for {}".format(\
            phase,cycle,ifi,self.frame_time[ifi],self.frame_duration[ifi]))


#play with pixel data
        if self.frame_data.shape[0] == 1:
            sh=np.transpose(plan.pixel_array,self.axisShift).shape;
            sh=list(sh)
            sh.append(self.nframe[-1])#add number of time slots
            self.frame_data=np.empty(sh)
            print(" Setting frame_data to {}".format(sh))

        #check & update pixel size
        pixel_size_read=[plan.PixelSpacing[0],plan.PixelSpacing[1],
                    plan.SliceThickness]

        for i in range(0,3):
            if self.pixel_size[i] == 0:
                self.pixel_size[i] = float(pixel_size_read[i])
            if abs(self.pixel_size[i]-pixel_size_read[i]) > 1e-3:
                print('Pixel size mismatch {.2f}/{.2f}'.format(self.pixel_size[i],\
                pixel_size_read[i]))

        center_read=plan.DetectorInformationSequence[0].ImagePositionPatient
        print("Stored center at ({0},{1},{2})".format(self.center[0],self.center[1],self.center[2]))
        print("Read   center at ({0},{1},{2})".format(center_read[0],center_read[1],center_read[2]))
        for i in range(0,3):
            if self.center[i] == 0:
                self.center[i] = float(center_read[i])
            if abs(self.center[i]-center_read[i]) > 1e-3:
                print('Image center mismatch {.2f}/{.2f}'.format(self.center[i],\
                    center_read[i]))

        frame_orientation_read=plan.DetectorInformationSequence[0].ImageOrientationPatient
        for i in range(0,6):
            if self.frame_orientation[i] == 0:
                self.frame_orientation[i] = float(frame_orientation_read[i])
            if abs(self.frame_orientation[i]-frame_orientation_read[i]) > 1e-3:
                print('Image orientation mismatch {.2f}/{.2f}'.format(
                    self.frame_rotation[i], frame_orientation_read[i]))




        self.frame_data[:,:,:,ifi]=np.transpose(plan.pixel_array,self.axisShift)

        return True

    def readCTFile(self,g):

        try:
            plan = pydicom.dcmread(g)
        except:
            print ("{}: Not a dicom file".format(g))
            return False


        if plan.Modality != 'CT':
            print ('{}: Not a CT file'.format(g))
            return False

        #this doesn't work in 2019 data version
        #if re.match("AC",plan.SeriesDescription) == None:
        #    print (plan.SeriesDescription)
        #    print ('Not a AC file')
        #    continue
        try:
            iType=plan.ImageType
        except:
            print("Image type not found")
            return False

        if iType[3].find("SPI")<0:
            print("Not a spiral image")
            return False



        self.ct_data.append(\
                pydicom.pixel_data_handlers.util.apply_modality_lut(\
                plan.pixel_array,plan))

        self.ct_idx.append(plan.InstanceNumber)
        self.ct_z.append(plan.ImagePositionPatient[2])

        pixel_size_read=[plan.PixelSpacing[0],plan.PixelSpacing[1],
            plan.SliceThickness]


        for i in range(0,3):
            if self.ct_pixel_size[i] == 0:
                self.ct_pixel_size[i] = float(pixel_size_read[i])
            if abs(self.ct_pixel_size[i]-pixel_size_read[i]) > 1e-3:
                print('Pixel size mismatch {.2f}/{.2f}'.format(self.ct_pixel_size[i],
                    pixel_size_read[i]))

        for i in range(0,2):
            if self.ct_center[i] == 0:
                self.ct_center[i] = float(plan.ImagePositionPatient[i])
            if abs(self.ct_center[i]-plan.ImagePositionPatient[i]) > 1e-3:
                    print('Image center mismatch {.2f}/{.2f}'.format(self.ct_center[i],
                    plan.ImagePositionPatient[i]))
    #not average, but minimum (!) why??

        if plan.ImagePositionPatient[2]<self.ct_center[2]:
            self.ct_center[2]=plan.ImagePositionPatient[2]

        for i in range(0,6):
            if self.ct_orientation[i] == 0:
                self.ct_orientation[i] = float(plan.ImageOrientationPatient[i])
            if abs(self.ct_orientation[i]-plan.ImageOrientationPatient[i]) > 1e-3:
                print('Image orientation mismatch {0:.2f}/{1:.2f}'.format(self.ct_orientation[i],\
                plan.ImageOrientationPatient[i]))

        return True

    def readMasterDirectory(self,mypath,remote=True):
        self.axisShift=(2,1,0)

        print("Reading master from {}".format(mypath))

        filelist=self.filelist(mypath,remote)
        for f in filelist:
            print('{}:'.format(f))

            g,ok=self.getfile(f,remote)
            if not(ok):
                return

            if self.readMasterFile(g):
                break


    def readNMDirectory(self,mypath,remote=True):

        files=self.filelist(mypath,remote)
        
        for f in files:

            g,ok=self.getfile(f,remote)
            if not(ok):
                continue

            self.readNMFile(g)





        return [self.frame_data,self.frame_time,self.frame_duration,self.center,
            self.pixel_size,self.frame_orientation]

    def readCTDirectory(self,mypath,remote=True):
        onlyfiles=self.filelist(mypath,remote)

        self.ct_data = []
        self.ct_idx = []
        self.ct_z = []
        self.ct_pixel_size = [0,0,0]
        self.ct_center = [0,0,0]
        self.ct_center[2]=1e30
        self.ct_orientation=[0,0,0,0,0,0]
        for f in onlyfiles:
            print('{}:'.format(f))

            g,ok=self.getfile(f,remote)
            if not(ok):
                return

            self.readCTFile(g)

        nz=len(self.ct_idx)
        #not average, again
        #ct_center[2]/=nz
        sh=self.ct_data[-1].shape
        sh_list=list(sh)
        sh_list.append(nz)
        data_array=np.zeros(sh_list)

        for k in range(0,nz):
            kp=int(np.round((self.ct_z[k]-self.ct_center[2])/self.ct_pixel_size[2]))
            data_array[:,:,kp]=np.transpose(self.ct_data[k])

        return data_array,self.ct_center,self.ct_pixel_size,self.ct_orientation