import numpy
import scipy.interpolate

def getGeometry(img):
   origin=numpy.array(img.GetOrigin())
   direction=numpy.array(img.GetDirection())
   direction=numpy.reshape(direction,(3,3))
   spacing=numpy.array(img.GetSpacing())
   #print(origin)
   #print(direction)
   #print(spacing)
   class geometry:pass
   geometry.origin=origin
   geometry.spacing=spacing
   geometry.direction=direction
   return geometry

def pixelToVector(geometry,pixel):
   #accounts for reverse order of coordiantes in numpy array relative to SimpleITK image
   return numpy.dot(geometry.direction,numpy.flip(pixel)*geometry.spacing)+geometry.origin

def vectorToPixel(geometry,vector):

   #accounts for reverse order of coordinates in numpy array relative to SimpleITK image
   return numpy.flip(numpy.dot(geometry.direction.transpose(),vector-geometry.origin)/geometry.spacing)

def toSpace2(spect,gSPECT,ct,gCT,method='linear'):
   #convert array spect with geometry gSPECT to an array of size equal to CT w/ corresponding geometry gCT
   out=numpy.zeros(ct.shape)
   pixels=[]
   for i in range(out.shape[0]):
      print('{}/{}'.format(i,out.shape[0]))
      for j in range(out.shape[1]):
         for k in range(out.shape[2]):
            pixel=[i,j,k]
            v=pixelToVector(gCT,pixel)
            pixelSPECT=vectorToPixel(gSPECT,v)
            pixels.append(pixelSPECT)
   print('Interpolating {} pixels'.format(len(pixels)))
   outs=interpolate(spect,pixels,method=method)
   print('Done')
   m=0
   for i in range(out.shape[0]):
      for j in range(out.shape[1]):
         for k in range(out.shape[2]):
            out[i,j,k]=outs[m]
            m+=1
    
   return out

def interpolate(ar,c,method='linear'):

   points=[numpy.linspace(0.5,ar.shape[i]-0.5,ar.shape[i]) for i in range(3)]
   return scipy.interpolate.interpn(points,ar,c,method=method,fill_value=-1,bounds_error=False)