import matplotlib.pyplot
import fitData
import fitModel
import numpy
import functools
import geometry
import segmentation

def plotIVF(t,ivf,samples,threshold,file0=None,file1=None):
    
    fit=fitData.getFit(samples,threshold)
    
    t1=numpy.linspace(0,numpy.max(t),300)
    
    w=numpy.ones(t.shape)
    fun=functools.partial(fitModel.fDiff,fitModel.fIVF,t,ivf,w)
    
    matplotlib.pyplot.figure(0)
    matplotlib.pyplot.scatter(t,ivf)
    n=samples.shape[1]
    alpha=1/n
    alpha1=1
    for j in range(samples.shape[1]):
        cPar=samples[1:,j]
        df=fun(cPar)
        chi2=(df*df).sum()
        cost=samples[0,j]
        if chi2!=cost:
            print('{}/{} {}'.format(samples[0,j],chi2,cPar))
        
        fv=fitModel.fIVF(t1,cPar)
        
        if chi2>threshold:
            matplotlib.pyplot.figure(0)
            matplotlib.pyplot.plot(t1,fv,color='silver',alpha=alpha)
            matplotlib.pyplot.figure(1)
            matplotlib.pyplot.plot(t,df,color='silver',alpha=alpha)
        else:
            matplotlib.pyplot.figure(0)
            matplotlib.pyplot.plot(t1,fv,alpha=alpha1)
            matplotlib.pyplot.figure(1)
            matplotlib.pyplot.plot(t,df,alpha=alpha1)
    if file0:
        matplotlib.pyplot.figure(0)
        matplotlib.pyplot.savefig(file0)
    
    if file1:
        matplotlib.pyplot.figure(1)
        matplotlib.pyplot.savefig(file1)

def plotIVFCenter(centerMapSPECT,centerMapCT,m,data,ct,file0=None,file1=None):
   um=centerMapSPECT==m
   loc=numpy.nonzero(centerMapSPECT==m)

   #print(loc)
   #get center of weight of voxels that belong to center indicated as IVF sample
   x=[numpy.mean(t) for t in loc]
   #print(x)
   #convert it to index
   iSlice=[int(y) for y in x]

   #print(iSlice)

   d20=data[...,-1]
   fig=matplotlib.pyplot.figure(figsize=(12, 3))
   axs = fig.subplots(1, 3)
   axs[0].imshow(d20[iSlice[0],...],cmap='gray_r')
   axs[0].imshow(um[iSlice[0],...],alpha=0.2,cmap='Reds')
   axs[1].imshow(d20[:,iSlice[1],:],cmap='gray_r')
   axs[1].imshow(um[:,iSlice[1],:],alpha=0.2,cmap='Reds')
   axs[2].imshow(d20[...,iSlice[2]],cmap='gray_r')
   axs[2].imshow(um[...,iSlice[2]],alpha=0.2,cmap='Reds')
   if file0:
      matplotlib.pyplot.savefig(file0)

   #overlay it over CT
   um1=centerMapCT==m
   loc1=numpy.nonzero(um1)
   #print(loc1)
   x1=[numpy.mean(t) for t in loc1]
   #print(x1)
   iSlice1=[int(y) for y in x1]
   #print(iSlice1)
   fig=matplotlib.pyplot.figure(figsize=(12, 3))
   axs = fig.subplots(1, 3)
   axs[0].imshow(ct[iSlice1[0],...],cmap='gray')
   axs[0].imshow(um1[iSlice1[0],...],alpha=0.4,cmap='Reds')
   axs[1].imshow(ct[:,iSlice1[1],:],aspect='auto',origin='lower',cmap='gray')
   axs[1].imshow(um1[:,iSlice1[1],:],alpha=0.4,aspect='auto',origin='lower',cmap='Reds')
   axs[2].imshow(ct[...,iSlice1[2]],aspect='auto',origin='lower',cmap='gray')
   axs[2].imshow(um1[...,iSlice1[2]],alpha=0.4,aspect='auto',origin='lower',cmap='Reds')
   if file1:
      matplotlib.pyplot.savefig(file1)

def plotIVFRealizations(t,ivf,samples,threshold,nplot=50,file=None):
    fit=fitData.getFit(samples,threshold)
    
    t1=numpy.linspace(0,numpy.max(t),300)
    #cPar=fit.mu
    #fv=fitModel.fIVF(t1,cPar)
    w=numpy.ones(t.shape)
    fun=functools.partial(fitModel.fDiff,fitModel.fIVF,t,ivf,w)
    ig=fitData.generateIVF()
    par,bounds=ig.generate()
              
    ivfSamples=fitData.generateGauss(fit,bounds,nplot)
    matplotlib.pyplot.figure()
    matplotlib.pyplot.scatter(t,ivf)
    for i in range(nplot):
        cPar=ivfSamples[:,i]
        fv=fitModel.fIVF(t1,cPar)
        #df=fun(cPar)
        matplotlib.pyplot.plot(t1,fv,alpha=0.05,color='blue')
        #matplotlib.pyplot.plot(t,df)  
    if file:
        matplotlib.pyplot.savefig(file)


def plotSamples(t,evalArray,file0=None,file1=None):   

   igIVF=fitData.generateIVF()
   ig=fitData.generateCModel()
   nIVF=igIVF.getN()
   nC=ig.getN()
    
   t1=numpy.linspace(0,numpy.max(t),300)
   f1=matplotlib.pyplot.figure()
   f2=matplotlib.pyplot.figure()
   matplotlib.pyplot.figure(f1.number)
   for x in evalArray:
      matplotlib.pyplot.scatter(t,x[1],color=x[2])
    
   for x in evalArray:
      samples=x[0]
      color=x[2]
      qf=x[1]
      chi2=samples[0,:]
      median=numpy.median(chi2)
      n=chi2.shape[0]
      alpha=1/n
      alpha1=0.5
      for j in range(n):
         cPar=samples[1:1+nC,j]
         ivfPar=samples[1+nC:,j]
         fv=fitModel.fCompartment(ivfPar,t1,cPar)

         fc1=functools.partial(fitModel.fCompartment,ivfPar)
         fun=functools.partial(fitModel.fDiff,fc1,t,qf,numpy.ones(t.shape))


         df=fun(cPar)
         cost=chi2[j]
         c1=(df*df).sum()
         print(f'{cost}/{c1}')
         if cost>median:
            matplotlib.pyplot.figure(f1.number)
            matplotlib.pyplot.plot(t1,fv,color='silver',alpha=alpha)
            matplotlib.pyplot.figure(f2.number)
            matplotlib.pyplot.plot(t,df,color='silver',alpha=alpha)
        
         else:
            matplotlib.pyplot.figure(f1.number)
            matplotlib.pyplot.plot(t1,fv,color=color,alpha=alpha1)
            matplotlib.pyplot.figure(f2.number)
            matplotlib.pyplot.plot(t,df,color=color,alpha=alpha1)
   if file0:
      matplotlib.pyplot.figure(f1.number)
      matplotlib.pyplot.savefig(file0)
    
   if file1:
      matplotlib.pyplot.figure(f2.number)
      matplotlib.pyplot.savefig(file1)

def plotSegmentation(r,setup,spect,vmax=1000,file0=None,fb=None):
   fp={}
   
   for q in rows:
      if q['regionId']==0:
         slc=[q['x'],q['y'],q['z']]
         slc=[int(x) for x in slc]
      slices=q['sliceId'].split(';')
      for s in slices:
         try:
            fp[s].append([float(x) for x in [q['x'],q['y'],q['z']]])
         except KeyError:
            fp[s]=[]
            fp[s].append([float(x) for x in [q['x'],q['y'],q['z']]])

   cut0=20
   w0=20
   cut1=20
   w1=20
   cut2=20
   w2=20
   vmin=0
   nd=3
   fig,ax=matplotlib.pyplot.subplots(3,2*nd+1,figsize=(20,12))
   for i in numpy.arange(0,2*nd+1):
      ax[0,i].set_xlabel('z')
      ax[0,i].set_ylabel('x')
      ax[0,i].imshow(spect[cut2:cut2+w2,slc[1]-nd+i,cut0:cut0+w0],cmap='gray_r',vmax=vmax,vmin=vmin)
      ax[1,i].set_xlabel('x')
      ax[1,i].set_ylabel('y')
      ax[1,i].imshow(spect[cut2:cut2+w2,cut0:cut0+w0,slc[2]-nd+i].T,cmap='gray_r',vmax=vmax,vmin=vmin)
      ax[2,i].set_xlabel('z')
      ax[2,i].set_ylabel('y')
      ax[2,i].imshow(spect[slc[0]-nd+i,cut1:cut1+w1,cut1:cut1+w1],cmap='gray_r',vmax=vmax,vmin=vmin)
      if i==nd:
         pt=fp['0']
         ax[0,i].scatter([x[2]-cut0 for x in pt],[x[0]-cut2 for x in pt])
         pt=fp['1']
         ax[1,i].scatter([x[0]-cut2 for x in pt],[x[1]-cut0 for x in pt])
         pt=fp['2']
         ax[2,i].scatter([x[2]-cut1 for x in pt],[x[1]-cut1 for x in pt])

         if i==0:
            ax[0,i].text(2,2,pId,fontsize='large')
   if file0:
      fig.savefig(file0)