import numpy as np
from matplotlib.lines import Line2D

from charged_shells.rotational_path import PairRotationalPath, PathEnergyPlot
from charged_shells import expansion
from charged_shells.parameters import ModelParams
from pathlib import Path
import json
import quadrupole_model_mappings
from plot_settings import *

Array = np.ndarray

zero_to_pi_half = np.linspace(0, np.pi/2, 100, endpoint=True)
pi_half_to_pi = np.linspace(np.pi/2, np.pi, 100, endpoint=True)

QuadPath = PairRotationalPath()
QuadPath.set_default_x_axis(zero_to_pi_half)
QuadPath.add_euler(beta1=np.pi/2, beta2=zero_to_pi_half, start_name="EP", end_name="EE")
QuadPath.add_euler(beta1=zero_to_pi_half[::-1], beta2=zero_to_pi_half[::-1], end_name="PP")
QuadPath.add_euler(beta1=zero_to_pi_half, end_name="EP")
QuadPath.add_euler(beta1=zero_to_pi_half[::-1], beta2=zero_to_pi_half, end_name="EP")
QuadPath.add_euler(beta1=np.pi/2, beta2=zero_to_pi_half, alpha2=np.pi/2, end_name="tEE")
QuadPath.add_euler(beta1=np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half[::-1], end_name="EE")


def model_comparison(config_data: dict, save_as=None, save_data=False):
    kappaR = 3
    params = ModelParams(R=150, kappaR=kappaR)
    a_bar = 0.5
    sigma_tilde = 0.001

    ex1 = expansion.MappedExpansionQuad(a_bar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    # matching other models to the mapped CSp model based on equal patch size in potential
    # ex_gauss = quadrupole_model_mappings.ic_to_gauss(sigma_tilde, a_bar, params, l_max=30, sigma0=0)
    # ex_gauss2 = ex_gauss.clone()
    # ex_cap = quadrupole_model_mappings.ic_to_cap(sigma_tilde, a_bar, params, l_max=30, sigma0=0)
    # ex_cap2 = ex_cap.clone()

    # path plots for all models
    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    # path_plot_gauss = PathEnergyPlot(ex_gauss, ex_gauss2, QuadPath, dist=2., params=params)
    # energy_gauss = path_plot_gauss.evaluate_path()
    #
    # path_plot_cap = PathEnergyPlot(ex_cap, ex_cap2, QuadPath, dist=2., params=params)
    # energy_cap = path_plot_cap.evaluate_path()

    # peak_energy_sanity_check
    # ex1new = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    # ex2new = ex1new.clone()
    # pp_energy = interactions.charged_shell_energy(ex1new, ex2new, params)
    # print(f'PP energy: {pp_energy}')

    # Emanuele data
    em_data = np.load(Path(config_data["emanuele_data"]).joinpath("FIG_3C").joinpath("pathway.npz"))['arr_0']
    # em_data = np.load(Path(config_data["emanuele_data"]).joinpath("FIG_7").joinpath("pathway.npz"))['arr_0']
    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("FIXEDCHARGE")
    #                 .joinpath("FIX_A").joinpath("ECC_0.25"))
    # em_data = np.load(em_data_path.joinpath(f"EMME_{kappaR}.").joinpath("pathway.npz"))['arr_0']

    if save_data:
        np.savez(Path(config_data["figure_data"]).joinpath(f"fig_7_kR{kappaR}.npz"),
                 ICi=em_data,
                 CSp=np.stack((x_axis, np.squeeze(energy))).T,
                 # CSp_gauss=np.stack((x_axis, np.squeeze(energy_gauss))).T,
                 # CSp_cap=np.stack((x_axis, np.squeeze(energy_cap))).T
                 )

    fig, ax = plt.subplots(figsize=(8.25, 3))
    ax.plot(em_data[:, 0], em_data[:, 1], label='ICi', c=COLOR_LIST[1])
    ax.plot(x_axis, np.squeeze(energy), label='CSp', c=COLOR_LIST[1], ls='--')
    # ax.plot(x_axis, np.squeeze(energy_gauss), label='CSp - Gauss')
    # ax.plot(x_axis, np.squeeze(energy_cap), label='CSp - cap')
    # ax.plot(x_axis, em_data[:, 1] / np.squeeze(energy), label='CSp')
    path_plot.plot_style(fig, ax, size=(8.25, 3.5))
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def kappaR_dependence(kappaR: Array, abar: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=0)

    path_plot.plot(labels=[rf'$\kappa R$={kR}' for kR in kappaR],
                   # norm_euler_angles={'beta2': np.pi},
                   save_as=save_as)


def abar_dependence(abar: Array, kappaR: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=None)

    path_plot.plot(labels=[rf'$\bar a$={a}' for a in abar],
                   # norm_euler_angles={'beta2': np.pi},
                   save_as=save_as)


def sigma0_dependence(sigma0: Array, kappaR: float, abar: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30, sigma0=sigma0)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=None)

    path_plot.plot(labels=[rf'$\eta={s0 / sigma_tilde}$' for s0 in sigma0],
                   # norm_euler_angles={'beta2': np.pi},
                   save_as=save_as)


def distance_dependence(dist: Array, kappaR: float, abar: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    plots = []
    for d in dist:
        path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=d, params=params)
        x = d * kappaR
        plots.append(path_plot.evaluate_path() * np.exp(x) * x)

    x_axis = path_plot.rot_path.stack_x_axes()
    labels = [rf'$\rho/R ={d}$' for d in dist]

    fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
    for pl, lbl in zip(plots, labels):
        ax.plot(x_axis, pl, label=lbl)
    QuadPath.plot_style(fig, ax)
    ax.set_ylabel(r'$U \kappa\rho e^{\kappa\rho}$', fontsize=15)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def IC_kappaR_dependence(config_data: dict, save_as=None):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("FIXEDCHARGE")
    #                 .joinpath("FIX_A").joinpath("ECC_0.25"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("FIXEDCHARGE")
                    .joinpath("FIX_A").joinpath("ECC_0.25"))
    kR1 = np.load(em_data_path.joinpath("EMME_1.").joinpath("pathway.npz"))['arr_0']
    kR3 = np.load(em_data_path.joinpath("EMME_3.").joinpath("pathway.npz"))['arr_0']
    kR10 = np.load(em_data_path.joinpath("EMME_10.").joinpath("pathway.npz"))['arr_0']

    labels = [rf'$\kappa R$={kR}' for kR in [1, 3, 10]]

    fig, ax = plt.subplots()
    ax.plot(kR1[:, 0], kR1[:, 1], label=labels[0])
    ax.plot(kR3[:, 0], kR3[:, 1], label=labels[1])
    ax.plot(kR10[:, 0], kR10[:, 1], label=labels[2])
    QuadPath.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def IC_abar_dependence(config_data: dict, save_as=None):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("FIXEDCHARGE").joinpath("FIX_M"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("FIXEDCHARGE").joinpath("FIX_M"))
    a03 = np.load(em_data_path.joinpath("ECC_0.15").joinpath("EMME_3.").joinpath("pathway.npz"))['arr_0']
    a04 = np.load(em_data_path.joinpath("ECC_0.2").joinpath("EMME_3.").joinpath("pathway.npz"))['arr_0']
    a05 = np.load(em_data_path.joinpath("ECC_0.25").joinpath("EMME_3.").joinpath("pathway.npz"))['arr_0']

    labels =[rf'$\bar a$={a}' for a in [0.3, 0.4, 0.5]]

    fig, ax = plt.subplots()
    ax.plot(a03[:, 0], a03[:, 1], label=labels[0])
    ax.plot(a04[:, 0], a04[:, 1], label=labels[1])
    ax.plot(a05[:, 0], a05[:, 1], label=labels[2])
    QuadPath.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def IC_sigma0_dependence(config_data: dict, save_as=None):
    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("CHARGE_ZC"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("CHARGE_ZC"))
    undercharged = np.load(em_data_path.joinpath("ZC_-277.27").joinpath("pathway.npz"))['arr_0']
    neutral = np.load(em_data_path.joinpath("ZC_-560").joinpath("pathway.npz"))['arr_0']
    overchargerd = np.load(em_data_path.joinpath("ZC_-842.74").joinpath("pathway.npz"))['arr_0']

    labels = [rf'$\eta={eta}$' for eta in [-0.1, 0, 0.1]]

    fig, ax = plt.subplots()
    ax.plot(overchargerd[:, 0], overchargerd[:, 1], label=labels[0])
    ax.plot(neutral[:, 0], neutral[:, 1], label=labels[1])
    ax.plot(undercharged[:, 0], undercharged[:, 1], label=labels[2])
    QuadPath.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_distance_dependence(config_data: dict, dist: Array = 2 * np.array([1., 1.15, 1.3, 1.45]),
                                 kappaR: float = 3,
                                 abar: float = 0.5,
                                 sigma_tilde=0.001,
                                 save_as=None):

    # em_data_path = Path(config_data["emanuele_data"]).joinpath("FIG_12")
    em_data_path = Path(config_data["emanuele_data"]).joinpath("FIG_3D_LONG_DIST")
    em_data = np.load(em_data_path.joinpath("pathway_fig12A.npz"))

    em_data_d2 = np.load(Path(config_data["emanuele_data"]).joinpath("FIG_3C").joinpath("pathway.npz"))['arr_0']

    ic_data = [em_data_d2]
    for key, d in em_data.items():
        ic_data.append(d)

    params = ModelParams(R=150, kappaR=kappaR)
    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    plots = []
    for d in dist:
        path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=d, params=params)
        plots.append(path_plot.evaluate_path())

    x_axis = path_plot.rot_path.stack_x_axes()
    labels = [rf'$\rho/R ={d}$' for d in dist]

    # additional legend
    line1 = Line2D([0], [0], color='black', linewidth=1, label='ICi')
    line2 = Line2D([0], [0], color='black', linestyle='--', linewidth=1, label='CSp')

    fig, ax = plt.subplots()
    for i, (d, en, label, c) in enumerate(zip(ic_data, plots, labels, COLORS)):
        if i < 3:
            ax.plot(d[:, 0], d[:, 1], label=label, c=c)
            ax.plot(x_axis, en, ls='--', c=c)
    QuadPath.plot_style(fig, ax)
    main_legend = ax.get_legend()
    extra_legend = ax.legend(handles=[line1, line2], loc='upper left', fontsize=15, frameon=False)
    ax.add_artist(main_legend)
    ax.add_artist(extra_legend)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_rescaled_distance_dependence(config_data: dict,
                                          dist: Array = 2 * np.array([1, 1.5, 2, 3, 5, 10]),
                                          kappaR: float = 3,
                                          abar: float = 0.5,
                                          sigma_tilde=0.001,
                                          save_as=None):

    # em_data_path = Path(config_data["emanuele_data"]).joinpath("FIG_12")
    em_data_path = Path(config_data["emanuele_data"]).joinpath("FIG_3D_LONG_DIST")
    em_data = np.load(em_data_path.joinpath("pathway_fig12B.npz"))
    ic_data = []
    for key, d in em_data.items():
        # if key == 'kr10':
        #     continue
        print(key, np.max(d[:, 1]))
        ic_data.append(d)

    params = ModelParams(R=150, kappaR=kappaR)
    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    plots = []
    for d in dist:
        path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=d, params=params)
        x = d * kappaR
        plots.append(path_plot.evaluate_path() * np.exp(x) * x)
        # plots.append(path_plot.evaluate_path())

    x_axis = path_plot.rot_path.stack_x_axes()
    labels = [rf'$\rho/R ={d}$' for d in dist]

    fig, ax = plt.subplots()
    for d, en, label, c in zip(ic_data, plots, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(x_axis, en, ls='--', c=c)
    QuadPath.plot_style(fig, ax)
    ax.set_ylabel(r'$U \kappa\rho e^{\kappa\rho}$', fontsize=15)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_kappaR_dependence(config_data: dict, kappaR: list[int], abar: float, sigma_tilde=0.001, save_as=None):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("FIXEDCHARGE")
    #                 .joinpath("FIX_A").joinpath(f"ECC_{np.round(abar/2, 4)}"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("FIXEDCHARGE")
                    .joinpath("FIX_A").joinpath(f"ECC_{np.round(abar/2, 4)}"))

    ic_data = []
    for kR in kappaR:
        ic_data.append(np.load(em_data_path.joinpath(f"EMME_{kR}.").joinpath("pathway.npz"))['arr_0'])

    params = ModelParams(R=150, kappaR=np.asarray(kappaR))

    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=0)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    labels = [rf'$\kappa R={kR}$' for kR in [1, 3, 10]]

    fig, ax = plt.subplots()
    for d, en, label, c in zip(ic_data, energy.T, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(x_axis, en, ls='--', c=c)
    QuadPath.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_abar_dependence(config_data: dict, kappaR: int, abar: list[float], sigma_tilde=0.001, save_as=None):

    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("FIXEDCHARGE").joinpath("FIX_M"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("FIXEDCHARGE").joinpath("FIX_M"))

    ic_data = []
    for ab in abar:
        ic_data.append(np.load(em_data_path.joinpath(f"ECC_{np.round(ab/2, 4)}").
                               joinpath(f"EMME_{kappaR}.").joinpath("pathway.npz"))['arr_0'])

    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = expansion.MappedExpansionQuad(np.asarray(abar), params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=None)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    labels = [rf'$\bar a={a}$' for a in [0.3, 0.4, 0.5]]

    fig, ax = plt.subplots()
    for d, en, label, c in zip(ic_data, energy.T, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(x_axis, en, ls='--', c=c)
    QuadPath.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_sigma0_dependence(config_data: dict, kappaR=3., abar=0.5, sigma0=(0.0002, 0.00, -0.0002), sigma_tilde=0.001, save_as=None):
    # em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_8").joinpath("CHARGE_ZC"))
    em_data_path = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("CHARGE_ZC"))
    undercharged = np.load(em_data_path.joinpath("ZC_-503").joinpath("pathway.npz"))['arr_0']
    neutral = np.load(em_data_path.joinpath("ZC_-560").joinpath("pathway.npz"))['arr_0']
    overchargerd = np.load(em_data_path.joinpath("ZC_-617").joinpath("pathway.npz"))['arr_0']
    ic_data = [undercharged, neutral, overchargerd]

    params = ModelParams(R=150, kappaR=kappaR)
    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30, sigma0=np.asarray(sigma0))
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=None)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    labels = [rf'$\eta={s0/sigma_tilde}$' for s0 in sigma0]

    fig, ax = plt.subplots()
    for d, en, label, c in zip(ic_data, energy.T, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(x_axis, en, ls='--', c=c)
    QuadPath.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_all(config_data: dict, save_as=None):

    sigma_tilde = 0.001
    kappaR_list = [1, 3, 10]
    abar_list = [0.5, 0.4, 0.3]
    sigma0_list = [0.0002, 0.00, -0.0002]
    kappaR = 3
    abar = 0.5


    em_data_kappaR = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("FIXEDCHARGE")
                    .joinpath("FIX_A").joinpath(f"ECC_{np.round(abar/2, 4)}"))

    ic_data_kappaR = []
    for kR in kappaR_list:
        ic_data_kappaR.append(np.load(em_data_kappaR.joinpath(f"EMME_{kR}.").joinpath("pathway.npz"))['arr_0'])

    params = ModelParams(R=150, kappaR=np.asarray(kappaR_list))

    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=0)
    energy_kappaR = path_plot.evaluate_path()
    x_axis_kappaR = path_plot.rot_path.stack_x_axes()
    labels_kappaR = [rf'$\kappa R={kR}$' for kR in [1, 3, 10]]


    em_data_abar = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("FIXEDCHARGE").joinpath("FIX_M"))

    ic_data_abar = []
    for ab in abar_list:
        ic_data_abar.append(np.load(em_data_abar.joinpath(f"ECC_{np.round(ab/2, 4)}").
                               joinpath(f"EMME_{kappaR}.").joinpath("pathway.npz"))['arr_0'])

    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = expansion.MappedExpansionQuad(np.asarray(abar_list), params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_abar = path_plot.evaluate_path()
    x_axis_abar = path_plot.rot_path.stack_x_axes()
    labels_abar = [rf'$\bar a={a}$' for a in abar_list]


    em_data_charge = (Path(config_data["emanuele_data"]).joinpath("FIG_4_Panels_ACE").joinpath("CHARGE_ZC"))
    undercharged = np.load(em_data_charge.joinpath("ZC_-503").joinpath("pathway.npz"))['arr_0']
    neutral = np.load(em_data_charge.joinpath("ZC_-560").joinpath("pathway.npz"))['arr_0']
    overchargerd = np.load(em_data_charge.joinpath("ZC_-617").joinpath("pathway.npz"))['arr_0']
    ic_data_sigma0 = [undercharged, neutral, overchargerd]

    params = ModelParams(R=150, kappaR=kappaR)
    ex1 = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30, sigma0=np.asarray(sigma0_list))
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, QuadPath, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_sigma0 = path_plot.evaluate_path()
    x_axis_sigma0 = path_plot.rot_path.stack_x_axes()

    labels_sigma0 = [rf'$\eta={s0/sigma_tilde}$' for s0 in sigma0_list]

    fig, axs = plt.subplots(3, 1, figsize=(6, 7.8))
    for d, en, label, c in zip(ic_data_kappaR, energy_kappaR.T, labels_kappaR, COLOR_LIST):
        axs[0].set_title('Screening', fontsize=15)
        axs[0].plot(d[:, 0], d[:, 1], label=label, c=c)
        axs[0].plot(x_axis_kappaR, en, ls='--', c=c)
    QuadPath.plot_style(fig, axs[0], size=None)
    for d, en, label, c in zip(ic_data_abar, energy_abar.T, labels_abar, COLOR_LIST):
        axs[1].set_title('Asymmetry', fontsize=15)
        axs[1].plot(d[:, 0], d[:, 1], label=label, c=c)
        axs[1].plot(x_axis_abar, en, ls='--', c=c)
    QuadPath.plot_style(fig, axs[1], size=None)
    for d, en, label, c in zip(ic_data_sigma0, energy_sigma0.T, labels_sigma0, COLOR_LIST):
        axs[2].set_title('Net charge', fontsize=15)
        axs[2].plot(d[:, 0], d[:, 1], label=label, c=c)
        axs[2].plot(x_axis_sigma0, en, ls='--', c=c)
    for ax in axs:
        ax.yaxis.set_label_coords(-0.09, 0.5)
    # axs[-1].set_xlabel('rotational path', fontsize=15)
    QuadPath.plot_style(fig, axs[2], size=None)
    for ax in axs:
        ax.get_legend().set_bbox_to_anchor((0.65, 1))
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def main():

    with open(Path("/home/andraz/ChargedShells/charged-shells/config.json")) as config_file:
        config_data = json.load(config_file)

    # model_comparison(config_data, save_data=False,
    #     save_as=Path(config_data["figures"]).joinpath("final_figures").joinpath('quad_og_comparison.png')
    # )

    # kappaR_dependence(np.array([1, 3, 10]), 0.5,
    #                   # save_as=Path("/home/andraz/ChargedShells/Figures/quadrupole_kappaR_dep.png")
    #                   )
    #
    # abar_dependence(np.array([0.3, 0.4, 0.5]), 3,
    #                 save_as=Path("/home/andraz/ChargedShells/Figures/quadrupole_abar_dep.png")
    #                 )

    # sigma0_dependence(np.array([-0.0002, 0.00, 0.0002]), 3, 0.5,
    #                   save_as=Path("/home/andraz/ChargedShells/Figures/quadrupole_charge_dep_abar05_kappaR3.png")
    #                   )

    # distance_dependence(dist=np.array([2, 3, 4, 6, 10, 20]), kappaR=3, abar=0.5,
    #                     # save_as=Path(config_data["figures"]).joinpath('quadrupole_distance_dep.png')
    #                     )

    # IC_kappaR_dependence(config_data,
    #                      save_as=Path(config_data["figures"]).joinpath("Emanuele_data").joinpath('IC_quadrupole_kappaR_dep.png')
    #                      )
    #
    # IC_abar_dependence(config_data, save_as=Path(config_data["figures"]).joinpath("Emanuele_data").
    #                    joinpath('IC_quadrupole_abar_dep.png'))
    #
    # IC_sigma0_dependence(config_data, save_as=Path(config_data["figures"]).joinpath("Emanuele_data").
    #                      joinpath('IC_quadrupole_charge_dep_abar05_kappaR3.png'))

    # combined_kappaR_dependence(config_data, kappaR=[1, 3, 10], abar=0.5,
    #                      save_as=Path(config_data["figures"]).joinpath("final_figures").joinpath('quad_kappaR_dep.png')
    #                      )

    # combined_sigma0_dependence(config_data,
    #                      save_as=Path(config_data["figures"]).joinpath("final_figures").joinpath('quad_charge_dep.png')
    #                      )

    # combined_abar_dependence(config_data, kappaR=3, abar=[0.3, 0.4, 0.5],
    #                          save_as=Path(config_data["figures"]).joinpath("final_figures").joinpath('quad_abar_dep.png')
    #                          )

    # combined_rescaled_distance_dependence(config_data)

    combined_distance_dependence(config_data,
                                 save_as=Path(config_data["figures"]).joinpath("final_figures").joinpath('quad_dist_dep.png')
                                 )

    # combined_all(config_data,
    #              save_as=Path(config_data["figures"]).joinpath("final_figures").joinpath('quad_combined_dep.png')
    #              )


if __name__ == '__main__':

    main()