import numpy
import scipy.optimize
import functools
import fitModel

class initialValueGenerator:
   
   def __init__(self):
      self.map={'gaus':initialValueGenerator.drawGauss,
      'loggaus':initialValueGenerator.drawLogGauss,
      'poisson':initialValueGenerator.drawPoisson,
      'flat':initialValueGenerator.drawFlat,
      'const':initialValueGenerator.drawConst}

   def add(self,gtype='gaus',parameters=[0,1],bounds=[-numpy.inf,numpy.inf]):
      v={'type':gtype,'parameters':parameters,'bounds':bounds}
      try:
         self.vals.append(v)
      except AttributeError:
         self.vals=[v]

   def getN(self):
      return len(self.vals)
   
   def drawGauss(p):
      return numpy.random.normal(p[0],p[1])

   def drawLogGauss(p):
      return numpy.power(10,initialValueGenerator.drawGauss(p))

   def drawPoisson(p):
      return numpy.random.poisson(p[0])
      
   def drawFlat(p):
      return p[0]+numpy.random.random()*(p[1]-p[0])

   def drawConst(p):
      return p[0]

   def draw(self,typ,p,lb,ub):
      f=self.map[typ]
      while True:
         x=f(p)
         if x<lb:
            continue
         if x>ub:
            continue
         return x
    

   def generate(self):
      n=len(self.vals)
      par=numpy.zeros(n)
      lb=numpy.zeros(n)
      ub=numpy.zeros(n)
      for i in range(n):
         x=self.vals[i]
         p=x['parameters']
         lb[i]=x['bounds'][0]
         ub[i]=x['bounds'][1]
         par[i]=self.draw(x['type'],p,lb[i],ub[i])
      return par, (lb,ub)

   def getBoundsScalar(self):
      gbounds=[]
      for x in self.vals:
         gbounds.append(x['bounds'])
      return gbounds

def generateIVF():
    ig=initialValueGenerator()
    ig.add('poisson',[200],[0,numpy.inf])
    ig.add('loggaus',[1,0.5],[0,numpy.inf])
    ig.add('loggaus',[-1,1],[0,numpy.inf])
    ig.add('const',[6,1],[0,numpy.inf])
    ig.add('poisson',[200],[0,numpy.inf])
    ig.add('loggaus',[-2,1],[0,0.1])
    return ig

def generateIVFFinite():
    ig=initialValueGenerator()
    #A
    ig.add('poisson',[200],[0,10000])
    #tau
    ig.add('loggaus',[1,0.5],[0,40])
    #alpha
    ig.add('loggaus',[-1,1],[0,5])
    #dt
    ig.add('const',[6,1],[6,50])
    #B
    ig.add('poisson',[200],[0,10000])
    #gamma
    ig.add('loggaus',[-2,1],[0,0.1])
    return ig


def generateCModel():
    #generate approx candidate
    ig=initialValueGenerator()
    ig.add('loggaus',[-3,1],[0,0.1])
    ig.add('flat',[0,1],[0,1])
    ig.add('loggaus',[-7,2],[0,0.1])
    ig.add('gaus',[10,5],[0,30])
    return ig

def generateCModelFinite():
    #generate approx candidate
    ig=initialValueGenerator()
    #k1
    ig.add('loggaus',[-3,1],[0,0.1])
    #BVF
    ig.add('flat',[0,1],[0,1])
    #k2
    ig.add('loggaus',[-7,2],[0,0.1])
    #dt
    ig.add('gaus',[10,2],[5,50])
    return ig

def generateGauss(fit,bounds,n=1):
#generate n realizations of a gaussian multivariate distribution with average (vector) mu
#and (co)variance (matrix) cov
#output is a (r,n) matrix where r is size of vector mu and n are realizations
   sig=scipy.linalg.sqrtm(fit.cov)
   r=fit.mu.shape[0]
   out=numpy.ones((r,n))
   i=0
   while True:
      s=numpy.matmul(sig,numpy.random.normal(size=r))
      s+=fit.mu
      if in_bounds(s,bounds[0],bounds[1]):
         out[:,i]=s
         i+=1
      if i==n:
         break
   return out
   #return numpy.outer(fit.mu,numpy.ones((1,n)))+numpy.matmul(sig,s)
  
def in_bounds(x0,lb,ub):
   r=x0.shape[0]
   for i in range(r):
      if x0[i]<lb[i]:
         return False
      if x0[i]>ub[i]:
         return False
   return True

def getFit(samples,threshold=numpy.inf,verbose=0):
    class fit:pass
    chi2=samples[0,:]
    #mScore=numpy.min(chi2)
    
    fit.mu=numpy.mean(samples[1:,chi2<threshold],axis=1)
    fit.cov=numpy.cov(samples[1:,chi2<threshold])
    if verbose>0:
        print(fit.mu)
        print(fit.cov)

    return fit
#fit input function

def fitIVF(t, centers,nfit=10,nbatch=20):
   #t,dt=loadData.loadTime(r,xsetup)
   #centers=loadCenters(r,xsetup)
   #find center with maximum content/uptake 
   m=numpy.unravel_index(numpy.argmax(centers),centers.shape)[0]
   #treat it as ivf sample
   ivf=centers[m]
   #create a partial specialization to be used in optimization
   w=numpy.ones(t.shape)
   fun=functools.partial(fitModel.fDiff,fitModel.fIVF,t,ivf,w)  
   jac=functools.partial(fitModel.jacIVF,t)
   
   #generate approx candidate
   ig=generateIVF()
   n=ig.getN()
   samples=numpy.zeros((n+1,nfit))
   for j in range(nfit):
      fMin=1e30
      for i in range(nbatch):
         par,bounds=ig.generate()
         result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds,jac=jac)
         #result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds)
         #fit is invariant on 1/p[1]==p[2], so 
         xm=[x for x in result.x]
         xm[1]=numpy.max([result.x[1],1/result.x[2]])
         xm[2]=numpy.max([1/result.x[1],result.x[2]])
         if result.cost<fMin:
            fMin=result.cost
            x1=xm
      samples[1:,j]=x1
      samples[0,j]=fMin
   
   return m,samples


#global version with annealing
def fitIVFGlobal(t, centers,nfit=10, qLambda=0):
    #find center with maximmum point
    m=numpy.unravel_index(numpy.argmax(centers),centers.shape)[0]
    #treat it as ivf sample
    ivf=centers[m]
    #this is the (relative) weight of each time point in fit
    w=numpy.ones(t.shape)
    #create a partial specialization to be used in optimizations
    #funcion, a difference between model prediction fIVF and measured values ivf at points t, using point weights w
    funScalar=functools.partial(fitModel.fDiffScalar,fitModel.fIVF,t,ivf,w)  
    #add regulizer on A
    funScalarRegularized=functools.partial(fitModel.fDiffScalarRegularized,funScalar,fitModel.fIVFRegA,qLambda)
    #Jacobi
    jac=functools.partial(fitModel.jacIVF,t)
    #convert it to scalar for global minimization
    jacScalar=functools.partial(fitModel.jacScalar,fitModel.fIVF,t,ivf,w,jac)
    #add regularization on A
    jacScalarRegularized=functools.partial(fitModel.jacScalarRegularized,jacScalar,fitModel.jacIVFRegA,qLambda)

    ig=generateIVFFinite()
    boundsScalar=ig.getBoundsScalar()
    #par,bounds=ig.generate()
    n=ig.getN()
    samples=numpy.zeros((n+1,nfit))
    
    minSetup=dict(method='L-BFGS-B',jac=jacScalar)
    if qLambda>0:
       minSetup['jac']=jacScalarRegularized

    for j in range(nfit):
        if qLambda>0:
           result=scipy.optimize.dual_annealing(func=funScalarRegularized,bounds=boundsScalar,minimizer_kwargs=minSetup)
        else:
           result=scipy.optimize.dual_annealing(func=funScalar,bounds=boundsScalar,minimizer_kwargs=minSetup)
        xm=[qx for qx in result.x]
        xm[1]=numpy.max([result.x[1],1/result.x[2]])
        xm[2]=numpy.max([1/result.x[1],result.x[2]])
        if result.x[2]!=xm[2]:
           pass
            #print(f'[{j}] switch')
        samples[1:,j]=xm
        samples[0,j]=result.fun
        #print(f'[{j}] {result.fun}')
    return m,samples


#fit compartment

def fitCompartment(ivfFit, t, qf ,nfit=10,nbatch=20, useJac=False):
   
   igIVF=generateIVF()
   pars,boundsIVF=igIVF.generate()
   ivfSamples=generateGauss(ivfFit,boundsIVF,nfit)
   ig=generateCModel()
   nC=ig.getN()
   boundsScalar=ig.getBoundsScalar()
   samples=numpy.zeros((1+nC+nIVF,nfit))
   for j in range(nfit):
      #create a partial specialization to be used in optimization
      ivfPar=ivfSamples[:,j]
      w=numpy.ones(t.shape)
      fc1=functools.partial(fitModel.fCompartment,ivfPar)
      fun=functools.partial(fitModel.fDiff,fc1,t,qf,w)  
      jac=functools.partial(fitModel.jacDep,ivfPar,t)
      #minimize requires scalar function
      funScalar=functools.partial(fitData.fDiffScalar,fc1,t,qf,w)  
      jacScalar=functools.partial(fitData.jacScalar,fc1,t,qf,w,jac)
   
      fMin=1e30
      for i in range(nbatch):
         #generate approx candidate
         par,bounds=ig.generate()
         
         if useJac:
            result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds,jac=jac)
            #scalar, just for reference
            #result=scipy.optimize.minimize(fun=funScalar,x0=par,bounds=boundsScalar,
            #                                           jac=jacScalar,method='L-BFGS-B')
         else:
            result=scipy.optimize.least_squares(fun=fun,x0=par,bounds=bounds)
         if result.cost<fMin:
            fMin=result.cost
            x1=result.x
      
      samples[0,j]=fMin
      samples[1:nC+1,j]=x1
      samples[(1+nC):,j]=ivfPar
        
   
   return samples


def fitCompartmentGlobal(ivfFit, t, qf ,nfit=10,nbatch=20, useJac=False,qLambda=0):
    #nclass=setup['nclass'][0]
    #t,dt=loadData.loadTime(r,setup)
    #centers=loadData.loadCenters(r,setup,ir)
    #qf=centers[m]

    igIVF=generateIVF()
    xs,boundsIVF=igIVF.generate()
    ivfSamples=generateGauss(ivfFit,boundsIVF,nfit)

    #samples=numpy.zeros((4+1,nfit))
    nIVF=igIVF.getN()
    ig=generateCModelFinite()
    nC=ig.getN()
    samples=numpy.zeros((1+nC+nIVF,nfit))
   
    boundsScalar=ig.getBoundsScalar()

    #optimize has a hard time dealing with small values, so scale the target up and later scale the parameters (approximately correct)
    scale=1
    fi=qf.sum()
    if fi<0.1:
        scale=10/fi
        print('scale={}'.format(scale))
    #print(qf.sum()*scale)
    qf*=scale
    for j in range(nfit):
        ivfPar=ivfSamples[:,j]
        w=numpy.ones(t.shape)
        fc1=functools.partial(fitModel.fCompartment,ivfPar)
        funScalar=functools.partial(fitModel.fDiffScalar,fc1,t,qf,w)  
        funScalarRegularized=functools.partial(fitModel.fDiffScalarRegularized,funScalar,fitModel.fCRegBVF,qLambda)
        jac=functools.partial(fitModel.jacDep,ivfPar,t)
        jacScalar=functools.partial(fitModel.jacScalar,fc1,t,qf,w,jac)
        jacScalarRegularized=functools.partial(fitModel.jacScalarRegularized,jacScalar,fitModel.jacDepRegBVF,qLambda)

        #minSetup=dict(method='L-BFGS-B',jac=jacScalar)
        minSetup=dict(method='L-BFGS-B',jac=jacScalarRegularized)
        
        result=scipy.optimize.dual_annealing(func=funScalarRegularized,bounds=boundsScalar,minimizer_kwargs=minSetup)

        print(f'[{j}/{nfit}]')
        
        samples[0,j]=result.fun/scale
        qx=result.x
        qx[0]/=scale
        qx[1]/=scale
        print('{} {} {}'.format(scale,result.fun/scale,qx))
        samples[1:nC+1,j]=qx
        samples[(1+nC):,j]=ivfPar
        
    
    return samples