import matplotlib.pyplot as plt
import numpy as np
from charged_shells import expansion, interactions, mapping, charge_distributions
from charged_shells.parameters import ModelParams
from functools import partial
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import TwoSlopeNorm
from matplotlib.ticker import FuncFormatter
from config import *

Expansion = expansion.Expansion


def energy_gap(ex1: Expansion, params: ModelParams, dist=2., match_expansion_axis_to_params=None):
    ex2 = ex1.clone()
    ex3 = ex1.clone()

    ex2.rotate_euler(alpha=0, beta=np.pi/2, gamma=0)  # to get EP config between ex1 and ex2

    energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=dist),
                                                     match_expansion_axis_to_params)
    energy_ep = energy_fn(ex1, ex2, params)
    energy_pp = energy_fn(ex1, ex3, params)

    return (energy_pp - energy_ep) / energy_pp


def abar_kappaR_dependence(save_as=None):

    kappaR = np.linspace(0.01, 25, 25)
    a_bar = np.array([0.2, 0.4, 0.6, 0.8])
    params = ModelParams(R=150, kappaR=kappaR)
    ex = charge_distributions.create_mapped_quad_expansion(a_bar[:, None], kappaR[None, :], 0.001)
    # ex = expansion.MappedExpansionQuad(a_bar, kappaR, 0.001)
    gap = energy_gap(ex, params, match_expansion_axis_to_params=1)

    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for g, lbl in zip(gap, [rf'$\bar a={a}$' for a in a_bar]):
        ax.plot(kappaR, g, label=lbl)
    ax.legend(fontsize=17)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
    ax.set_xlabel(r'$\kappa R$', fontsize=15)
    ax.set_ylabel(r'$\frac{V_{pp}-V_{ep}}{V_{pp}}$', fontsize=20)
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def abar_kappaR_dependence2(save_as=None):

    kappaR = np.array([1, 3, 10, 30])
    a_bar = np.linspace(0.2, 0.8, 30)
    params = ModelParams(R=150, kappaR=kappaR)
    ex = charge_distributions.create_mapped_quad_expansion(a_bar[:, None], kappaR[None, :], 0.001)
    # ex = expansion.MappedExpansionQuad(a_bar, kappaR, 0.001)
    gap = energy_gap(ex, params, match_expansion_axis_to_params=1)

    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for g, lbl in zip(gap.T, [rf'$\kappa R={kR}$' for kR in kappaR]):
        ax.plot(a_bar, g, label=lbl)
    ax.legend(fontsize=17)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
    ax.set_xlabel(r'$\bar a$', fontsize=15)
    ax.set_ylabel(r'$\frac{V_{pp}-V_{ep}}{V_{pp}}$', fontsize=20)
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def charge_kappaR_dependence(a_bar, min_charge, max_charge, save_as=None, cmap=cm.jet):

    kappaR = np.linspace(0.01, 10, 50)
    sigma_tilde = 0.001
    params = ModelParams(R=150, kappaR=kappaR)
    charge = np.linspace(min_charge, max_charge, 100)
    ex = charge_distributions.create_mapped_quad_expansion(a_bar, kappaR, sigma_tilde, sigma0=charge)
    # ex = expansion.MappedExpansionQuad(a_bar, kappaR, 0.001)
    gap = energy_gap(ex, params, match_expansion_axis_to_params=0)

    colors = cmap(np.linspace(0, 1, len(charge)))
    sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=np.min(charge) / sigma_tilde,
                                                         vmax=np.max(charge) / sigma_tilde))
    sm.set_array([])

    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for g, c in zip(gap.T, colors):
        ax.plot(kappaR, g, c=c)
    # ax.legend(fontsize=17)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\kappa R$', fontsize=20)
    ax.set_ylabel(r'$(V_{pp}-V_{ep})/V_{pp}$', fontsize=20)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.tick_params(labelsize=15)
    cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
    cbar.set_label(r'$\eta$', rotation=0, labelpad=15, fontsize=20)

    # plt.tight_layout()
    plt.subplots_adjust(left=0.1, right=0.9, top=0.95, bottom=0.12)
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def charge_kappaR_dependence_heatmap(a_bar, min_charge, max_charge, save_as=None, cmap=cm.jet):

    kappaR = np.linspace(0.01, 10, 50)
    params = ModelParams(R=150, kappaR=kappaR)
    charge = np.linspace(min_charge, max_charge, 100)
    ex = charge_distributions.create_mapped_quad_expansion(a_bar, kappaR, 0.001, sigma0=charge)
    # ex = expansion.MappedExpansionQuad(a_bar, kappaR, 0.001)
    gap = energy_gap(ex, params, match_expansion_axis_to_params=0)

    norm = TwoSlopeNorm(vmin=np.min(gap), vcenter=0, vmax=np.max(gap))
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    def y_formatter(x, pos):
        return f"{charge[int(x)-1]:.2f}"

    def x_formatter(x, pos):
        return f"{kappaR[int(x)-1]:.2f}"

    fig, ax = plt.subplots(figsize=plt.figaspect(1))
    ax.imshow(gap.T, cmap=cmap, origin='lower',
              # extent=[kappaR.min(), kappaR.max(), charge.min(), charge.max()]
              )
    # ax.legend(fontsize=17)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
    ax.set_xlabel(r'$\kappa R$', fontsize=15)
    ax.set_ylabel(r'$\tilde \sigma_0$', fontsize=15)

    plt.gca().xaxis.set_major_formatter(FuncFormatter(x_formatter))
    plt.gca().yaxis.set_major_formatter(FuncFormatter(y_formatter))

    # ax.set_xticks(kappaR)
    # ax.set_yticks(charge)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
    cbar.set_label(r'$\frac{V_{pp}-V_{ep}}{V_{pp}}$', rotation=90, labelpad=20, fontsize=12)

    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def IC_gap_plot(save_as=None):
    em_data_path = (ICI_DATA_PATH.joinpath("FIG_10"))
    em_data = np.load(em_data_path.joinpath("relative_gap.npz"))
    for k in list(em_data.keys()):
        data = em_data[k]
        print(k, data.shape)
        for i in range(3):
            print(np.unique(data[:, i]))
        print('\n')


def IC_gap_kappaR(save_as=None):
    em_data_path = (ICI_DATA_PATH.joinpath("FIG_10"))
    em_data = np.load(em_data_path.joinpath("relative_gap.npz"))
    data = em_data['fixA']

    print(data)

    sort = np.argsort(data[:, 2])
    xdata = data[:, 2][sort]
    ydata = data[:, 3][sort]

    plt.plot(xdata, ydata)
    plt.xlabel('kappaR')
    plt.ylabel('gap')
    plt.show()


def IC_gap_abar(save_as=None):
    em_data_path = (ICI_DATA_PATH.joinpath("FIG_10"))
    em_data = np.load(em_data_path.joinpath("relative_gap.npz"))
    data = em_data['fixM']

    print(data)

    sort = np.argsort(data[:, 1])
    xdata = data[:, 1][sort]
    ydata = data[:, 3][sort]

    plt.plot(xdata, ydata)
    plt.xlabel('abar')
    plt.ylabel('gap')
    plt.show()


def IC_gap_charge_at_abar(a_bar, save_as=None, cmap=cm.coolwarm, which_change='changezp',
                          eta_min: float = None, eta_max: float = None):
    em_data_path = (ICI_DATA_PATH.joinpath("FIG_10"))
    em_data = np.load(em_data_path.joinpath("relative_gap_ZC.npz"))
    data = em_data[which_change]

    sigma_tilde = 0.001

    relevant_indices = data[:, 1] == a_bar
    if not np.any(relevant_indices):
        raise ValueError(f'No results for given a_bar = {a_bar}. Possible values: {np.unique(data[:, 1])}')

    data = data[relevant_indices]

    charge, inverse, counts = np.unique(data[:, 0], return_counts=True, return_inverse=True)
    # print(f'All charge: {charge}')

    eta = charge / sigma_tilde

    if eta_min is None:
        eta_min = np.min(eta)
    if eta_max is None:
        eta_max = np.max(eta)

    def map_eta_to_unit(x):
        return (x - eta_min) / (eta_max - eta_min)

    # print(eta[0], eta[1])
    # print(map_eta_to_unit(eta[0]), map_eta_to_unit(eta[-1]))

    colors_linspace = np.linspace(map_eta_to_unit(eta[0]), map_eta_to_unit(eta[-1]), len(charge))
    colors_linspace[colors_linspace > 1] = 1
    colors_linspace[colors_linspace < 0] = 0

    colors = cmap(colors_linspace)
    sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=eta_min, vmax=eta_max))
    sm.set_array([])

    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for i, c in enumerate(colors):
        idx, = np.nonzero(inverse == i)
        kR = data[idx, 2]
        gap = data[idx, 3]
        sort = np.argsort(kR)
        ax.plot(kR[sort], gap[sort], c=c)

    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\kappa R$', fontsize=20)
    ax.set_ylabel(r'$(V_{pp}-V_{ep})/V_{pp}$', fontsize=20)
    ax.set_xlim(-0.25, 10.25)
    ax.set_ylim(-4, 5)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.tick_params(labelsize=15)
    cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
    cbar.set_label(r'$\eta$', rotation=0, labelpad=15, fontsize=20)

    # plt.tight_layout()
    plt.subplots_adjust(left=0.1, right=0.9, top=0.95, bottom=0.12)
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def test_gap(a_bar, kappaR, charge):
    params = ModelParams(R=150, kappaR=kappaR)
    ex = charge_distributions.create_mapped_quad_expansion(a_bar, kappaR, 0.001, sigma0=charge)
    gap = energy_gap(ex, params, match_expansion_axis_to_params=None)
    print(gap)


def main():

    # test_gap(0.3, 10, charge=-0.003)

    # abar_kappaR_dependence(Path("/home/andraz/ChargedShells/Figures/full_amplitude_kappaR_dep.png"))
    # abar_kappaR_dependence2(Path("/home/andraz/ChargedShells/Figures/full_amplitude_abar_dep.png"))

    # charge_kappaR_dependence(a_bar=0.8, min_charge=-0.002, max_charge=0.002,
    #                          save_as=Path("/home/andraz/ChargedShells/Figures/full_amplitude_charge_abar08.png"),
    #                          cmap=cm.coolwarm)

    # charge_kappaR_dependence_heatmap(a_bar=0.5, min_charge=-0.003, max_charge=0.003,
    #                          save_as=Path("/home/andraz/ChargedShells/Figures/full_amplitude_heatmap_abar05.png"),
    #                          cmap=cm.bwr)

    # IC_gap_plot()

    # IC_gap_kappaR()
    # IC_gap_abar()

    IC_gap_charge_at_abar(0.3, which_change='changezc', eta_min=-2, eta_max=2,
                          save_as=FIGURES_PATH.joinpath('ICi_data').joinpath('IC_full_amplitude_charge_abar03.png')
                          )


if __name__ == '__main__':

    main()