import os
import sys
import unittest
import vtk, qt, ctk, slicer
from slicer.ScriptedLoadableModule import *
import logging
import slicer
import numpy as np
import json
import re
#
# cardiacSPECT
#

class cardiacSPECT(ScriptedLoadableModule):
  """Uses ScriptedLoadableModule base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """
  def __init__(self, parent):
    ScriptedLoadableModule.__init__(self, parent)
    parent.title = "Cardiac SPECT"
    parent.categories = ["dynamicSPECT"]
    parent.dependencies = []
    parent.contributors = ["Andrej Studen (FMF/JSI)"] # replace with "Firstname Lastname (Org)"
    parent.helpText = """
    Load dynamic cardiac SPECT data to Slicer
    """
    parent.acknowledgementText = """
    This module was developed within the frame of the ARRS sponsored medical
    physics research programe to investigate quantitative measurements of cardiac
    function using sestamibi-like tracers
    """ # replace with organization, grant and thanks.
    self.parent.helpText += self.getDefaultModuleDocumentationLink()
    self.parent = parent

#
# cardiacSPECTWidget
#

class cardiacSPECTWidget(ScriptedLoadableModuleWidget):
  """Uses ScriptedLoadableModuleWidget base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """

  def setup(self):
    ScriptedLoadableModuleWidget.setup(self)


    #load config
    configFile=os.path.join(os.path.expanduser('~'),\
              '.cardiacSPECT','cardiacSPECT.json')
    with open(configFile) as f:
          self.cfg=json.load(f)

    #need labkey browser here
    sconfig=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
    with open(sconfig) as f:
            setup=json.load(f)
    sys.path.append(setup['paths']['labkeyInterface'])
    import labkeyInterface
    import labkeyDatabaseBrowser
    import labkeyFileBrowser

    self.net=labkeyInterface.labkeyInterface()
    fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
    self.net.init(fconfig)
    self.db=labkeyDatabaseBrowser.labkeyDB(self.net)
      
    ds=self.db.selectRows(self.cfg['project'],self.cfg['schemaName'],\
            self.cfg['queryName'],[])

    patients=list(set([row['aliasID'] for row in ds['rows']]))

    # Instantiate and connect widgets ...
    dataButton = ctk.ctkCollapsibleButton()
    dataButton.text = "Data"
    self.layout.addWidget(dataButton)

    # Layout within the sample collapsible button
    dataFormLayout = qt.QFormLayout(dataButton)



    self.patientId=qt.QComboBox()
    self.patientId.insertItems(0,patients)
    dataFormLayout.addRow('Patient ID', self.patientId)

    self.refPatientId=qt.QComboBox();
    self.refPatientId.insertItems(0,patients)
    dataFormLayout.addRow('Reference Patient ID', self.refPatientId)


    patientLoadButton = qt.QPushButton("Load")
    patientLoadButton.toolTip="Load data from DICOM"
    dataFormLayout.addRow("Patient",patientLoadButton)
    patientLoadButton.clicked.connect(self.onPatientLoadButtonClicked)

    patientLoadNRRDButton = qt.QPushButton("Load NRRD")
    patientLoadNRRDButton.toolTip="Load data from NRRD"
    dataFormLayout.addRow("Patient",patientLoadNRRDButton)
    patientLoadNRRDButton.clicked.connect(self.onPatientLoadNRRDButtonClicked)

    loadSegmentationButton = qt.QPushButton("Load")
    loadSegmentationButton.toolTip="Load segmentation from server"
    dataFormLayout.addRow("Segmentation",loadSegmentationButton)
    loadSegmentationButton.clicked.connect(self.onLoadSegmentationButtonClicked)

    self.modelParameter=qt.QLineEdit('k1');
    dataFormLayout.addRow('Model Parameter', self.modelParameter)
    
    loadModelButton = qt.QPushButton("Load")
    loadModelButton.toolTip="Load model parameters from server"
    dataFormLayout.addRow("Model",loadModelButton)
    loadModelButton.clicked.connect(self.onLoadModelButtonClicked)

    saveVolumeButton = qt.QPushButton("Save")
    saveVolumeButton.toolTip="Save volume to NRRD"
    dataFormLayout.addRow("Volume",saveVolumeButton)
    saveVolumeButton.clicked.connect(self.onSaveVolumeButtonClicked)

    saveSegmentationButton = qt.QPushButton("Save")
    saveSegmentationButton.toolTip="Save segmentation to NRRD"
    dataFormLayout.addRow("Segmentation",saveSegmentationButton)
    saveSegmentationButton.clicked.connect(self.onSaveSegmentationButtonClicked)

    saveTransformationButton = qt.QPushButton("Save")
    saveTransformationButton.toolTip="Save transformation to NRRD"
    dataFormLayout.addRow("Transformation",saveTransformationButton)
    saveTransformationButton.clicked.connect(self.onSaveTransformationButtonClicked)

    saveInputFunctionButton = qt.QPushButton("Save")
    saveInputFunctionButton.toolTip="Save InputFunction to NRRD"
    dataFormLayout.addRow("InputFunction",saveInputFunctionButton)
    saveInputFunctionButton.clicked.connect(self.onSaveInputFunctionButtonClicked)

    transformNodeButton = qt.QPushButton("Transform Nodes")
    transformNodeButton.toolTip="Transform node with patient based transform"
    dataFormLayout.addRow("Transform Nodes",transformNodeButton)
    transformNodeButton.clicked.connect(self.onTransformNodeButtonClicked)


    # Add vertical spacer
    self.layout.addStretch(1)

    #addFrameButton=qt.QPushButton("Add Frame")
    #addFrameButton.toolTip="Add frame to VTK"
    #dataFormLayout.addWidget(addFrameButton)
    #addFrameButton.connect('clicked(bool)',self.onAddFrameButtonClicked)

    #addCTButton=qt.QPushButton("Add CT")
    #addCTButton.toolTip="Add CT to VTK"
    #dataFormLayout.addWidget(addCTButton)
    #addCTButton.connect('clicked(bool)',self.onAddCTButtonClicked)

    #
    # Parameters Area
    #

    parametersCollapsibleButton = ctk.ctkCollapsibleButton()
    parametersCollapsibleButton.text = "Parameters"
    self.layout.addWidget(parametersCollapsibleButton)

    # Layout within the dummy collapsible button
    parametersFormLayout = qt.QFormLayout(parametersCollapsibleButton)

    #
    # check box to trigger taking screen shots for later use in tutorials
    #
    hbox1=qt.QHBoxLayout()

    frameLabel = qt.QLabel()
    frameLabel.setText("Select frame")
    hbox1.addWidget(frameLabel)

    self.time_frame_select=qt.QSlider(qt.Qt.Horizontal)
    self.time_frame_select.valueChanged.connect(self.onTimeFrameSelect)

    #self.time_frame_select.connect('valueChanged()', self.onTimeFrameSelect)
    self.time_frame_select.setMinimum(0)
    self.time_frame_select.setMaximum(0)
    self.time_frame_select.setValue(0)
    self.time_frame_select.setTickPosition(qt.QSlider.TicksBelow)
    self.time_frame_select.setTickInterval(5)
    self.time_frame_select.toolTip = "Select the time frame"
    hbox1.addWidget(self.time_frame_select)

    parametersFormLayout.addRow(hbox1)

    hbox2 = qt.QHBoxLayout()

    meanROILabel = qt.QLabel()
    meanROILabel.setText("MeanROI")
    hbox2.addWidget(meanROILabel)

    self.meanROIVolume = qt.QLineEdit()
    self.meanROIVolume.setText("testVolume15")
    hbox2.addWidget(self.meanROIVolume)

    self.meanROISegment = qt.QLineEdit()
    self.meanROISegment.setText("Segment_1")
    hbox2.addWidget(self.meanROISegment)

    computeMeanROI = qt.QPushButton("Compute mean ROI")
    computeMeanROI.connect('clicked(bool)',self.onComputeMeanROIClicked)
    hbox2.addWidget(computeMeanROI)

    self.meanROIResult = qt.QLineEdit()
    self.meanROIResult.setText("0")
    hbox2.addWidget(self.meanROIResult)


    parametersFormLayout.addRow(hbox2)

    #row 3
    hbox3 = qt.QHBoxLayout()

    drawTimePlot=qt.QPushButton("Draw ROI time plot")
    drawTimePlot.connect('clicked(bool)',self.onDrawTimePlotClicked)
    hbox3.addWidget(drawTimePlot)

    parametersFormLayout.addRow(hbox3)
    #dataFormLayout.addWidget(hbox)

    #row 4
    hbox4 = qt.QHBoxLayout()

    countSegments=qt.QPushButton("Count segmentation segments")
    countSegments.connect('clicked(bool)',self.onCountSegmentsClicked)
    hbox4.addWidget(countSegments)

    self.countSegmentsDisplay=qt.QLineEdit()
    self.countSegmentsDisplay.setText("0")
    hbox4.addWidget(self.countSegmentsDisplay)
    parametersFormLayout.addRow(hbox4)


    #
    # Apply Button
    #
    self.applyButton = qt.QPushButton("Apply")
    self.applyButton.toolTip = "Run the algorithm."
    self.applyButton.enabled = False
    parametersFormLayout.addRow(self.applyButton)

    # connections
    self.applyButton.connect('clicked(bool)', self.onApplyButton)
    #self.inputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.onSelect)
    #self.outputSelector.connect("currentNodeChanged(vtkMRMLNode*)", self.onSelect)

    # Add vertical spacer
    self.layout.addStretch(1)

    self.resetPosition=1

    #add logic aware of all GUI elements on page
    self.logic=cardiacSPECTLogic(self.cfg)

  def cleanup(self):
    pass

  def onApplyButton(self):
      pass
    #logic = cardiacSPECTLogic()
    #imageThreshold = self.imageThresholdSliderWidget.value

  def onBrowseButtonClicked(self):
      startDir=self.dataPath.text
      inputDir=qt.QFileDialog.getExistingDirectory(None,
          'Select DICOM directory',startDir)

      self.dataPath.setText("file://"+inputDir)

  def onRemoteBrowseButtonClicked(self):
      self.selectRemote.show()

  def onDataLoadButtonClicked(self):
      self.logic.loadData(self)

  def onRemotePathTextChanged(self,str):
      self.dataPath.setText('labkey://'+str)

  def onTimeFrameSelect(self):
       it=self.time_frame_select.value
       selectionNode = slicer.app.applicationLogic().GetSelectionNode()
       print("Propagating CT volume")
       nodeName=self.patientId.currentText+'CT'
       node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
       selectionNode.SetReferenceActiveVolumeID(node.GetID())
       if self.resetPosition==1:
          self.resetPosition=0
          slicer.app.applicationLogic().PropagateVolumeSelection(1)
       else:
          slicer.app.applicationLogic().PropagateVolumeSelection(0)
       print("Propagating SPECT volume")
       nodeName=self.patientId.currentText+'Volume'+str(it)
       node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
       selectionNode.SetSecondaryVolumeID(node.GetID())
       slicer.app.applicationLogic().PropagateForegroundVolumeSelection(0)
       node.GetDisplayNode().SetAndObserveColorNodeID('vtkMRMLColorTableNodeRed')
       lm = slicer.app.layoutManager()
       sID=['Red','Yellow','Green']
       for s in sID:
           sliceLogic = lm.sliceWidget(s).sliceLogic()
           compositeNode = sliceLogic.GetSliceCompositeNode()
           compositeNode.SetForegroundOpacity(0.5)

       #make sure the viewer is matched to the volume
       print("Done")
        #to access sliceLogic (slice control) use
        #lcol=slicer.app.layoutManager().mrmlSliceLogics() (vtkCollection)
        #vtkMRMLSliceLogic are named by colors (Red,Green,Blue)

  def onComputeMeanROIClicked(self):
        s=self.logic.meanROI(self.meanROIVolume.text,self.meanROISegment.text)
        self.meanROIResult.setText(str(s))

  def onDrawTimePlotClicked(self):
        n=self.time_frame_select.maximum+1
        ft=self.logic.frame_time

        #find number of segments
        ns = self.logic.countSegments()

        #add the chart node
        cn = slicer.mrmlScene.AddNode(slicer.vtkMRMLChartNode())

        for j in range(0,ns):
            #add node for data
            dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())

            dn.SetSize(n)
            dn.SetName(self.patientId.currentText+'_'+self.logic.getSegmentName(j))

            dt=0;
            t0=0;
            for i in range(0,n):
                vol=self.patientId.currentText+"Volume"+str(i)
                fx=ft[i]
                fy=self.logic.meanROI(vol,j)
                dt=2*ft[i]-t0
                t0+=dt

                dn.SetValue(i, 0, fx)
                dn.SetValue(i, 1, fy/dt)
                dn.SetValue(i, 2, 0)
                print("[{0} at {1:.2f}:{2:.2f}]".format(vol,fx,fy))


            #fish the number of the segment
            cn.AddArray(self.logic.getSegmentName(j), dn.GetID())

        cn.SetProperty('default', 'title', 'ROI time plot')
        cn.SetProperty('default', 'xAxisLabel', 'time [ms]')
        cn.SetProperty('default', 'yAxisLabel', 'Activity (arb)')

        #update the chart node
        cvns = slicer.mrmlScene.GetNodesByClass('vtkMRMLChartViewNode')
        if cvns.GetNumberOfItems() == 0:
            cvn = slicer.mrmlScene.AddNode(slicer.vtkMRMLChartViewNode())
        else:
            cvn = cvns.GetItemAsObject(0)
        cvn.SetChartNodeID(cn.GetID())

  def onCountSegmentsClicked(self):
      self.countSegmentsDisplay.setText(self.logic.countSegments)
  
  def onPatientLoadButtonClicked(self):
      self.logic.loadPatient(self.patientId.currentText)
      self.time_frame_select.setMaximum(self.logic.frame_data.shape[3]-1)

  def onPatientLoadNRRDButtonClicked(self):
      self.logic.loadPatientNRRD(self.patientId.currentText)
      self.time_frame_select.setMaximum(len(self.logic.frame_time))

  def onLoadSegmentationButtonClicked(self):
      self.logic.loadSegmentation(self.patientId.currentText)

  def onLoadModelButtonClicked(self):
      self.logic.loadModelVolume(self.patientId.currentText,self.modelParameter.text)

  def onSaveVolumeButtonClicked(self):
      self.logic.storeVolumeNodes(self.patientId.currentText,
            self.time_frame_select.minimum,self.time_frame_select.maximum)

  def onSaveSegmentationButtonClicked(self):
      self.logic.storeSegmentation(self.patientId.currentText)

  def onSaveTransformationButtonClicked(self):
      self.logic.storeTransformation(self.patientId.currentText)

  def onSaveInputFunctionButtonClicked(self):
      self.logic.storeInputFunction(self.patientId.currentText)

  def onTransformNodeButtonClicked(self):
      self.logic.applyTransform(self.patientId.currentText, self.refPatientId.currentText,
          self.time_frame_select.minimum,self.time_frame_select.maximum)


#def onAddFrameButtonClicked(self):
#      it=int(self.time_frame_select.text)
#      self.logic.addFrame(it)

 # def onAddCTButtonClicked(self):
#      self.logic.addCT()
#
#
# cardiacSPECTLogic
#

class cardiacSPECTLogic(ScriptedLoadableModuleLogic):
  """This class should implement all the actual
  computation done by your module.  The interface
  should be such that other python code can import
  this class and make use of the functionality without
  requiring an instance of the Widget.
  Uses ScriptedLoadableModuleLogic base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """
  def __init__(self,config):
      ScriptedLoadableModuleLogic.__init__(self)
      sconfig=os.path.join(os.path.expanduser('~'),'.labkey','setup.json')
      with open(sconfig) as f:
            setup=json.load(f)
      sys.path.append(setup['paths']['labkeyInterface'])
      import labkeyInterface
      import labkeyDatabaseBrowser
      import labkeyFileBrowser

      self.net=labkeyInterface.labkeyInterface()
      fconfig=os.path.join(os.path.expanduser('~'),'.labkey','network.json')
      self.net.init(fconfig)
      self.db=labkeyDatabaseBrowser.labkeyDB(self.net)
      self.fb=labkeyFileBrowser.labkeyFileBrowser(self.net)

      sys.path.append(setup['paths']['parseDicom'])
      import parseDicom

      self.pd=parseDicom.parseDicom
      self.pd.setFileBrowser(self.fb)
      
      self.tempPath=os.path.join(os.path.expanduser('~'),\
              'temp','cardiacSPECT')
      if not os.path.isdir(self.tempPath):
          os.makedirs(self.tempPath)
      
      self.pd.setTempBase(self.tempPath)
      
      sys.path.append(setup['paths']['resample'])
      import resample
      self.resampler=resample.resampleLogic(None)
      
      sys.path.append(setup['paths']['vtkInterface'])
      

      
      self.cfg=config
      
  def loadData(self,widget):
    import vtkInterface
    #calculate inputDir from data on form
    inputDir=str(widget.dataPath.text)
    self.pd.readMasterDirectory(inputDir)
    self.frame_data, self.frame_time, self.frame_duration, self.frame_origin, \
        self.frame_pixel_size, \
        self.frame_orientation=self.pd.readNMDirectory(inputDir)

    self.ct_data,self.ct_origin,self.ct_pixel_size, \
        self.ct_orientation=self.pd.readCTDirectory(inputDir)

    self.ct_orientation=vtkInterface.completeOrientation(self.ct_orientation)
    self.frame_orientation=vtkInterface.completeOrientation(self.frame_orientation)

    self.addCT('test')
    self.addFrames('test')

    widget.time_frame_select.setMaximum(self.frame_data.shape[3]-1)

    #additional message via qt
    qt.QMessageBox.information(
        slicer.util.mainWindow(),
        'Slicer Python','Data loaded')


  def getFilespecPath(self,r):
      if self.cfg["remote"]==True:
          return self.fb.formatPathURL(self.cfg['transfer']['project'],\
                  '/'.join([self.cfg['dicomDir'],r['Study'],r['Series']]))
      else:
          path=os.path.join(self.cfg["dicomPath"],r["Study"],r["Series"])
          return path

  def loadPatient(self,patientId):
      
      import vtkInterface
      print("Loading {}".format(patientId))
      idFilter={'variable':'aliasID','value':patientId,'oper':'eq'}
      ds=self.db.selectRows(self.cfg['project'],self.cfg['schemaName'],
              self.cfg['queryName'],[idFilter])
      visit=ds['rows'][0]

      print(visit)
      idFilter={'variable':'PatientId','value':visit['aliasID'],'oper':'eq'}
      dicoms=self.db.selectRows(self.cfg['transfer']['project'],
              self.cfg['transfer']['schemaName'],
              self.cfg['transfer']['queryName'],
              [idFilter])
      
      for r in dicoms['rows']:
          if abs(r['SequenceNum']-float(visit['nmMaster']))<0.1:
              masterPath=self.getFilespecPath(r)
              #masterPath="labkey://"+path
          if abs(r['SequenceNum']-float(visit['nmCorrData']))<0.1:
              nmPath=self.getFilespecPath(r)
              #nmPath="labkey://"+path
          if abs(r['SequenceNum']-float(visit['ctData']))<0.1:
              ctPath=self.getFilespecPath(r)
              #ctPath="labkey://"+path


      self.pd.readMasterDirectory(masterPath)
      self.frame_data, self.frame_time, self.frame_duration,self.frame_origin, \
          self.frame_pixel_size, self.frame_orientation=\
          self.pd.readNMDirectory(nmPath)
      self.ct_data,self.ct_origin,self.ct_pixel_size, \
         self.ct_orientation=self.pd.readCTDirectory(ctPath)

      self.ct_orientation=vtkInterface.completeOrientation(self.ct_orientation)
      self.frame_orientation=vtkInterface.completeOrientation(self.frame_orientation)

      self.addCT(patientId)
      self.addFrames(patientId)


  def loadPatientNRRD(self,patientId):
      print("Loading NRRD {}".format(patientId))

      self.loadDummyInputFunction(patientId)
      dnsNode=slicer.util.getFirstNodeByName(patientId+'_Dummy')
      if dnsNode==None:
          print("Could not find dummy double array node")
          return

      n=dnsNode.GetSize();
      self.frame_time=np.zeros(n);
      self.frame_duration=np.zeros(n);

      a=vtk.reference(1)
      for i in range(0,n):
          self.loadVolume(patientId,i)
          self.frame_time[i]=dnsNode.GetValue(i,0,a)
          self.frame_duration[i]=dnsNode.GetValue(i,1,a)

      self.loadCTVolume(patientId)
      self.loadSegmentation(patientId)

  def loadDummyInputFunction(self,patientId):
      self.loadNode(patientId,patientId+'_Dummy','DoubleArrayFile','.mcsv')

  def loadVolume(self,patientId,i):
      self.loadNode(patientId,patientId+'Volume'+str(i),'VolumeFile')

  def loadCTVolume(self,patientId):
      self.loadNode(patientId,patientId+'CT','VolumeFile')

  def loadModelVolume(self,patientId,name):
      node=self.loadNode(patientId,name,'VolumeFile')
      if node:
          node.SetName(patientId+'_'+name)

  def loadSegmentation(self,patientId):
      self.loadNode(patientId,'Heart','SegmentationFile')

  def loadNode(self,patientId,fName,type,suffix='.nrrd'):
      remotePath=self.fb.formatPathURL(self.cfg['project'],
              '/'.join([patientId]))
      labkeyFile=remotePath+'/'+fName+suffix
      localFile=os.path.join(self.tempPath,fName+suffix)
      self.fb.readFileToFile(labkeyFile,localFile)
      print("Remote: {}".format(labkeyFile))
      try:
          node=slicer.util.loadNodeFromFile(localFile,type)
      except RuntimeError:
          return None    
      os.remove(localFile)
      return node
      

  def addNode(self,nodeName,v, lpsOrigin, pixel_size, \
          lpsOrientation, dataType):

       # if dataType=0 it is CT data, which gets propagated to background an is
       #used to fit the view field dimensions

       # if dataType=1, it is SPECT data, which gets propagated to foreground
       #and is not fit; keeping window set from CT

       #nodeName='testVolume'+str(it)
       newNode=slicer.vtkMRMLScalarVolumeNode()
       newNode.SetName(nodeName)

       #pixel_size=[0,0,0]
       #pixel_size=v.GetSpacing()
       #print(pixel_size)
       #origin=[0,0,0]
       #origin=v.GetOrigin()
       v.SetOrigin([0,0,0])
       v.SetSpacing([1,1,1])
       ijkToRAS = vtk.vtkMatrix4x4()
       #think how to do this with image orientation
       rasOrientation=[-lpsOrientation[i] if (i%3 < 2) else lpsOrientation[i]
        for i in range(0,len(lpsOrientation))]
       rasOrigin=[-lpsOrigin[i] if (i%3<2) else lpsOrigin[i] for i in range(0,len(lpsOrigin))]

       for i in range(0,3):
           for j in range(0,3):
               ijkToRAS.SetElement(i,j,pixel_size[i]*rasOrientation[3*j+i])

           ijkToRAS.SetElement(i,3,rasOrigin[i])

       newNode.SetIJKToRASMatrix(ijkToRAS)
       newNode.SetAndObserveImageData(v)
       slicer.mrmlScene.AddNode(newNode)

  def addFrames(self,patientId):
       import vtkInterface
       #convert data from numpy.array to vtkImageData
       #use time point it
       print("NFrames: {}".format(self.frame_data.shape[3]))
       for it in range(0,self.frame_data.shape[3]):
           frame_data=self.frame_data[:,:,:,it];
           nodeName=patientId+'Volume'+str(it)
           self.addNode(nodeName,
              vtkInterface.numpyToVTK(frame_data,frame_data.shape),
              self.frame_origin,
              self.frame_pixel_size,
              self.frame_orientation,1)

  def addCT(self,patientId):
       import vtkInterface
       nodeName=patientId+'CT'
       self.addNode(nodeName,
            #vi.numpyToVTK3D(self.ct_data,
            #    self.ct_origin,self.ct_pixel_size),
            vtkInterface.numpyToVTK(self.ct_data,self.ct_data.shape),
            self.ct_origin,self.ct_pixel_size,
            self.ct_orientation,0)

  def rFromI(i,volumeNode):
      ijkToRas = vtk.vtkMatrix4x4()
      volumeNode.GetIJKToRASMatrix(ijkToRas)
      vImage=volumeNode.GetImageData()
      i1=list(vImage.GetPoint(i))
      i1=i1.append(1)
      #ras are global coordinates (in mm)
      position_ras=ijkToRas.MultiplyPoint(i1)
      return position_ras[0:3]

  def IfromR(pos,volumeNode):
      fM=vtk.vtkMatrix4x4()
      volumeNode.GetRASToIJKMatrix(fM)
      fM.MultiplyPoint(pos)
      vImage=volumeNode.GetImageData()
      #nearest neighbor
      return vImage.FindPoint(pos[0:3])




  def getMaskPos(self,mask,i):
      maskIJK=mask.GetPoint(i)
      maskIJK=[r-c for r,c in zip(maskIJK,mask.GetOrigin())]
      maskIJK=[r/s for r,s in zip(maskIJK,mask.GetSpacing())]

      #this is now in extent spacing, whitch ImageToWorldMatrix understands

      #to 4D vector for vtkMatrix4x4 handling
      maskIJK.append(1)

      #go to ras, global coordinates (in mm)
      maskImageToWorldMatrix=vtk.vtkMatrix4x4()
      mask.GetImageToWorldMatrix(maskImageToWorldMatrix)

      maskPos=maskImageToWorldMatrix.MultiplyPoint(maskIJK)
      return maskPos[0:3]

  def meanROI(self, volName1, i):
    s=0

    #get the segmentation mask
    fNode=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode").GetItemAsObject(0)
    print("Found segmentation node: {}".format(fNode.GetName()))
    segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)

    #no python bindings for vtkSegmentation

    #if segNode.GetSegmentation().GetNumberOfSegments()==0 :
    #    print("No segments available")
    #    return 0

    #edit here to change for more segments
    segment=segNode.GetSegmentation().GetNthSegmentID(int(i))
    print('Computing for segment {}'.format(segment))
    mask=slicer.vtkOrientedImageData()
    segNode.GetBinaryLabelmapRepresentation(segment,mask)
    if mask==None:
        print("Segment {} not found".format(segment))
        return s

    print("Got mask for segment {}, npts {}".format(segment,mask.GetNumberOfPoints()))
    #get mask at (x,y,z)
    #mask.GetPointData().GetScalars().GetTuple1(mask.FindPoint([x,y,z]))

    #get the image data
    dataNode=slicer.mrmlScene.GetFirstNodeByName(volName1)
    dataImage=dataNode.GetImageData()
    # use IJK2RAS to get global coordinates
    dataRAStoIJK = vtk.vtkMatrix4x4()
    dataNode.GetRASToIJKMatrix(dataRAStoIJK)

    #allow for interpolation in segmentation pixels
    coeff=vtk.vtkImageBSplineCoefficients()
    coeff.SetInputData(dataImage)
    coeff.SetBorderMode(vtk.VTK_IMAGE_BORDER_CLAMP)
    #between 3 and 5
    coeff.SetSplineDegree(5)
    coeff.Update()

    maskImageToWorldMatrix=vtk.vtkMatrix4x4()
    mask.GetImageToWorldMatrix(maskImageToWorldMatrix)
    ns=0


    maskN=mask.GetNumberOfPoints()
    maskScalars=mask.GetPointData().GetScalars()
    maskOrigin=[0,0,0]
    maskOrigin=mask.GetOrigin()

    for i in range(0,maskN):
      #skip all points that are 0
      if maskScalars.GetTuple1(i)==0:
          continue

      #get global coordinates of point i
      maskPos=self.getMaskPos(mask,i)

      #print("Evaluating at {}").format(maskPos)
      #convert from global to local
      dataPos=[0,0,0]
      #account for potentially applied transform on dataNode
      dataNode.TransformPointFromWorld(maskPos,dataPos)
      dataPos.append(1)
      dataIJK=dataRAStoIJK.MultiplyPoint(dataPos)

      #drop the 4th index
      dataIJK=dataIJK[0:3]

      #interpolate
      s+=coeff.Evaluate(dataIJK)
      ns+=1

    return s/ns

  def countSegments(self):
    segNodeList=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
    if segNodeList.GetNumberOfItems()==0:
        return 0
    fNode=segNodeList.GetItemAsObject(0)
    segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
    if fNode==None:
        return 0
    return segNode.GetSegmentation().GetNumberOfSegments()

  def getSegmentName(self,i):
      segNodeList=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
      if segNodeList.GetNumberOfItems()==0:
          return "NONE"
      fNode=segNodeList.GetItemAsObject(0)
      segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
      if fNode==None:
          return "NONE"
      return segNode.GetSegmentation().GetSegment(segNode.GetSegmentation().GetNthSegmentID(i)).GetName()

  def storeNodeRemote(self,pathList,nodeName):
    
    node=slicer.mrmlScene.GetFirstNodeByName(nodeName)
    if node==None:
        print("Node {} not found".format(nodeName))
        return

    suffix=".nrrd"
    if node.__class__.__name__=="vtkMRMLDoubleArrayNode":
        suffix=".mcsv"
    if (node.__class__.__name__=="vtkMRMLTransformNode" or
            node.__class__.__name__=="vtkMRMLGridTransformNode"):
        suffix=".h5"

    fileName=nodeName+suffix

    localPath=os.path.join(self.tempPath,fileName)

    slicer.util.saveNode(node,localPath)
    print("Stored to: {}".format(localPath))
    if self.cfg["remote"]:
        labkeyPath=self.fb.buildPathURL(self.cfg['project'],pathList)
        print ("Remote: {}".format(labkeyPath))
        #checks if exists
        remoteFile=labkeyPath+'/'+fileName
        self.fb.writeFileToFile(localPath,remoteFile)

  def storeVolumeNodes(self,patientId,n1,n2):
      #n1=self.time_frame.minimum;
      #n2=self.time_frame.maximum
    pathList=[patientId]

    print("Store CT")
    nodeName=patientId+'CT'

    self.storeNodeRemote(pathList,nodeName)

    #prefer resampled
    testNode=slicer.util.getFirstNodeByName(nodeName+"_RS")
    if testNode:
        nodeName=nodeName+"_RS"
        self.storeNodeRemote(pathList,nodeName)

    print("Storing NM from {} to {}".format(n1,n2))
    n=n2-n1+1
    for i in range(n):
        it=i+n1
        nodeName=patientId+'Volume'+str(it)

        self.storeNodeRemote(pathList,nodeName)

        #add resampled
        testNode=slicer.util.getFirstNodeByName(nodeName+"_RS")
        if testNode:
            nodeName=nodeName+"_RS"
            self.storeNodeRemote(pathList,nodeName)

    self.storeDummyInputFunction(patientId)

  def storeSegmentation(self,patientId):
      pathList=[patientId]
      segNodeName="Heart"
      self.storeNodeRemote(pathList,segNodeName)

  def storeInputFunction(self,patientId):
      self.calculateInputFunction(patientId)
      relativePath=[patientId]
      doubleArrayNodeName=patientId+'_Ventricle'
      self.storeNodeRemote(relativePath,doubleArrayNodeName)

  def storeDummyInputFunction(self,patientId):
      self.calculateDummyInputFunction(patientId)
      relativePath=[patientId]
      doubleArrayNodeName=patientId+'_Dummy'
      self.storeNodeRemote(relativePath,doubleArrayNodeName)

  def storeTransformation(self,patientId):
      relativePath=[patientId]
      transformNodeName=patientId+"_DF"
      self.storeNodeRemote(relativePath,transformNodeName)

  def applyTransform(self, patientId,refPatientId,n1,n2):
      if patientId == refPatientId:
          print("Transform [{}] and reference [{}] are the same".format(patientId, refPatientId))
          return
      transformNodeName=patientId+"_DF"
      transformNode=slicer.util.getFirstNodeByName(transformNodeName)
      if transformNode==None:
          print("Transform node [{}] not found".format(transformNodeName))
          return


      n=n2-n1+1
      for i in range(n):
          it=i+n1
          nodeName=patientId+'Volume'+str(it)
          node=slicer.util.getFirstNodeByName(nodeName)
          if node==None:
              continue
          node.SetAndObserveTransformNodeID(transformNode.GetID())

          refNodeName=refPatientId+'Volume'+str(it)
          refNode=slicer.util.getFirstNodeByName(refNodeName)
          if refNode!=None:
              self.resampler.rebinNode(node,refNode)
          print("Completed transformation {}".format(it))
          #unset transformation
          node.SetAndObserveTransformNodeID('NONE')

          

      nodeName=patientId+'CT'
      node=slicer.util.getFirstNodeByName(nodeName)
      if not node==None:
          node.SetAndObserveTransformNodeID(transformNode.GetID())
          refNodeName=refPatientId+'CT'
          refNode=slicer.util.getFirstNodeByName(refNodeName)
          if refNode!=None:
              self.resampler.rebinNode(node,refNode)

          node.SetAndObserveTransformNodeID('NONE')






  def calculateInputFunction(self,patientId):
       debug=True
       n=len(self.frame_time)

       dnsNodeName=patientId+'_Ventricle'
       dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode',dnsNodeName)
       if dns.GetNumberOfItems() == 0:
           dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
           dn.SetName(dnsNodeName)
       else:
           dn = dns.GetItemAsObject(0)



       dn.SetSize(n)


       fNodes=slicer.mrmlScene.GetNodesByClass("vtkMRMLSegmentationNode")
       if fNodes.GetNumberOfItems() == 0:
           return
       fNode=fNodes.GetItemAsObject(0)
       segNode=slicer.vtkMRMLSegmentationNode.SafeDownCast(fNode)
       segmentation=segNode.GetSegmentation()

       juse=-1
       for j in range(0,segmentation.GetNumberOfSegments()):
           segmentID=segNode.GetSegmentation().GetNthSegmentID(j)
           segment=segNode.GetSegmentation().GetSegment(segmentID)
           if segment.GetName()=='Ventricle':
               juse=j
               break

       if juse<0:
           print('Failed to find Ventricle segment')
           return

       dt=0;
       t0=0;
       ft=self.frame_time
       dt=self.frame_duration
       for i in range(0,n):
           vol=patientId+"Volume"+str(i)
           fx=ft[i]
           fy=self.meanROI(vol,juse)
           if debug:
               print('{}: t0={} tp={} dt={}'.format(i,t0,fx,dt))
           t0+=dt[i]

           dn.SetValue(i, 0, fx)
           dn.SetValue(i, 1, fy/dt[i])
           dn.SetValue(i, 2, 0)
           print("[{0} at {1:.2f}:{2:.2f}]".format(vol,fx,fy))

  def calculateDummyInputFunction(self,patientId):

        n=self.frame_data.shape[3]

        dnsNodeName=patientId+'_Dummy'
        dns = slicer.mrmlScene.GetNodesByClassByName('vtkMRMLDoubleArrayNode',dnsNodeName)
        if dns.GetNumberOfItems() == 0:
            dn = slicer.mrmlScene.AddNode(slicer.vtkMRMLDoubleArrayNode())
            dn.SetName(dnsNodeName)
        else:
            dn = dns.GetItemAsObject(0)



        dn.SetSize(n)


        ft=self.frame_time
        dt=self.frame_duration
        for i in range(0,n):
            fx=ft[i]
            fy=dt[i]
            dn.SetValue(i, 0, fx)
            dn.SetValue(i, 1, fy)
            dn.SetValue(i, 2, 0)



class cardiacSPECTTest(ScriptedLoadableModuleTest):
  """
  This is the test case for your scripted module.
  Uses ScriptedLoadableModuleTest base class, available at:
  https://github.com/Slicer/Slicer/blob/master/Base/Python/slicer/ScriptedLoadableModule.py
  """

  def setUp(self):
    """ Do whatever is needed to reset the state - typically a scene clear will be enough.
    """
    slicer.mrmlScene.Clear(0)

  def runTest(self):
    """Run as few or as many tests as needed here.
    """
    self.setUp()
    self.test_cardiacSPECT1()

  def test_cardiacSPECT1(self):
    """ Ideally you should have several levels of tests.  At the lowest level
    tests should exercise the functionality of the logic with different inputs
    (both valid and invalid).  At higher levels your tests should emulate the
    way the user would interact with your code and confirm that it still works
    the way you intended.
    One of the most important features of the tests is that it should alert other
    developers when their changes will have an impact on the behavior of your
    module.  For example, if a developer removes a feature that you depend on,
    your test should break so they know that the feature is needed.
    """

    self.delayDisplay("Starting the test")
    #
    # first, get some data
    #

    self.delayDisplay('Test passed!')