import matplotlib.pyplot as plt
from charged_shells import expansion, interactions, mapping
from charged_shells.parameters import ModelParams
import numpy as np
from typing import Literal
from pathlib import Path
from functools import partial
import json

Array = np.ndarray
Expansion = expansion.Expansion
RotConfig = Literal['ep', 'pp']


def peak_energy_plot(kappaR: Array,
                     a_bar: Array,
                     which: Literal['ep', 'pp'],
                     R: float = 150,
                     dist: float = 2.,
                     l_max=20,
                     save_as: Path = None):

    params = ModelParams(R=R, kappaR=kappaR)

    # energy = []
    # for params in params.unravel():
    #     ex1 = expansion.MappedExpansionQuad(a_bar, params.kappaR, 0.001, l_max=l_max)
    #     ex2 = ex1.clone()
    #     if which == 'ep':
    #         ex1.rotate_euler(alpha=0, beta=np.pi / 2, gamma=0)
    #     energy.append(interactions.charged_shell_energy(ex1, ex2, dist, params))
    #
    # energy = np.array(energy)

    ex1 = expansion.MappedExpansionQuad(a_bar[:, None], params.kappaR[None, :], 0.001, l_max=l_max)
    ex2 = ex1.clone()
    if which == 'ep':
        ex1.rotate_euler(alpha=0, beta=np.pi / 2, gamma=0)

    energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=dist), 1)
    energy = energy_fn(ex1, ex2, params)

    fig, ax = plt.subplots()
    for en, ab in zip(energy, a_bar):
        ax.plot(kappaR, en / en[0], label=rf'$\bar a = {ab:.1f}$')
        # ax.plot(kappaR, en, label=rf'$\bar a = {ab:.1f}$')
    ax.legend(fontsize=12)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\kappa R$', fontsize=15)
    ax.set_ylabel(rf'$\bar V$', fontsize=15)
    plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


def IC_peak_energy_plot(config_data: dict,
                        a_bar: list,
        which: Literal['ep', 'pp'],
                     save_as: Path = None):

    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_11"))
    em_data = np.load(em_data_path.joinpath("pair_energy_fig11.npz"))
    data = em_data['fixA']

    if which == 'ep':
        column_idx = 4
    elif which == 'pp':
        column_idx = 3
    else:
        raise ValueError

    abar, inverse, counts = np.unique(data[:, 1], return_counts=True, return_inverse=True)
    # print(indices, counts)
    print(inverse)
    # sort_abar = np.argsort(indices)
    # print(data[:, 1][indices])
    # print(sort_abar)
    # print(data[:, 1][indices[sort_abar]])

    # energies = data[:, column_idx].reshape(-1, counts[0])
    # kappaR = data[:, 2].reshape(-1, counts[0])
    # print(len(kappaR), len(energies), len(abar))

    fig, ax = plt.subplots()
    for i in range(len(abar)):
        if abar[i] in a_bar:
            idx, = np.nonzero(inverse == i)
            kR = data[idx, 2]
            en = data[idx, column_idx]
            sort = np.argsort(kR)
            # ax.plot(kR[sort], en[sort] / en[sort][0], label=rf'$\bar a = {abar[i]:.2f}$')
            ax.plot(kR[sort], en[sort], label=rf'$\bar a = {abar[i]:.2f}$')
    ax.legend(fontsize=12)
    ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
    ax.set_xlabel(r'$\kappa R$', fontsize=15)
    ax.set_ylabel(rf'$\bar V_{{{which}}}$', fontsize=15)
    plt.tight_layout()
    ax.set_yscale('log')
    ax.set_xscale('log')
    if save_as is not None:
        plt.savefig(save_as, dpi=600)
    plt.show()


if __name__ == '__main__':
    
    with open(Path("/home/andraz/ChargedShells/charged-shells/config.json")) as config_file:
        config_data = json.load(config_file)

    kappaR = np.arange(0.5, 10, 0.1)
    a_bar = np.arange(0.2, 0.8, 0.2)

    # peak_energy_plot(kappaR, a_bar, which='pp',
    #                  # save_as=Path('/home/andraz/ChargedShells/Figures/nonmonotonicity_check_ep.pdf')
    #                  )
    
    IC_peak_energy_plot(config_data, a_bar=[0.2, 0.24, 0.36, 0.52, 0.8], which='pp',
                     # save_as=Path('/home/andraz/ChargedShells/Figures/Emanuele_data/nonmonotonicity_check_ep.pdf')
                     )