import numpy

#adapters
def fDiff(f,t,y,w,par):
   fv=f(t,par)
   return (fv-y)*w

def fDiffScalar(f,t,y,w,par):
   df=fDiff(f,t,y,w,par)
   return (df*df).sum()

def jacScalar(f,t,y,w,jac,par):
   #(m,) array
   #m number of time points
   b=2*fDiff(f,t,y,w,par)
   J=jac(par)
   return numpy.dot(b,J)

#add regularization with explicit lambda
def fDiffScalarRegularized(fDiffScalar,fScalarReg,qLambda,par):
   #fDiffScalar and fScalar reg should return a scalar to be minimized
   return fDiffScalar(par)+qLambda*fScalarReg(par)

def jacScalarRegularized(jacScalar,jacScalarReg,qLambda,par):
   #jac scalar (reg) should return (n,) vector with n number of parameters
   return jacScalar(par)+qLambda*jacScalarReg(par)

#input function
def fIVF(t,par):
   A=par[0]
   tau=par[1]
   alpha=par[2]
   dt=par[3]
   try:
      B=par[4]
      gamma=par[5]
   except IndexError:
      print('IndexError')
      B=0
      gamma=0

   t1=t-dt
   x=t1/tau
   if tau==1/alpha:
      et=getExp(x,1)
      fv=A*alpha*getX1(x,et)+B*gamma*getExp(t1,gamma)
   else:
      et=getExp(x,1)
      ea=getExp(t1,alpha)
      fv=A*alpha*getE(x,ea,et,1-alpha*tau)+B*gamma*getExp(t1,gamma)
   #fv*=A
   #fv[t1<0]=0
   return fv


def jacIVF(t,par):
   #return m by n matrix, where m is t.shape and n is par.shape
   jac=numpy.zeros((t.shape[0],par.shape[0]))
   A=par[0]
   tau=par[1]
   alpha=par[2]
   dt=par[3]
   t1=t-dt
   x=t1/tau
   et=getExp(x,1)
   try:
      B=par[4]
      gamma=par[5]
   except IndexError:
      print('IndexError')
      B=0
      gamma=0
   

   eg=getExp(t1,gamma)
   fb=B*gamma*eg

   if tau==1/alpha:
      fv=A*alpha*getX1(x,et)
      #first column, df/dA
      jac[t1>0,0]=fv[t1>0]/A
      #second column, df/dtau
      jac[t1>0,1]=-fv[t1>0]/tau*(1+0.5*x[t1>0])
      #third column df/dalpha
      jac[t1>0,2]=fv[t1>0]*(1/alpha-0.5*t1[t1>0])
      #last column df/d(dt)
      jac[:,3]=fv*alpha-A*alpha*et/tau+gamma*fb

   else:
      Q=A*alpha/(1-alpha*tau)
      ea=getExp(t1,alpha)
      fv=Q*(ea-et)
      #first column, df/dA
      jac[:,0]=fv/A
      #second column, df/dtau
      jac[:,1]=fv*alpha/(1-alpha*tau)+Q*getX1(x,et)/tau
      #third column df/dalpha
      jac[:,2]=fv/alpha/(1-alpha*tau)-Q*getX1(t1,ea)
      #last column df/d(dt)
      jac[:,3]=Q*getF(x,alpha,ea,1/tau,et,1)+gamma*fb

   try:
      jac[:,4]=fb/B
      jac[:,5]=fb/gamma+B*getX1(t1,eg)
   except IndexError:
      pass

   return jac

#regularizers, require A to be as small as possible
def fIVFRegA(par):
   return par[0]

def jacIVFRegA(par):
   jac=numpy.zeros(par.shape[0])
   jac[0]=1
   return jac

#helper functions

def getE(x,e,ek,w):
   #get Ea or Et
   #first argument is ea for Ea 
   #or et for Et
   #last argument is w0 for Ea
   #or wk for Et
   E=numpy.zeros(x.shape)
   E[x>0]=(e[x>0]-ek[x>0])/w
   return E

def getF(x,a,e,b,eb,w):
   F=numpy.zeros(x.shape)
   F[x>0]=(a*e[x>0]-b*eb[x>0])/w
   return F

def getH(x,a,b,ea=numpy.empty(0),eb=numpy.empty(0)):
   if ea.size==0:
      ea=getExp(x,a)
   if a==b:
      return getX1(x,ea)
   if eb.size==0:
      eb=getExp(x,b)
   return (ea-eb)/(b-a)

def getH2(x,a,b):
   #derivative with respect to the second parameter
   eb=getExp(x,b)
   if a==b:
      return -0.5*getX2(x,ea)
   
   return -(getH(x,a,b,eb=eb)+getX1(x,eb))/(b-a)

def getHD(x,a,b):
   #derivative with respect to offset dt
   ea=getExp(x,a)
   if a==b:
      return -ea+a*getX1(t1,ea)
   eb=getExp(x,b)
   return (a*ea-b*eb)/(b-a)

def getX1(x,e):
   #returns x1 or y1
   #for x1, use ea as the 2nd argument
   #for y1, use et as the 2nd argument
   x1=numpy.zeros(x.shape)
   x1[x>0]=x[x>0]*e[x>0]
   return x1

def getX2(x,e):
   x2=numpy.zeros(x.shape)
   x2[x>0]=x[x>0]*x[x>0]*e[x>0]
   return x2

def getX3(x,e):
   x2=numpy.zeros(x.shape)
   x2[x>0]=x[x>0]*x[x>0]*x[x>0]*e[x>0]
   return x2

def getExp(x,k):
   e=numpy.zeros(x.shape)
   e[x>0]=numpy.exp(-x[x>0]*k)
   return e

def fCompartment(ivfPar,t,par):
   A=ivfPar[0]
   tau=ivfPar[1]
   alpha=ivfPar[2]

   k1=par[0]
   BVF=par[1]
   k2=par[2]
   dt=par[3]
   try:
      B=ivfPar[4]
      gamma=ivfPar[5]
   except IndexError:
      B=0
      gamma=0
   return BVF*fIVFStar(t,A,tau,alpha,dt,B,gamma)+(1-BVF)*fC1(t,A,tau,alpha,dt,k1,k2,B,gamma)

def fIVFStar(t,A,tau,alpha,dt,B=0,gamma=0):
   t1=t-dt
   x=t1/tau
   
   et=getExp(x,1)
   ea=getExp(t1,alpha)
   eg=getExp(t1,gamma)
   AT=alpha*tau
   wa=1-AT

   
   fv=numpy.zeros(t1.shape)

   if AT==1:
      fv[x>0]=A*alpha*x[x>0]*ea[x>0]+B*gamma*getExp(t1,gamma)
   
   else:
      fv=A*alpha*getE(x,ea,et,wa)+B*gamma*getExp(t1,gamma)

   return fv

def fC1(t,A,tau,alpha,dt,k1,k2,B=0,gamma=0):
   #apply time shift
   t1=t-dt
   x=t1/tau
   
#place to store results
   r0=numpy.zeros(t1.shape)

#helper values
   et=getExp(x,1)
   ea=getExp(t1,alpha)
   ek=getExp(t1,k2)

   AT=alpha*tau
   K2T=k2*tau
   K1T=k1*tau
   w0=AT-K2T
   wk=1-K2T
   wa=1-AT

#stratify by parameter values
   #option A and D
   if AT==1:
#option D
      if K2T==1:
         r0=0.5*A*K1T*alpha*getX2(x,ea)+B*gamma*getH(t1,gamma,k2)
      else:
#option A
         Ea=getE(x,ea,ek,-w0)
         x1=getX1(x,ea)
         r0=-K1T*A*alpha/w0*(-Ea+x1)+B*gamma*getH(t1,gamma,k2)
   else:
   #option 0, B and C; AT not equal to 1
      D1=getH(t1,alpha,k2)
      D2=getH(t1,1/tau,k2)
      r0=A*alpha*K1T/wa*(D1-D2)+k1*B*gamma*getH(t1,gamma,k2)

   return r0

def jacDep(ivfPar,t,par):
   jac=numpy.zeros((t.shape[0],par.shape[0]))
   A=ivfPar[0]
   tau=ivfPar[1]
   alpha=ivfPar[2]
   try:
      B=ivfPar[4]
      gamma=ivfPar[5]
   except IndexError:
      B=0
      gamma=0
  
   k1=par[0]
   BVF=par[1]
   k2=par[2]
   dt=par[3]
   
   t1=t-dt
   x=t1/tau
   et=getExp(x,1)
   ea=getExp(t1,alpha)
   ek=getExp(t1,k2)

   AT=alpha*tau
   K2T=k2*tau
   K1T=k1*tau
   w0=AT-K2T
   wk=1-K2T
   wa=1-AT




   c1=fC1(t,A,tau,alpha,dt,k1,k2,B,gamma)
   c0=fIVFStar(t,A,tau,alpha,dt,B,gamma)
   #first column, df/dk1
   jac[t1>0,0]=c1[t1>0]/k1
   #second column, df/dBVF
   jac[t1>0,1]=c0[t1>0]-c1[t1>0]

#more effort for k2 and dt
#2nd column is df/dk2
#3rd column is df/d(dt)
   if AT==1:
      jac[:,3]=BVF*(A*(AT*getX1(x,ea)-ea)+B*gamma*gamma*getExp(t1,gamma))
      if K2T==1:
#option D
         jac[:,2]=-0.5*K1T*A*tau*tau*getX3(x,ea)+B*k1*gamma*getH2(t1,gamma,k2)
         jac[:,3]+=(1-BVF)*(0.5*K1T*A*(AT*getX2(x,ea)-2*getX1(x,ea))+B*k1*gamma*getHD(t1,gamma,k2))
      else:
#option A
         jac[:,2]=c1*tau/w0+K1T*A*alpha/w0*getHD(t1,alpha,k2)+B*k1*gamma*getHD(t1,gamma,k2)
         jac[:,3]+=(1-BVF)*(K1T*A/w0*(tau*getF(x,alpha,ea,k2,ek,-w0)-AT*getX1(x,ea)+ea)+B*k1*gamma*getHD(t1,gamma,k2))
   else:
#option 0,B,C
      jac[:,2]=A*K1T*alpha/wa*(getH2(t1,alpha,k2)-getH2(t1,1/tau,k2))+B*k1*gamma*getH2(t1,gamma,k2)
      jac[:,3]=BVF*A*alpha*getHD(t1,alpha,1/tau)
      jac[:,3]+=(1-BVF)*(K1T*A*alpha/wa*(getHD(t1,alpha,k2)-getHD(t1,1/tau,k2))+B*k1*gamma*getHD(t1,gamma,k2))
      
   jac[:,2]*=(1-BVF)

   return jac

def fCRegBVF(par):
   return par[1]

def jacDepRegBVF(par):
   jac=numpy.zeros(par.shape[0])
   jac[1]=1
   return jac