import matplotlib.pyplot as plt
import numpy as np
from charged_shells import expansion, interactions, mapping, charge_distributions
from charged_shells.parameters import ModelParams
from functools import partial
from config import *
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from labellines import labelLine, labelLines
from matplotlib.colors import TwoSlopeNorm
from matplotlib.ticker import FuncFormatter

Expansion = expansion.Expansion


def sEE_minimum(ex: Expansion, params: ModelParams, accuracy=1e-2, dist=2., match_expansion_axis_to_params=None,
                angle_start: float = 0, angle_stop: float = np.pi / 2):
    ex2 = ex.clone()
    angle_range = np.linspace(angle_start, angle_stop, int((angle_stop - angle_start) / accuracy), endpoint=True)
    # angle_range = np.linspace(0.1, 0.7, int(1 / accuracy), endpoint=True)

    ex.rotate_euler(alpha=0, beta=angle_range, gamma=0)
    ex2.rotate_euler(alpha=0, beta=angle_range, gamma=0)

    if match_expansion_axis_to_params is not None:
        match_expansion_axis_to_params += 1

    energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=dist),
                                                     match_expansion_axis_to_params)
    energy = energy_fn(ex, ex2, params)

    min_idx = np.argmin(energy, axis=0)
    min_energy = np.min(energy, axis=0)

    return min_energy, angle_range[min_idx]


def contours():

    kappaR = np.linspace(1, 10, 20)
    a_bar = np.linspace(0.2, 0.6, 20)
    params = ModelParams(R=150, kappaR=kappaR)
    ex = charge_distributions.create_mapped_quad_expansion(a_bar[:, None], kappaR[None, :], 0.001)
    min_energy, min_angle = sEE_minimum(ex, params, match_expansion_axis_to_params=1, accuracy=0.01)

    kR_mesh, a_mesh = np.meshgrid(kappaR, a_bar)

    print(np.min(min_angle), np.max(min_angle))
    print(np.min(min_energy), np.max(min_energy))

    plt.contourf(kR_mesh, a_mesh, min_angle)
    # plt.imshow(min_angle)
    plt.show()

    plt.contourf(kR_mesh, a_mesh, min_energy, np.array([-9, -5, -2.5, -1, -0.5, -0.2, 0]))
    # plt.imshow(min_energy)
    plt.show()


def kappaR_dependence(kappaR, save_as=None, cmap=cm.jet):

    a_bar = np.linspace(0.12, 0.8, 15)
    params = ModelParams(R=150, kappaR=kappaR)

    ex = charge_distributions.create_mapped_quad_expansion(a_bar[None, :], kappaR[:, None], 0.001)
    min_energy, min_angle = sEE_minimum(ex, params, match_expansion_axis_to_params=0, accuracy=0.001,
                                        angle_start=0.5, angle_stop=1.)

    kappaR_alt = np.array([0.01, 2, 5, 10, 50])
    params_alt = ModelParams(R=150, kappaR=kappaR_alt)
    a_bar_alt = np.linspace(np.min(a_bar) - 0.05, np.max(a_bar) + 0.05, 20)
    ex2 = charge_distributions.create_mapped_quad_expansion(a_bar_alt[:, None], kappaR_alt[None, :], 0.001)
    min_energy_alt, min_angle_alt = sEE_minimum(ex2, params_alt, match_expansion_axis_to_params=1, accuracy=0.001,
                                                angle_start=0.5, angle_stop=1.)

    colors = cmap(np.linspace(0, 1, len(a_bar)))
    sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=np.min(a_bar), vmax=np.max(a_bar)))
    sm.set_array([])

    # Transformation to the position along our quadrupolar rotational path
    min_angle = np.pi - min_angle
    min_angle_alt = np.pi - min_angle_alt

    kappaR_labels = [rf'$\kappa R={kR:.1f}$' for kR in kappaR_alt]

    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for me, ma, lbl in zip(min_energy_alt.T, min_angle_alt.T, kappaR_labels):
        ax.plot(ma, me, ls=':', c='k', label=lbl)
    labelLines(ax.get_lines(), align=False, fontsize=15,  xvals=(2.35, 2.65))

    for me, ma, lbl, c in zip(min_energy.T, min_angle.T, [rf'$\bar a={a:.2f}$' for a in a_bar], colors):
        ax.plot(ma, me, label=lbl, c=c)
    # ax.plot(min_angle, min_energy)
    # ax.legend(fontsize=17)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel('pry angle', fontsize=20)
    ax.set_ylabel('U', fontsize=20)
    ax.set_xlim(2.2, 2.65)
    ax.set_ylim(-31, 1)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.tick_params(labelsize=15)
    cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
    cbar.set_label(r'$\bar a$', rotation=0, labelpad=15, fontsize=20)

    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def IC_kappaR_dependence(which_kappa_lines: list,
                         save_as=None, cmap=cm.jet, file_suffix=""):
    em_data_path = ICI_DATA_PATH.joinpath("FIG_SUPP_sEE")
    em_data = np.load(em_data_path.joinpath(f"fig9{file_suffix}.npz"))['arr_0']
    # print(em_data.shape)

    a_bar, indices, counts = np.unique(em_data[:, 0], return_counts=True, return_inverse=True)
    print(a_bar)

    if not np.all(counts == counts[0]):
        raise ValueError("Data not reshapable.")

    colors = cmap(np.linspace(0, 1, len(a_bar)))
    sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=np.min(a_bar), vmax=np.max(a_bar)))
    sm.set_array([])

    min_energy = em_data[:, 3].reshape(-1, counts[0])
    min_angle = em_data[:, 2].reshape(-1, counts[0])

    kappa_line_energy = []
    kappa_line_angle = []
    kappaR_labels = []
    for kappa in which_kappa_lines:
        ind = em_data[:, 1] == kappa
        if np.sum(ind) == 0:
            print(f'No lines with kR={kappa} in data. Available values: {np.unique(em_data[:, 1])}')
            continue
        kappa_line_energy.append(em_data[:, 3][ind])
        kappa_line_angle.append(em_data[:, 2][ind])
        kappaR_labels.append(rf'$\kappa R={kappa:.1f}$')


    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for me, ma, lbl in zip(kappa_line_energy, kappa_line_angle, kappaR_labels):
        sort = np.argsort(ma)
        ax.plot(ma[sort], me[sort], ls=':', c='k', label=lbl)

    labelLines(ax.get_lines(), align=False, fontsize=15, xvals=(2.4, 2.65))

    for me, ma, lbl, c in zip(min_energy, min_angle, [rf'$\bar a={a:.2f}$' for a in a_bar], colors):
        ax.plot(ma, me, label=lbl, c=c)
    # ax.plot(min_angle, min_energy)
    # ax.legend(fontsize=17)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel('pry angle', fontsize=20)
    ax.set_ylabel('U', fontsize=20)
    ax.set_xlim(2.2, 2.65)
    ax.set_ylim(-41, 1)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    cax.tick_params(labelsize=15)
    cbar = fig.colorbar(sm, cax=cax, orientation='vertical')
    cbar.set_label(r'$\bar a$', rotation=0, labelpad=15, fontsize=20)

    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


# def charge_dependence(charge, save_as=None):
#
#     a_bar = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
#     params = ModelParams(R=150, kappaR=kappaR)
#
#     ex = expansion.MappedExpansionQuad(a_bar[None, :], kappaR[:, None], 0.001)
#     min_energy, min_angle = sEE_minimum(ex, params, match_expansion_axis_to_params=0, accuracy=0.001,
#                                         angle_start=0.5, angle_stop=1.)
#
#     kappaR_alt = np.array([0.01, 2, 5, 10, 50])
#     params_alt = ModelParams(R=150, kappaR=kappaR_alt)
#     a_bar_alt = np.linspace(np.min(a_bar) - 0.05, np.max(a_bar) + 0.05, 20)
#     ex2 = expansion.MappedExpansionQuad(a_bar_alt[:, None], kappaR_alt[None, :], 0.001)
#     min_energy_alt, min_angle_alt = sEE_minimum(ex2, params_alt, match_expansion_axis_to_params=1, accuracy=0.001,
#                                                 angle_start=0.5, angle_stop=1.)
#
#     fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
#     for me, ma, lbl in zip(min_energy_alt.T, min_angle_alt.T, [rf'$\bar a={a:.2f}$' for a in a_bar]):
#         ax.plot(ma, me, ls=':', c='k')
#     for me, ma, lbl in zip(min_energy.T, min_angle.T, [rf'$\bar a={a:.2f}$' for a in a_bar]):
#         ax.plot(ma, me, label=lbl)
#     # ax.plot(min_angle, min_energy)
#     ax.legend(fontsize=17)
#     ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
#     ax.set_xlabel('angle', fontsize=15)
#     ax.set_ylabel('U', fontsize=20)
#     ax.set_xlim(0.56, 0.93)
#     ax.set_ylim(-20, 1)
#     plt.tight_layout()
#     if save_as is not None:
#         plt.savefig(save_as, dpi=600)
#     plt.show()


def main():

    # contours()

    # kappaR_dependence(kappaR=np.linspace(0.1, 30, 20),
    #                   save_as=Path(config_data["figures"]).joinpath('sEE_min_kappaR_abar.png')
    #                   )

    IC_kappaR_dependence(which_kappa_lines=[0.01, 3., 5., 10., 50.], file_suffix="_1",
                         # save_as=FIGURES_PATH.joinpath("ICi_data").joinpath('IC_sEE_min_kappaR_abar.png')
                         )


if __name__ == '__main__':

    main()