from charged_shells import expansion, interactions, mapping, functions
from charged_shells.parameters import ModelParams
import numpy as np
from typing import Literal
from pathlib import Path
from functools import partial
import json
from plot_settings import *
from dataclasses import dataclass

Array = np.ndarray
Expansion = expansion.Expansion
RotConfig = Literal['ep', 'pp', 'sep']


@dataclass
class Peak:
    rot_config: str
    y_label: str
    ex1: Expansion
    ex2: Expansion
    emanuele_data_column: int
    log_y: bool
    kappaR_axis_in_expansion: int = None


class PeakEP(Peak):

    def __init__(self, ex: Expansion, log_y: bool = False, kappaR_axis_in_expansion: int = None):
        self.emanuele_data_column = 4
        self.y_label = r'$|U_{EP}|$' if log_y else r'$U_{EP}$'
        self.ex1 = ex.clone()
        self.ex2 = ex.clone()
        self.ex1.rotate_euler(alpha=0, beta=np.pi / 2, gamma=0)
        self.log_y = log_y
        self.kappaR_axis_in_expansion = kappaR_axis_in_expansion


class PeakPP(Peak):

    def __init__(self, ex: Expansion, log_y: bool = False, kappaR_axis_in_expansion: int = None):
        self.emanuele_data_column = 3
        self.y_label = r'$U_{PP}$'
        self.ex1 = ex.clone()
        self.ex2 = ex.clone()
        self.log_y = log_y
        self.kappaR_axis_in_expansion = kappaR_axis_in_expansion


class PeakSEP(Peak):

    def __init__(self, ex: Expansion, log_y: bool = False, kappaR_axis_in_expansion: int = None):
        self.emanuele_data_column = 5
        self.y_label = r'$|U_{sEP}|$' if log_y else r'$U_{sEP}$'
        self.ex1 = ex.clone()
        self.ex2 = ex.clone()
        self.ex1.rotate_euler(alpha=0, beta=np.pi / 4, gamma=0)
        self.ex2.rotate_euler(alpha=0, beta=np.pi / 4, gamma=0)
        self.log_y = log_y
        self.kappaR_axis_in_expansion = kappaR_axis_in_expansion


def get_charge_energy_dicts(peak: Peak, params: ModelParams, emanuele_data: Array, sigma0: Array, abar_cs: list, sigma_tilde=0.001):
    abar_ic, inverse, counts = np.unique(emanuele_data[:, 1], return_counts=True, return_inverse=True)

    energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=2),
                                                     peak.kappaR_axis_in_expansion)
    energy = energy_fn(peak.ex1, peak.ex2, params)

    data_dict_ic = {}
    data_dict_cs = {}
    k = 0
    for i in range(len(abar_ic)):
        ab = np.around(abar_ic[i], 5)
        if ab in np.around(abar_cs, 5):
            idx, = np.nonzero(inverse == i)
            charge = emanuele_data[idx, 0]
            en = emanuele_data[idx, peak.emanuele_data_column]
            sort = np.argsort(charge)
            if peak.log_y:
                data_dict_ic[ab] = np.stack((charge[sort] / 280 + 2, np.abs(en)[sort])).T
                data_dict_cs[ab] = np.stack((sigma0 / sigma_tilde, np.abs(energy[k]))).T
            else:
                data_dict_ic[ab] = np.stack((charge[sort] / 280 + 2, en[sort])).T
                data_dict_cs[ab] = np.stack((sigma0 / sigma_tilde, energy[k])).T
            k += 1

    return data_dict_ic, data_dict_cs


def get_kappaR_energy_dicts(peak: Peak, params: ModelParams, emanuele_data: Array, kappaR: Array, abar_cs: list, max_kappaR: float = 50):
    abar_ic, inverse = np.unique(emanuele_data[:, 1], return_inverse=True)

    energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=2),
                                                     peak.kappaR_axis_in_expansion)
    energy = energy_fn(peak.ex1, peak.ex2, params)

    data_dict_ic = {}
    data_dict_cs = {}
    k = 0
    for i in range(len(abar_ic)):
        ab = np.around(abar_ic[i], 5)
        if ab in np.around(abar_cs, 5):
            idx, = np.nonzero(inverse == i)
            kR = emanuele_data[idx, 2][emanuele_data[idx, 2] <= max_kappaR]
            en = emanuele_data[idx, peak.emanuele_data_column][emanuele_data[idx, 1] <= max_kappaR]
            sort = np.argsort(kR)
            if peak.log_y:
                data_dict_ic[ab] = np.stack((kR[sort], np.abs(en)[sort])).T
                data_dict_cs[ab] = np.stack((kappaR, np.abs(energy[k]))).T
            else:
                data_dict_ic[ab] = np.stack((kR[sort], en[sort])).T
                data_dict_cs[ab] = np.stack((kappaR, energy[k])).T
            k += 1

    return data_dict_ic, data_dict_cs


def get_abar_energy_dicts(peak: Peak, params: ModelParams, emanuele_data: Array, a_bar: Array, kR_cs: list, max_abar: float = 0.8):

    kR_ic, inverse = np.unique(emanuele_data[:, 2], return_inverse=True)

    energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=2),
                                                     peak.kappaR_axis_in_expansion)
    energy = energy_fn(peak.ex1, peak.ex2, params)

    data_dict_ic = {}
    data_dict_cs = {}
    k = 0
    for i in range(len(kR_ic)):
        kr = np.around(kR_ic[i], 5)
        if kr in np.around(kR_cs, 5):
            idx, = np.nonzero(inverse == i)
            abar = emanuele_data[idx, 1][emanuele_data[idx, 1] <= max_abar]
            en = emanuele_data[idx, peak.emanuele_data_column][emanuele_data[idx, 1] <= max_abar]
            sort = np.argsort(abar)
            if peak.log_y:
                data_dict_ic[kr] = np.stack((abar[sort], np.abs(en)[sort])).T
                data_dict_cs[kr] = np.stack((a_bar, np.abs(energy[k]))).T
            else:
                data_dict_ic[kr] = np.stack((abar[sort], en[sort])).T
                data_dict_cs[kr] = np.stack((a_bar, energy[k])).T
            k += 1

    return data_dict_ic, data_dict_cs


def save_data(data_dict_ic, data_dict_cs, a_bar: list):
    if save_data:
        ic_save = {str(key): val for key, val in data_dict_ic.items()}
        cs_save = {str(key): val for key, val in data_dict_cs.items()}
        ic_save['abar'] = a_bar
        cs_save['abar'] = a_bar
        with open(Path("/home/andraz/ChargedShells/charged-shells/config.json")) as config_file:
            config_data = json.load(config_file)
        np.savez(Path(config_data["figure_data"]).joinpath("fig_11_IC.npz"), **ic_save)
        np.savez(Path(config_data["figure_data"]).joinpath("fig_11_CS.npz"), **cs_save)


def IC_peak_energy_plot(config_data: dict,
                        a_bar: list,
                        which: RotConfig,
                        min_kappaR: float = 0.01,
                        max_kappaR: float = 50.,
                        R: float = 150,
                        save_as: Path = None,
                        save_data: bool = False,
                        quad_correction: bool = False,
                        log_y: bool = True):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_11"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_PANELS_BDF"))
    em_data = np.load(em_data_path.joinpath("newpair_energy.npz"))
    data = em_data['fixA']

    num = 100 if which == 'sep' else 30
    kappaR = np.geomspace(min_kappaR, max_kappaR, num)
    params = ModelParams(R=R, kappaR=kappaR)

    ex = expansion.MappedExpansionQuad(np.array(a_bar)[:, None], params.kappaR[None, :], 0.001, l_max=20)

    if which == 'ep':
        peak = PeakEP(ex, log_y, kappaR_axis_in_expansion=1)
    elif which == 'pp':
        peak = PeakPP(ex, log_y, kappaR_axis_in_expansion=1)
    elif which == 'sep':
        peak = PeakSEP(ex, log_y, kappaR_axis_in_expansion=1)
    else:
        raise ValueError

    data_dict_ic, data_dict_cs = get_kappaR_energy_dicts(peak, params, data, kappaR, a_bar, max_kappaR)

    colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

    fig, ax = plt.subplots(figsize=(4.25, 3))
    for (key, data1), data2 in zip(data_dict_ic.items(), data_dict_cs.values()):
        current_color = next(colors)
        ax.plot(data1[:, 0], data1[:, 1], label=rf'$\bar a = {key:.2f}$', c=current_color)
        ax.plot(data2[:, 0], data2[:, 1], ls='--', c=current_color)
    ax.legend(fontsize=14, ncol=1, frameon=False, handlelength=0.7, loc='upper right',
              bbox_to_anchor=(0.42, 0.42),
              # bbox_to_anchor=(0.42, 0.9)
              )
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\kappa R$', fontsize=15)

    ax.set_ylabel(peak.y_label, fontsize=15)
    if log_y:
        ax.set_yscale('log')
    ax.set_xscale('log')
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def IC_peak_energy_abar_plot(config_data: dict,
                        kappaR: list,
                        which: RotConfig,
                        min_abar: float = 0.01,
                        max_abar: float = 50.,
                        R: float = 150,
                        save_as: Path = None,
                        save_data: bool = False,
                        quad_correction: bool = False,
                        log_y: bool = True):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_11"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_PANELS_BDF"))
    em_data = np.load(em_data_path.joinpath("newpair_energy.npz"))
    data = em_data['fixA']

    num = 100 if which == 'sep' else 30
    a_bar = np.linspace(min_abar, max_abar, num)
    params = ModelParams(R=R, kappaR=np.array(kappaR))

    ex = expansion.MappedExpansionQuad(np.array(a_bar)[None, :], params.kappaR[:, None], 0.001, l_max=20)

    if which == 'ep':
        peak = PeakEP(ex, log_y, kappaR_axis_in_expansion=0)
    elif which == 'pp':
        peak = PeakPP(ex, log_y, kappaR_axis_in_expansion=0)
    elif which == 'sep':
        peak = PeakSEP(ex, log_y, kappaR_axis_in_expansion=0)
    else:
        raise ValueError

    data_dict_ic, data_dict_cs = get_abar_energy_dicts(peak, params, data, a_bar, kappaR, max_abar)

    colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

    fig, ax = plt.subplots(figsize=(4.25, 3))
    for (key, data1), data2 in zip(data_dict_ic.items(), data_dict_cs.values()):
        current_color = next(colors)
        ax.plot(data1[:, 0], data1[:, 1], label=rf'$\kappa R = {key:.2f}$', c=current_color)
        ax.plot(data2[:, 0], data2[:, 1], ls='--', c=current_color)
    ax.legend(fontsize=14, ncol=1, frameon=False, handlelength=0.7,
              # loc='lower right',
              # bbox_to_anchor=(0.42, 0.42),
              # bbox_to_anchor=(0.42, 0.9)
              )
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\bar a$', fontsize=15)

    ax.set_ylabel(peak.y_label, fontsize=15)
    if log_y:
        ax.set_yscale('log')
    # ax.set_xscale('log')
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def IC_peak_energy_charge_plot(config_data: dict,
                        a_bar: list,
                        which: RotConfig,
                        max_kappaR: float = 30.,
                        R: float = 150,
                        save_as: Path = None,
                        save_data: bool = False,
                        quad_correction: bool = False,
                        log_y: bool = False):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_11"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_PANELS_BDF"))
    em_data = np.load(em_data_path.joinpath("newpair_energy.npz"))
    data = em_data['changeZc']

    params = ModelParams(R=R, kappaR=3)
    sigma0 = np.linspace(-0.0003, 0.0003, 300)
    sigma_tilde = 0.001

    ex = expansion.MappedExpansionQuad(np.array(a_bar), kappaR=3, sigma0=sigma0, l_max=20, sigma_tilde=sigma_tilde)

    if which == 'ep':
        peak = PeakEP(ex, log_y)
    elif which == 'pp':
        peak = PeakPP(ex, log_y)
    elif which == 'sep':
        peak = PeakSEP(ex, log_y)
    else:
        raise ValueError

    data_dict_ic, data_dict_cs = get_charge_energy_dicts(peak, params, data, sigma0, a_bar, sigma_tilde)

    colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

    fig, ax = plt.subplots(figsize=(4.25, 3))
    for (key, data1), data2 in zip(data_dict_ic.items(), data_dict_cs.values()):
        current_color = next(colors)
        ax.plot(data1[:, 0], data1[:, 1], label=rf'$\bar a = {key:.2f}$', c=current_color)
        ax.plot(data2[:, 0], data2[:, 1], ls='--', c=current_color)
    ax.legend(fontsize=14, ncol=1, frameon=False, handlelength=0.7, loc='upper right',
              # bbox_to_anchor=(0.42, 0.95),
              # bbox_to_anchor=(0.7, 1),
              bbox_to_anchor=(0.42, 1),
              )
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\eta$', fontsize=15)

    ax.set_ylabel(peak.y_label, fontsize=15)
    if log_y:
        ax.set_yscale('log')
    # ax.set_xscale('log')
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def IC_peak_energy_kappaR_combined_plot(config_data: dict,
                        a_bar: list,
                        R: float = 150,
                        save_as: Path = None,
                        min_kappaR: float = 0.01,
                        max_kappaR: float = 50.,
                        ):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_11"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_PANELS_BDF"))
    em_data = np.load(em_data_path.joinpath("newpair_energy.npz"))
    data = em_data['fixA']

    num = 50
    kappaR = np.geomspace(min_kappaR, max_kappaR, num)
    params = ModelParams(R=R, kappaR=kappaR)

    ex = expansion.MappedExpansionQuad(np.sort(np.array(a_bar))[:, None],  # sorting necessary as it is also in energy_dicts()
                                       params.kappaR[None, :], 0.001, l_max=20)

    peak_ep = PeakEP(ex, log_y=True, kappaR_axis_in_expansion=1)
    peak_pp = PeakPP(ex, log_y=True, kappaR_axis_in_expansion=1)
    peak_sep = PeakSEP(ex, log_y=False, kappaR_axis_in_expansion=1)
    peaks = [peak_ep, peak_pp, peak_sep]

    data_ic = []
    data_cs = []
    for peak in peaks:
        dict_ic, dict_cs = get_kappaR_energy_dicts(peak, params, data, kappaR, a_bar, max_kappaR)
        data_ic.append(dict_ic)
        data_cs.append(dict_cs)

    legend_coords = [(0.58, 0.45), (0.58, 0.43), (0.58, 0.9)]

    fig, axs = plt.subplots(3, 1, figsize=(3, 7.8))
    for ax, data_dict_ic, data_dict_cs, peak, lc in zip(axs, data_ic, data_cs, peaks, legend_coords):
        colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
        for ab in a_bar:
            key = np.around(ab, 5)
            current_color = next(colors)
            ax.plot(data_dict_ic[key][:, 0], data_dict_ic[key][:, 1], label=rf'$\bar a = {key:.2f}$', c=current_color)
            ax.plot(data_dict_cs[key][:, 0], data_dict_cs[key][:, 1], ls='--', c=current_color)
        ax.axvline(x=10, c='k', ls=':')
        ax.axvline(x=3, c='k', ls=':')
        ax.axvline(x=1, c='k', ls=':')
        ax.legend(fontsize=13, ncol=1, frameon=False, handlelength=0.7, loc='upper right', bbox_to_anchor=lc)
        ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
        # if ax == axs[-1]:
        ax.set_xlabel(r'$\kappa R$', fontsize=15)

        ax.set_ylabel(peak.y_label, fontsize=15)
        if peak.log_y:
            ax.set_yscale('log')
        ax.set_xscale('log')
        ax.yaxis.set_label_coords(-0.2, 0.5)
    plt.subplots_adjust(left=0.3)
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def IC_peak_energy_charge_combined_plot(config_data: dict,
                        a_bar: list,
                        R: float = 150,
                        save_as: Path = None,
                        log_y: bool = False):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_11"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_PANELS_BDF"))
    em_data = np.load(em_data_path.joinpath("newpair_energy.npz"))
    data = em_data['changeZc']

    params = ModelParams(R=R, kappaR=3)
    sigma0 = np.linspace(-0.0003, 0.0003, 300)
    sigma_tilde = 0.001

    ex = expansion.MappedExpansionQuad(np.sort(np.array(a_bar)),  # sorting necessary as it is also in energy_dicts()
                                       kappaR=3, sigma0=sigma0, l_max=20, sigma_tilde=sigma_tilde)

    peak_ep = PeakEP(ex, log_y)
    peak_pp = PeakPP(ex, log_y)
    peak_sep = PeakSEP(ex, log_y)
    peaks = [peak_ep, peak_pp, peak_sep]

    data_ic = []
    data_cs = []
    for peak in peaks:
        dict_ic, dict_cs = get_charge_energy_dicts(peak, params, data, sigma0, a_bar, sigma_tilde)
        data_ic.append(dict_ic)
        data_cs.append(dict_cs)

    fig, axs = plt.subplots(3, 1, figsize=(3, 7.8))
    for ax, data_dict_ic, data_dict_cs, peak in zip(axs, data_ic, data_cs, peaks):
        colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
        for ab in a_bar:
            key = np.around(ab, 5)
            current_color = next(colors)
            ax.plot(data_dict_ic[key][:, 0], data_dict_ic[key][:, 1], label=rf'$\bar a = {key:.2f}$', c=current_color)
            ax.plot(data_dict_cs[key][:, 0], data_dict_cs[key][:, 1], ls='--', c=current_color)
        ax.legend(fontsize=13, ncol=1, frameon=False, handlelength=0.7, loc='upper right',
                  bbox_to_anchor=(0.77, 1.03),
                  )
        ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
        # if ax == axs[-1]:
        ax.set_xlabel(r'$\eta$', fontsize=15)

        ax.set_ylabel(peak.y_label, fontsize=15)
        if peak.log_y:
            ax.set_yscale('log')
        # ax.set_xscale('log')
        ax.yaxis.set_label_coords(-0.2, 0.5)
    plt.subplots_adjust(left=0.3)
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def main():
    with open(Path("/home/andraz/ChargedShells/charged-shells/config.json")) as config_file:
        config_data = json.load(config_file)

    # a_bar = [0.2, 0.5, 0.8]
    # IC_peak_energy_plot(config_data, a_bar=a_bar, which='ep', max_kappaR=50,
    #                     # save_as=Path('/home/andraz/ChargedShells/Figures/Emanuele_data/peak_pp_kappaR.png'),
    #                     save_data=False,
    #                     quad_correction=False,
    #                     log_y=True
    #                     )
    # IC_peak_energy_kappaR_combined_plot(config_data, a_bar,
    #                                     save_as=Path(
    #                                         '/home/andraz/ChargedShells/Figures/final_figures/peak_combined_kappaR.png')
    #                                     )

    # a_bar = [0.1, 0.2, 0.3]
    # IC_peak_energy_charge_plot(config_data, a_bar=a_bar, which='sep',
    #                            # save_as=Path('/home/andraz/ChargedShells/Figures/Emanuele_data/peak_sep_charge.png'),
    #                            )
    # IC_peak_energy_charge_combined_plot(config_data, a_bar,
    #                                     save_as=Path('/home/andraz/ChargedShells/Figures/final_figures/peak_combined_charge.png')
    #                                     )

    kappaR = [0.01, 3.02407, 30]
    IC_peak_energy_abar_plot(config_data, kappaR=kappaR, which='sep', min_abar=0.2, max_abar=0.8,
                        save_as=Path('/home/andraz/ChargedShells/Figures/Emanuele_data/peak_sep_abar.png'),
                        save_data=False,
                        quad_correction=False,
                        log_y=False
                        )


if __name__ == '__main__':
    
    main()