import numpy
import json
import os
import scipy.interpolate
#for partial function specializations
import functools
import function
import importlib
importlib.reload(function)

class model:
   def __init__(self):
      self.compartments={}
      self.seJ={}
      self.scaled=[]
        
   def add_input(self,compartmentName,parameterName):
      self.compartments[compartmentName]['input']=parameterName

   def add_compartment(self,compartmentName):
      self.compartments[compartmentName]={}
      self.compartments[compartmentName]['targets']={}
      self.compartments[compartmentName]['sensTargets']={}

   def getTimeUnit(self):

      try: 
         return self.mod['timeUnit']
      except KeyError:
         return 's'

   def bind(self,src,target,qName,pcName):
      

      #establish a flow from source compartment to the target
      
      #the source equation (where we subtract the current)
      #in fact, this is the diagonal element
      #get volume names
      srcVName=self.getVolumePar(src)
      #generate coupling object (w/derivatives)    
      pSrc=self.couplingObject(-1,qName,pcName,srcVName)
      #this includes derivatives and value!
      self.addValueObject(src,src,pSrc)
     
      #special target which is not part of calculation
      if target=='dump':
         return

      #the target equation (where we add the current)
      #get volume names
      targetVName=self.getVolumePar(target)
      #generate coupling object
      pTarget=self.couplingObject(1,qName,pcName,targetVName)
      #equation is for target compartment, but binding for source
      self.addValueObject(target,src,pTarget)
   
   def addValueObject(self,targetName,srcName,cObject):
      #always binds equation id and a variable
      targetList=self.compartments[targetName]['targets']
      addValue(targetList,srcName,cObject["value"])
      der=cObject["derivatives"]
      for d in der:
         targetSE=self.getSEJ_comp(d,targetName)
         addValue(targetSE,srcName,der[d])

                    
   def couplingObject(self,sign, qParName, pcParName, vParName):
        
      qPar=self.get(qParName)
      pcPar=self.get(pcParName)
      vPar=self.get(vParName)
      q=qPar['value']
      pc=pcPar['value']
      v=vPar['value']
      DPC=pcPar['derivatives']
      DQ=qPar['derivatives']
      DV=vPar['derivatives']

      if any(['function' in qPar,'function' in pcPar, 'function' in vPar]):
         fq=function.to(q)
         fpc=function.to(pc)
         fv=function.to(v)
         f=lambda t,q=fq,pc=fpc,v=fv,s=sign:s*q(t)/v(t)/pc(t)
         dfdPC=lambda t,f=f,pc=fpc:-f(t)/pc(t)
         dPC=function.generate(dfdPC,DPC)
         dfdQ=lambda t,f=f,q=fq: f(t)/q(t)
         dQ=function.generate(dfdQ,DQ)
         dfdV=lambda t,f=f,v=fv: -f(t)/v(t)
         dV=function.generate(dfdV,DV)
         return function.Object(f,[dPC,dQ,dV])
      else:
         f=sign*q/v/pc
         return function.derivedObject(sign*q/v/pc,\
            [{'df':-f/pc,'D':DPC},\
            {'df':sign/v/pc,'D':DQ},\
            {'df':-f/v,'D':DV}])
      #derivatives is the combination of the above
            

   def getVolumePar(self,compartment):
      #returnis volume name, if found and useVolume is directed, 
      #or a standard parameter one
      try:
         return self.mod["volumes"][compartment]
         #parV=self.mod["parameters"][parVName]
      except KeyError:
         pass

      return "one"


   def build(self):
      comps=self.compartments
      self.n=len(comps)
      #numeric representation of the input
      self.fu=numpy.zeros((self.n))
      #dictionary that holds potential input function objects
      self.du={}
      self.lut={c:i for (i,c) in zip(range(self.n),comps.keys())}
      self.dM={}
      self.fM=numpy.zeros((self.n,self.n))
      self.inputDerivatives={}
      self.uTotal=[]
      for c in comps:
         comp=comps[c]
         if 'input' in comp:
            qs=self.get(comp['input'])
            self.uTotal.append(qs["value"])
            qV=self.getVolumePar(c)
            #input is a quotient (amount of exogen per unit time per volume(mass) of input compartment)
            qs1=function.ratio(qs,self.get(qV))
            if function.isFunction(qs1):
               self.du[self.lut[c]]=qs1
            else:
               self.fu[self.lut[c]]=qs1['value']
               #let buildSE know we have to include this derivatives
               self.inputDerivatives[c]=qs1['derivatives']
              
                            
         for t in comp['targets']:
            arr=comp['targets'][t]
            
            if function.contains(arr):
               try:
                  self.dM[self.lut[c]][self.lut[t]]=\
                     function.sumArray(arr)
               except (KeyError,TypeError):
                  self.dM[self.lut[c]]={}
                  self.dM[self.lut[c]][self.lut[t]]=\
                     function.sumArray(arr)
            else:
               #just set once
               self.fM[self.lut[c],self.lut[t]]=sum(arr)
      #generate total from self.uTotal
      #ignore derivatives; uTotal is just a scaling shorthand
      if function.contains(self.uTotal):
         self.du[self.lut['total']]=function.Object(function.sumArray(self.uTotal),[])
      else:
         self.fu[self.lut['total']]=sum(self.uTotal)
      #build SE part


      self.buildSE()

   def buildSE(self):
      #check which parameterst to include
      parList=[]
      pars=self.parSetup['parameters']
      #add derivatives to jacobi terms
      parCandidates=list(self.seJ.keys())
      #add derivatives of input terms
      for x in self.inputDerivatives:
         D=self.inputDerivatives[x]
         parCandidates.extend(list(D.keys()))
      for x in self.du:
         D=self.du[x]['derivatives']
         parCandidates.extend(list(D.keys()))
      #get rid of duplicates
      parCandidates=list(set(parCandidates))

      for parName in parCandidates:
         #print(par)
         
         par=pars[parName]
         usePar=calculateDerivative(par)
         #print('[{}]: {}'.format(usePar,par))
         if not usePar:
            continue
         parList.append(parName)

      #print(parList)
      self.m=len(parList)
      self.lutSE={c:i for (i,c) in zip(range(self.m),parList)}
      w=self.getWeights(self.lutSE)
      w=numpy.sqrt(w)

      self.qSS={}
      self.SS=numpy.zeros((self.m,self.n,self.n))
      #elements of SS will be w_p*dM_i,j/dp
      for parName in parList:
         try:
            sources=self.seJ[parName]
         except KeyError:
            continue
         for compartment in sources:
            targets=sources[compartment]
            for t in targets:
               k=self.lutSE[parName]
               i=self.lut[compartment]
               j=self.lut[t]
               #print('[{} {} {}] {}'.format(parName,compartment,t,targets[t]))
               arr=targets[t]
               if not function.contains(arr):
                  self.SS[k,i,j]=w[k]*sum(arr)
                  continue
               ft=function.sumArray(arr,w[k])
               try:
                  self.qSS[k][i][j]=ft
               except (KeyError,TypeError):
                  try:
                     self.qSS[k][i]={}
                     self.qSS[k][i][j]=ft
                  except (KeyError,TypeError):
                     self.qSS[k]={}
                     self.qSS[k][i]={}
                     self.qSS[k][i][j]=ft


      #derivatives of inputs
      #time dependent derivatives are handled in self.Su(t)
      self.fSu=numpy.zeros((self.m,self.n))
      for x in self.inputDerivatives:
         D=self.inputDerivatives[x]
         for p in D:
            if p in parList:
               k=self.lutSE[p]
               self.fSu[self.lutSE[p],self.lut[x]]=D[p]*w[k]

      #use fM to build static part of fJ
      N=self.n*(self.m+1)
      self.fJ=numpy.zeros((N,N))
      for i in range(self.m+1):
         self.fJ[i*self.n:(i+1)*self.n,i*self.n:(i+1)*self.n]=self.fM
    
   


   def inspect(self):
      comps=self.compartments
      pars=self.parSetup['parameters']
      #pars=self.mod['parameters']
      
      tu=self.getTimeUnit()
      print('Time unit: {}'.format(tu))
      print('Compartments')
      for c in comps:
         print('{}/{}:'.format(c,self.lut[c]))
         comp=comps[c]
         if 'input' in comp:
             print('\tinput\n\t\t{}'.format(comp['input']))
         print('\ttargets')
         for t in comp['targets']:
             print('\t\t{}[{},{}]: {}'.format(t,self.lut[c],self.lut[t],\
               comp['targets'][t]))
      print('Flows')
      for f in self.flows:
         fName=self.flows[f]
         fParName=self.mod['flows'][fName]
         fPar=pars[fParName]
         print('\t{}[{}]:{} [{}]'.format(f,fName,fParName,self.get(fParName)))

      print('Volumes')
      for v in self.mod['volumes']:
         vParName=self.mod['volumes'][v]
         vPar=pars[vParName]
         print('\t{}:{} [{}]'.format(v,vParName,self.get(vParName)))

      print('Partition coefficients')
      for pc in self.mod['partitionCoefficients']:
         pcParName=self.mod['partitionCoefficients'][pc]
         pcPar=pars[pcParName]
         print('\t{}:{} [{}]'.format(pc,pcParName,self.get(pcParName)))

   def inspectSE(self):

      print('SE parameters')
      for p in self.seJ:
         print(p)
         sources=self.seJ[p]
         for compartment in sources:
             targets=sources[compartment]
             for t in targets:
                 print('\t SE bind {}/{}:{}'.format(compartment,t,targets[t]))
    
   def parse(self,setupFile,parameterFile):
                    
      with open(setupFile,'r') as f:
         self.mod=json.load(f)

      
      with open(parameterFile,'r') as f:
         self.parSetup=json.load(f)

      self.mod['compartments'].append('total')
      for m in self.mod['compartments']:
         self.add_compartment(m)
      
      for m in self.mod['scaled']:
         self.scaled.append(m)

      self.add_default_parameters()
      #standard parameters such as one,zero etc.
      for s in self.mod['inputs']:
         #src=mod['inputs'][s]
         self.add_input(s,self.mod['inputs'][s])
      self.flows={}
      #pars=self.mod['parameters']
      pars=self.parSetup['parameters']
      for f in self.mod['flows']:
         #skip comments
         if f.find(':')<0:
             continue
         
         comps=f.split(':')
         c0=splitVector(comps[0])
         c1=splitVector(comps[1])
         for x in c0:
             for y in c1:
                 pairName='{}:{}'.format(x,y)
                 self.flows[pairName]=f
                 
      for b in self.mod['bindings']['diffusion']:
         #whether to scale transfer constants to organ volume
         #default is true, but changing here will assume no scaling
         comps=b.split('->')
         try:
             pcParName=self.mod['partitionCoefficients'][b]
         except KeyError:
             pcParName="one"
        
         kParName=self.mod['bindings']['diffusion'][b]
         #operate with names to allow for value/function/derived infrastructure
         self.bind(comps[0],comps[1],kParName,pcParName)
         
      for q in self.mod['bindings']['flow']:
         comps=q.split('->')
         srcs=splitVector(comps[0])
         tgts=splitVector(comps[1])
         for cs in srcs:
             for ct in tgts:
                 #get partition coefficient
                 try:
                     pcParName=self.mod['partitionCoefficients'][cs]
                 except KeyError:
                     pcParName="one"
                 
                 #get flow (direction could be reversed)
                 try:
                     qName=self.flows['{}:{}'.format(cs,ct)]
                 except KeyError:
                     qName=self.flows['{}:{}'.format(ct,cs)]
                 
                 flowParName=self.mod['flows'][qName]
                 #flowPar=pars[flowParName]
                 
                 self.bind(cs,ct,flowParName,pcParName)
                 
      self.build()
   
   def add_default_parameters(self):
      pars=self.parSetup['parameters']
      pars['one']={'value':1}
      pars['zero']={'value':0}
      pars['two']={'value':2}
   

   def M(self,t,y=numpy.array([])):
      for i in self.dM:
         for j in self.dM[i]:
            self.fM[i,j]=self.dM[i][j](t)
      #create an array and fill it with outputs of function at t
      if (y.size==0):
         return self.fM
      self.set_scaledM(t,y)
      return self.fM

   def set_scaledM(self,t,y):
      #prevent zero division
      eps=1e-8
      for c in self.scaled:
         i=self.lut[c]
         it=self.lut['total']
         try:
            k=numpy.copy(self.originalK[i])
         except AttributeError:
            k=numpy.copy(self.fM[i,:])
            self.originalK={}
            self.originalK[i]=k
            #make another copy
            k=numpy.copy(self.originalK[i])
         except KeyError:
            k=numpy.copy(self.fM[i,:])
            self.originalK[i]=k
            #make another copy
            k=numpy.copy(self.originalK[i])

         k[i]=k[i]-self.u(t)[it]
         #scale all inputs by total input mass
         for j in range(self.n):
            self.fM[i,j]=k[j]/(y[it]+eps)

   def u(self,t):
      for x in self.du:
         self.fu[x]=self.du[x]['value'](t)
      #this should be done previously
      return self.fu

   def Su(self,t):
      w=self.getWeights(self.lutSE)
      w=numpy.sqrt(w)
      #add time dependent values
      for x in self.du:
         D=self.du[x]['derivatives']
         for p in D:
            k=self.lutSE[p]
            #print(f'[{p}]: {k}')
            self.fSu[k,x]=w[k]*D[p](t)
      return self.fSu

   def jacobiFull(self,t):

      #update jacobi created during build phase with time dependent values
      for i in self.dM:
         for j in self.dM[i]:
            for k in range(self.m+1):
               self.fJ[k*self.n+i,k*self.n+j]=self.dM[i][j](t)
      return self.fJ



   def fSS(self,t,y=numpy.array([])):
      for k in self.qSS:
         for i in self.qSS[k]:
            for j in self.qSS[k][i]:
               #print('[{},{},{}] {}'.format(k,i,j,self.qSS[k][i][j]))
               self.SS[k,i,j]=(self.qSS[k][i][j])(t)
      if y.size==0:
         return self.SS
      self.set_scaledSS(t,y)
      return self.SS
 
   def set_scaledSS(self,t,y):
      #prevent zero division
      eps=1e-8
      for c in self.scaled:
         it=self.lut['total']
         i=self.lut[c]
         try:
            dkdp=numpy.copy(self.originalSS[i])
         except AttributeError:
            dkdp=numpy.copy(self.SS[:,i,:])
            self.originalSS={}
            self.originalSS[i]=dkdp
            dkdp=numpy.copy(self.originalSS[i])
         except KeyError:
            dkdp=numpy.copy(self.SS[:,i,:])
            self.originalSS[i]=dkdp
            dkdp=numpy.copy(self.originalSS[i])
         self.SS[:,i,:]=dkdp/(y[it]+eps)
      #should add error on u!
                     
   def fSY(self,y,t):
      #M number of sensitivity parameters
      #N number of equations
      #fSS is MxNxN

      #assume a tabulated solution y(t) at t spaced intervals

      qS=self.fSS(t,y).dot(y)
      #qS is MxN
      #but NxM is expected, so do a transpose

      #for simultaneous calculation, a Nx(M+1) matrix is expected
      tS=numpy.zeros((self.n,self.m+1))
      #columns from 2..M+1 are the partial derivatives 
      tS[:,1:]=numpy.transpose(qS)
      #first column is the original function
      tS[:,0]=self.u(t)
      return tS
    
   def fS(self,t):
   #M number of sensitivity parameters
   #N number of equations
   #fSS is MxNxN
        
   #assume a tabulated solution y(t) at t spaced intervals
        
      qS=self.fSS(t).dot(self.getY(t))
      return numpy.transpose(qS)
                     
   def getSEJ(self,parName):
      #find the sensitivity (SE) derivative of Jacobi with 
      #respect to parameter  
         try:
            return self.seJ[parName]
         except KeyError:
            self.seJ[parName]={}
            return self.seJ[parName]
    
   def getSEJ_comp(self,parName,compartmentName):
      #find equation dictating concentration in compartmentName 
      #for jacobi-parameter derivative
      seJ=self.getSEJ(parName)

      try:
         return seJ[compartmentName]
      except KeyError:
         seJ[compartmentName]={}
         return seJ[compartmentName]

   def setY(self,t,y):
      self.tck=[None]*self.n
      for i in range(self.n):
         self.tck[i] = scipy.interpolate.splrep(t, y[:,i], s=0)
    
   def getY(self,t):
      fY=numpy.zeros(self.n)
      for i in range(self.n):
         fY[i]=scipy.interpolate.splev(t, self.tck[i], der=0)
      return fY
   
   def getWeight(self,parName):
      pars=self.parSetup['parameters']
      par=pars[parName]
      #self.get parses the units
      v=self.get(parName)["value"]
      #if par['dist']=='lognormal':
         #this is sigma^2_lnx
         #sln2=numpy.log(par["cv"]*par["cv"]+1)
         #have to multiplied by value to get the derivative 
         #with respect to lnx
         #return sln2*v*v
      #else:
         #for Gaussian, cv is sigma/value; get sigma by value multiplication
      return par["cv"]*par["cv"]*v*v

      
   def getMax(lutSE):
      fm=-1
      for x in lutSE:
         if int(lutSE[x])>fm:
            fm=lutSE[x]
      return fm

   def getWeights(self,lutSE):
      #pars=self.parSetup['parameters']
      wts=numpy.zeros((model.getMax(lutSE)+1))
      for parName in lutSE:
         j=lutSE[parName]
         wts[j]=self.getWeight(parName)
      return wts

      
   def getVolumes(self):
      m=numpy.zeros((len(self.lut)))
      for p in self.lut:
         m[self.lut[p]]=self.getVolume(p)
      return m
   
   def getVolume(self,p):
      pV=self.getVolumePar(p)
      return self.get(pV)['value']

   def getDerivatives(self,se,i):
      #return latest point derivatives
      fse=se[-1][i]
      #fse is an m-vector
      return fse*fse

  

   def calculateUncertainty(self,se):
      
      s2out=numpy.zeros(se.shape[1:])
      se2=numpy.multiply(se,se)
      #w=self.getWeights(self.lutSE)
      w=numpy.ones((self.m))
      return numpy.sqrt(numpy.dot(se2,w))


   def get(self,parName):
      pars=self.parSetup['parameters']
      par=pars[parName]
      par['name']=parName
      if "value" in par:
         return self.getValue(par)
      if "function" in par:
         return self.getFunction(par)
      if "derived" in par:
         return self.getDerived(par)
      print('Paramter {} not found!'.format(parName))

   def getValue(self,par):

      v=par["value"]
      parName=par['name']
      #convert to seconds
      try:
         parUnits=par['unit'].split('/')
      except (KeyError,IndexError):
         #no unit given
         return valueObject(v,parName)
     
      timeUnit=self.getTimeUnit()

      try:
         if parUnits[1]==timeUnit:
            return valueObject(v,parName)
      except IndexError:
         #no / in unit name
         return valueObject(v,parName)

      if parUnits[1]=='min' and timeUnit=='s':
         return valueObject(v/60,parName)
      
      if parUnits[1]=='s' and timeUnit=='min':
         return valueObject(60*v,parName)

      if parUnits[1]=='day' and timeUnit=='min':
         return valueObject(v/24/60,parName)


      if parUnits[1]=='hour' and timeUnit=='min':
         return valueObject(v/60,parName)
      #no idea what to do
      return valueObject(v,parName)

   def getFunction(self,par):
      fName=par['function']
      #print('[{}]: getFunction({})'.format(par['name'],par['function']))
      df=self.parSetup['functions'][fName]
      skip=['type']
      par1={x:self.get(df[x]) for x in df if x not in skip}
      if df['type']=='linearGrowth':
         #print(par1)
         return function.linearGrowth(par1)
      if df['type']=='linearGrowthFixedSlope':
         return function.linearGrowthFixedSlope(par1)
      if df['type']=='exp':
         return function.exp(par1)
         
      print('Function {}/{} not found!'.format(fName,df))

   def getDerived(self,par):
      dName=par['derived']
      d=self.parSetup['derivedParameters'][dName]
      #print('Derived [{}]: type {}'.format(dName,d['type']))
      if d['type']=='product':
         return function.product(self.get(d['a']),self.get(d['b']))

      if d['type']=='power':
         return function.power(self.get(d['a']),self.get(d['n']))
      
      if d['type']=='ratio':
         return function.ratio(pA=self.get(d['a']),pB=self.get(d['b']))

      if d['type']=='sum':
         return function.add(pA=self.get(d['a']),pB=self.get(d['b']))

def calculateDerivative(par):
   #add derivatives if dist(short for distribution) is specified
   return "dist" in par    

def valueObject(v,parName):
   #convert everything to functions
   d0={parName:1}
   return {'value':v,'derivatives':{parName:1}}

def splitVector(v):
   if v.find('(')<0:
      return [v]
   return v[1:-1].split(',')

def addValue(qdict,compName,v):
   #add function to compName of dictionary qdict, 
   #check if compName exists and handle the potential error
   #lambda functions can't be summed directly, so qdict is a list
   #that will be merged at matrix generation time
   try:
      qdict[compName].append(v)
   except KeyError:
      qdict[compName]=[v]

   #also add derivatives
   #
   #   for d in dTarget:
   #      ctarget=self.getSEJ_comp(d,target)
   #      addValue(ctarget,target,dTarget[d])




def get(timeUnit,par):
   v=par["value"]
#convert to seconds
   try:
      parUnits=par['unit'].split('/')
   except (KeyError,IndexError):
      #no unit given
      return v
   
   try:
      if parUnits[1]==timeUnit:
         return v
   except IndexError:
      #no / in unit name
      return v
   if parUnits[1]=='min' and timeUnit=='s':
      return v/60
   
   if parUnits[1]=='s' and timeUnit=='min':
      return 60*v

   #no idea what to do
   return v