import vtk, qt, ctk, slicer
import numpy as np
import SimpleITK as sitk

#set of routines to transform images from one form to another, most notably
#numpy to vtk to itk and all possible combinations inbetween. Keep track of
#orientation, origin and spacing between transforms

class vtkInterface:
  def __init__(self, parent):
    #parent.title = "vtk Interface"
    #parent.categories = ["LabKey"]
    parent.dependencies = []
    parent.title ="vtkInterface"
    parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
    parent.helpText = """
    Convert native numpy data structures to vtk
    """
    parent.acknowledgementText = """
    This module was developed within the frame of the ARRS sponsored medical
    physics research programe to investigate quantitative measurements of cardiac
    function using sestamibi-like tracers
    """ # replace with organization, grant and thanks.
    self.parent = parent


def numpyToVTK(numpy_array, shape, data_type=vtk.VTK_FLOAT):
    v=vtk.vtkImageData()
    v.GetPointData().SetScalars(
        vtk.util.numpy_support.numpy_to_vtk(
            np.ravel(numpy_array,order='F'),deep=True, array_type=data_type))
    v.SetOrigin(0,0,0)
    v.SetSpacing(1,1,1)
    v.SetDimensions(shape)
    return v


def completeOrientation(orientation):
    o=orientation
    o.append(o[1]*o[5]-o[2]*o[4])#0,3
    o.append(o[2]*o[3]-o[0]*o[5])#1,4
    o.append(o[0]*o[4]-o[1]*o[3])#2,5
    return o


def ITK2VTK(img):
    #convert itk to vtk format.
    #get the array
    data=sitk.GetArrayFromImage(img)
    #reverse the shape (don't ask, look at vtk manual if really curios)
    shape=list(reversed(data.shape))
    return numpyToVTK(data.ravel(),shape)

def VTK2ITK(v):
    #convert vtk image to sitk image
    #convert to numpy first and then go to sitk
    scalars=v.GetPointData().GetScalars()
    shape=v.GetDimensions()
    data=vtk.util.numpy_support.vtk_to_numpy(scalars)
    #now convert to sitk (notice the little reversal of the shape)
    return sitk.GetImageFromArray(np.reshape(data,list(reversed(shape))))

def ITKfromNode(nodeName):
    #use node as data source and generate an itk image
    node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
    if node==None:
        print("Node {0} not available".format(nodeName))
        return

    img=VTK2ITK(node.GetImageData())

    img.SetOrigin(node.GetOrigin())
    img.SetSpacing(node.GetSpacing())
    m=vtk.vtkMatrix4x4()
    node.GetIJKToRASDirectionMatrix(m)
    orientation=[0]*9
    for i in range(0,3):
        for j in range (0,3):
            orientation[3*j+i]=m.GetElement(i,j)
    img.SetDirection(orientation)
    return img



def ITKtoNode(img,nodeName):
    #display itk image and assign it a volume node
    #useful for displaying outcomes of itk calculations
    node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
    if node==None:
        node=slicer.vtkMRMLScalarVolumeNode()
        node.SetName(nodeName)
        slicer.mrmlScene.AddNode(node)

    node.SetAndObserveImageData(ITK2VTK(img))

    #hairy - keep orientation, spacing and origin from node and pass it to itk
    #for future reference
    spacing=img.GetSpacing()
    orientation=img.GetDirection()
    origin=img.GetOrigin()

    #we should get orientation, spacing and origin from somewhere
    ijkToRAS = vtk.vtkMatrix4x4()
    
    for i in range(0,3):
       for j in range(0,3):
           ijkToRAS.SetElement(i,j,spacing[i]*orientation[3*j+i])

       ijkToRAS.SetElement(i,3,origin[i])

    node.SetIJKToRASMatrix(ijkToRAS)