import numpy as np
from matplotlib import gridspec
from charged_shells.rotational_path import PairRotationalPath, PathEnergyPlot
from charged_shells import interactions, charge_distributions
from charged_shells.parameters import ModelParams
from config import *
from plot_settings import *

Array = np.ndarray


zero_to_pi_half = np.linspace(0, np.pi/2, 100, endpoint=True)
pi_half_to_pi = np.linspace(np.pi/2, np.pi, 100, endpoint=True)
pi_to_three_halves_pi = np.linspace(np.pi, 3 * np.pi / 2, 100, endpoint=True)

DipolePath = PairRotationalPath()
DipolePath.set_default_x_axis(zero_to_pi_half)
DipolePath.add_euler(beta2=pi_half_to_pi[::-1])
DipolePath.add_euler(beta2=zero_to_pi_half[::-1])
DipolePath.add_euler(beta2=zero_to_pi_half, beta1=zero_to_pi_half)
DipolePath.add_euler(beta2=np.pi/2, beta1=np.pi/2, alpha2=zero_to_pi_half)
DipolePath.add_euler(beta2=np.pi/2, alpha2=np.pi/2, beta1=pi_half_to_pi)
DipolePath.add_euler(beta2=np.pi/2, beta1=pi_half_to_pi[::-1], alpha1=np.pi)
DipolePath.add_euler(beta2=zero_to_pi_half[::-1], beta1=pi_half_to_pi, alpha1=np.pi)
DipolePath.add_euler(beta2=zero_to_pi_half, beta1=pi_half_to_pi[::-1], alpha1=np.pi)
DipolePath.add_euler(beta2=pi_half_to_pi, beta1=zero_to_pi_half[::-1], alpha1=np.pi)

DipolePath2 = PairRotationalPath()
DipolePath2.set_default_x_axis(zero_to_pi_half)
DipolePath2.add_euler(beta2=pi_half_to_pi[::-1])
DipolePath2.add_euler(beta2=zero_to_pi_half[::-1])
DipolePath2.add_euler(beta2=zero_to_pi_half, beta1=zero_to_pi_half)
DipolePath2.add_euler(beta2=np.pi/2, beta1=np.pi/2, alpha2=zero_to_pi_half)
DipolePath2.add_euler(beta2=np.pi/2, alpha2=np.pi/2, beta1=pi_half_to_pi)
DipolePath2.add_euler(beta2=zero_to_pi_half[::-1], beta1=pi_half_to_pi[::-1])
DipolePath2.add_euler(beta2=zero_to_pi_half[::-1], beta1=np.pi)
DipolePath2.add_euler(beta2=zero_to_pi_half, beta1=pi_half_to_pi[::-1], alpha1=np.pi)
DipolePath2.add_euler(beta2=pi_half_to_pi, beta1=zero_to_pi_half[::-1], alpha1=np.pi)

DipolePath3 = PairRotationalPath()
DipolePath3.set_default_x_axis(zero_to_pi_half)
DipolePath3.add_euler(beta2=np.pi/2, beta1=zero_to_pi_half, start_name="EP", end_name="EE")
DipolePath3.add_euler(beta2=pi_half_to_pi, beta1=pi_half_to_pi, end_name="PP")
DipolePath3.add_euler(beta2=pi_half_to_pi[::-1], beta1=np.pi, end_name="EP")
DipolePath3.add_euler(beta2=pi_half_to_pi, beta1=pi_to_three_halves_pi, end_name="EP")
DipolePath3.add_euler(beta1=3 * np.pi/2, beta2=pi_half_to_pi[::-1], alpha2=np.pi/2, end_name="tEE")
DipolePath3.add_euler(beta1=3 * np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half[::-1], end_name="EE")
DipolePath3.add_euler(beta1=3 * np.pi/2, beta2=pi_half_to_pi, end_name="EP")


def sections_plot(kappaR: float = 3, abar: float = 0.5, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)

    path_plot.plot_sections(save_as=save_as)


def kappaR_dependence(kappaR: Array, abar: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=0)

    path_plot.plot(labels=[rf'$\kappa R$={kR}' for kR in kappaR],
                   # norm_euler_angles={'beta2': np.pi},
                   save_as=save_as)


def abar_dependence(abar: Array, kappaR: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)

    path_plot.plot(labels=[rf'$\bar a$={a}' for a in abar],
                   # norm_euler_angles={'beta2': np.pi},
                   save_as=save_as)


def sigma0_dependence(sigma0: Array, kappaR: float, abar: float, sigma_tilde=0.001, save_as=None):
    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30, sigma0=sigma0)
    ex2 = ex1.clone()

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)

    path_plot.plot(labels=[rf'$\eta={s0 / sigma_tilde}$' for s0 in sigma0],
                   # norm_euler_angles={'beta2': np.pi},
                   save_as=save_as)


def model_comparison(save_as=None, save_data=False):
    kappaR = 3
    params = ModelParams(R=150, kappaR=kappaR)
    a_bar = 0.5
    sigma_tilde = 0.001

    ex1 = charge_distributions.create_mapped_dipolar_expansion(a_bar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign()

    # matching other models to the mapped CSp model based on equal patch size in potential
    # ex_gauss = quadrupole_model_mappings.ic_to_gauss(sigma_tilde, a_bar, params, l_max=30, sigma0=0)
    # ex_gauss2 = ex_gauss.clone()
    # ex_cap = quadrupole_model_mappings.ic_to_cap(sigma_tilde, a_bar, params, l_max=30, sigma0=0)
    # ex_cap2 = ex_cap.clone()

    # path plots for all models
    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params)
    energy_inv = path_plot_inv.evaluate_path()

    # path_plot_gauss = PathEnergyPlot(ex_gauss, ex_gauss2, DipolePath3, dist=2., params=params)
    # energy_gauss = path_plot_gauss.evaluate_path()
    #
    # path_plot_cap = PathEnergyPlot(ex_cap, ex_cap2, DipolePath3, dist=2., params=params)
    # energy_cap = path_plot_cap.evaluate_path()

    # peak_energy_sanity_check
    # ex1new = expansion.MappedExpansionQuad(abar, params.kappaR, sigma_tilde, l_max=30)
    # ex2new = ex1new.clone()
    # pp_energy = interactions.charged_shell_energy(ex1new, ex2new, params)
    # print(f'PP energy: {pp_energy}')

    # ICi data
    em_data = np.load(ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_A").joinpath("pathway.npz"))['arr_0']
    # em_data = np.load(ICI_DATA_PATH.joinpath("FIG_7").joinpath("pathway.npz"))['arr_0']
    # em_data_path = (ICI_DATA_PATH.joinpath("FIG_8").joinpath("FIXEDCHARGE")
    #                 .joinpath("FIX_A").joinpath("ECC_0.25"))
    # em_data = np.load(em_data_path.joinpath(f"EMME_{kappaR}.").joinpath("pathway.npz"))['arr_0']
    em_data, em_data_inv = em_data[:int(len(em_data) / 2)], em_data[int(len(em_data) / 2):]

    # if save_data:
    #     np.savez(Path(config_data["figure_data"]).joinpath(f"fig_5_janus_kR{kappaR}.npz"),
    #              ICi=em_data,
    #              CSp=np.stack((x_axis, np.squeeze(energy))).T,
    #              CSp_gauss=np.stack((x_axis, np.squeeze(energy_gauss))).T,
    #              CSp_cap=np.stack((x_axis, np.squeeze(energy_cap))).T)

    fig, ax = plt.subplots(figsize=0.5 * np.array([8.25, 4.125]))
    ax.plot(em_data[:, 0], em_data[:, 1], label='ICi', c='tab:blue')
    ax.plot(em_data_inv[:, 0], em_data_inv[:, 1], ls='--', c='tab:blue')
    ax.plot(x_axis, np.squeeze(energy), label='CSp', c='tab:orange')
    ax.plot(x_axis, np.squeeze(energy_inv), ls='--', c='tab:orange')
    # ax.plot(x_axis, np.squeeze(energy_gauss), label='CSp - Gauss')
    # ax.plot(x_axis, np.squeeze(energy_cap), label='CSp - cap')
    # ax.plot(x_axis, em_data[:, 1] / np.squeeze(energy), label='CSp')
    path_plot.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_kappaR_dependence(kappaR: list[int], abar: float, sigma_tilde=0.001, save_as=None):

    em_data_path = (ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_C")
                    .joinpath("FIX_A").joinpath(f"ECC_{np.round(abar/2, 4)}"))

    ic_data = []
    ic_data_inv = []
    for kR in kappaR:
        em_data = np.load(em_data_path.joinpath(f"EMME_{kR}.").joinpath("pathway.npz"))['arr_0']
        em_data, em_data_inv = em_data[:int(len(em_data) / 2)], em_data[int(len(em_data) / 2):]
        ic_data.append(em_data)
        ic_data_inv.append(em_data_inv)

    params = ModelParams(R=150, kappaR=np.asarray(kappaR))

    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign()

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=0)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=0)
    energy_inv = path_plot_inv.evaluate_path()

    labels = [rf'$\kappa R = {kR}$' for kR in [1, 3, 10]]

    fig, ax = plt.subplots()
    for d, d_inv, en, en_inv, label, c in zip(ic_data, ic_data_inv, energy.T, energy_inv.T, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(d_inv[:, 0], d_inv[:, 1], c=c)
        ax.plot(x_axis, en, ls='--', c=c)
        ax.plot(x_axis, en_inv, ls='--', c=c)
    DipolePath3.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()
    
    
def combined_abar_dependence(kappaR: int, abar: list[float], sigma_tilde=0.001, save_as=None):
    
    em_data_path = ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_B").joinpath("FIX_M")

    ic_data = []
    ic_data_inv = []
    for ab in abar:
        em_data = np.load(em_data_path.joinpath(f"ECC_{np.round(ab/2, 4)}").joinpath(f"EMME_{kappaR}.").joinpath("pathway.npz"))['arr_0']
        em_data, em_data_inv = em_data[:int(len(em_data) / 2)], em_data[int(len(em_data) / 2):]
        ic_data.append(em_data)
        ic_data_inv.append(em_data_inv)

    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = charge_distributions.create_mapped_dipolar_expansion(np.asarray(abar), params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign()

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_inv = path_plot_inv.evaluate_path()

    labels = [rf'$\bar a={a}$' for a in [0.3, 0.4, 0.5]]

    fig, ax = plt.subplots()
    for d, d_inv, en, en_inv, label, c in zip(ic_data, ic_data_inv, energy.T, energy_inv.T, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(d_inv[:, 0], d_inv[:, 1], c=c)
        ax.plot(x_axis, en, ls='--', c=c)
        ax.plot(x_axis, en_inv, ls='--', c=c)
    DipolePath3.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_sigma0_dependence(kappaR=3., abar=0.5, sigma0=(-0.0002, 0.00, 0.0002), sigma_tilde=0.001, save_as=None):

    em_data_path = ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_D").joinpath("CHANGE_ZC")
    undercharged = np.load(em_data_path.joinpath("ZC_-56").joinpath("pathway.npz"))['arr_0']
    overcharged = np.load(em_data_path.joinpath("ZC_56").joinpath("pathway.npz"))['arr_0']

    neutral_path = ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_B").joinpath("FIX_M")
    neutral = np.load(neutral_path.joinpath(f"ECC_{np.round(abar/2, 4)}").joinpath(f"EMME_{int(kappaR)}.").joinpath("pathway.npz"))['arr_0']

    undercharged, undercharged_inv = undercharged[:int(len(undercharged) / 2)], undercharged[int(len(undercharged) / 2):]
    overcharged, overcharged_inv = overcharged[:int(len(overcharged) / 2)], overcharged[int(len(overcharged) / 2):]
    neutral, neutral_inv = neutral[:int(len(neutral) / 2)], neutral[int(len(neutral) / 2):]

    ic_data = [undercharged, neutral, overcharged]
    ic_data_inv = [undercharged_inv, neutral_inv, overcharged_inv]

    params = ModelParams(R=150, kappaR=kappaR)
    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30, sigma0=np.asarray(sigma0))
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign(exclude_00=True)

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy = path_plot.evaluate_path()
    x_axis = path_plot.rot_path.stack_x_axes()

    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_inv = path_plot_inv.evaluate_path()

    labels = [rf'$\eta={s0/sigma_tilde}$' for s0 in sigma0]

    fig, ax = plt.subplots()
    for d, d_inv, en, en_inv, label, c in zip(ic_data, ic_data_inv, energy.T, energy_inv.T, labels, COLORS):
        ax.plot(d[:, 0], d[:, 1], label=label, c=c)
        ax.plot(d_inv[:, 0], d_inv[:, 1], c=c)
        ax.plot(x_axis, en, ls='--', c=c)
        ax.plot(x_axis, en_inv, ls='--', c=c)
    DipolePath3.plot_style(fig, ax)
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


def combined_all(save_as=None):

    sigma_tilde = 0.00099
    kappaR_list = [1, 3, 10]
    abar_list = [0.5, 0.4, 0.3]
    sigma0_list = [-0.000198, 0.00, 0.000198]
    kappaR = 3
    abar = 0.5

    em_data_kappaR = (ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_C")
                      .joinpath("FIX_A").joinpath(f"ECC_{np.round(abar / 2, 4)}"))
    
    ic_data_kappaR = []
    ic_data_kappaR_inv = []
    for kR in kappaR_list:
        em_data = np.load(em_data_kappaR.joinpath(f"EMME_{kR}.").joinpath("pathway.npz"))['arr_0']
        em_data, em_data_inv = em_data[:int(len(em_data) / 2)], em_data[int(len(em_data) / 2):]
        ic_data_kappaR.append(em_data)
        ic_data_kappaR_inv.append(em_data_inv)

    params = ModelParams(R=150, kappaR=np.asarray(kappaR_list))

    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign(exclude_00=True)

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=0)
    energy_kappaR = path_plot.evaluate_path()
    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=0)
    energy_kappaR_inv = path_plot_inv.evaluate_path()
    x_axis_kappaR = path_plot.rot_path.stack_x_axes()
    labels_kappaR = [rf'$\kappa R={kR}$' for kR in [1, 3, 10]]

    em_data_abar = ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_B").joinpath("FIX_M")

    ic_data_abar = []
    ic_data_abar_inv = []
    for ab in abar_list:
        em_data = np.load(
            em_data_abar.joinpath(f"ECC_{np.round(ab / 2, 4)}").joinpath(f"EMME_{kappaR}.").joinpath("pathway.npz"))[
            'arr_0']
        em_data, em_data_inv = em_data[:int(len(em_data) / 2)], em_data[int(len(em_data) / 2):]
        ic_data_abar.append(em_data)
        ic_data_abar_inv.append(em_data_inv)

    params = ModelParams(R=150, kappaR=kappaR)

    ex1 = charge_distributions.create_mapped_dipolar_expansion(np.asarray(abar_list), params.kappaR, sigma_tilde, l_max=30)
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign(exclude_00=True)

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_abar = path_plot.evaluate_path()
    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_abar_inv = path_plot_inv.evaluate_path()
    x_axis_abar = path_plot.rot_path.stack_x_axes()
    labels_abar = [rf'$\bar a={a}$' for a in abar_list]

    em_data_charge = ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_D").joinpath("CHANGE_ZC")
    undercharged = np.load(em_data_charge.joinpath("ZC_-56").joinpath("pathway.npz"))['arr_0']
    overcharged = np.load(em_data_charge.joinpath("ZC_56").joinpath("pathway.npz"))['arr_0']
    neutral_path = ICI_DATA_PATH.joinpath("FIG_5_JANUS").joinpath("FIG_5_JANUS_B").joinpath("FIX_M")
    neutral = np.load(
        neutral_path.joinpath(f"ECC_{np.round(abar / 2, 4)}").joinpath(f"EMME_{int(kappaR)}.").joinpath("pathway.npz"))[
        'arr_0']
    undercharged, undercharged_inv = undercharged[:int(len(undercharged) / 2)], undercharged[
                                                                                int(len(undercharged) / 2):]
    overcharged, overcharged_inv = overcharged[:int(len(overcharged) / 2)], overcharged[int(len(overcharged) / 2):]
    neutral, neutral_inv = neutral[:int(len(neutral) / 2)], neutral[int(len(neutral) / 2):]

    ic_data_sigma0 = [undercharged, neutral, overcharged]
    ic_data_sigma0_inv = [undercharged_inv, neutral_inv, overcharged_inv]

    params = ModelParams(R=150, kappaR=kappaR)
    ex1 = charge_distributions.create_mapped_dipolar_expansion(abar, params.kappaR, sigma_tilde, l_max=30, sigma0=np.asarray(sigma0_list))
    ex2 = ex1.clone()
    ex3 = ex1.clone().inverse_sign(exclude_00=True)

    path_plot = PathEnergyPlot(ex1, ex2, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_sigma0 = path_plot.evaluate_path()
    path_plot_inv = PathEnergyPlot(ex1, ex3, DipolePath3, dist=2., params=params, match_expansion_axis_to_params=None)
    energy_sigma0_inv = path_plot_inv.evaluate_path()
    x_axis_sigma0 = path_plot.rot_path.stack_x_axes()
    labels_sigma0 = [rf'$\eta={s0/sigma_tilde:.1f}$' for s0 in sigma0_list]

    # fig, axs = plt.subplots(3, 1, figsize=(6, 7.8))
    fig = plt.figure(figsize=(4, 3.6))
    gs = gridspec.GridSpec(2, 1, figure=fig)
    # gs.update(left=0.12, right=0.975, top=0.96, bottom=0.04, hspace=0.3)
    gs.update(left=0.12, right=0.975, top=0.94, bottom=0.06, hspace=0.3)
    # axs = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[2, 0])]
    axs = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[1, 0])]
    for d, d_inv, en, en_inv, label, c in zip(ic_data_kappaR, ic_data_kappaR_inv, energy_kappaR.T, energy_kappaR_inv.T, labels_kappaR, COLOR_LIST):
        axs[0].set_title('Screening', fontsize=11, y=0.98)
        axs[0].plot(d[:, 0], d[:, 1], label=label, c=c)
        axs[0].plot(x_axis_kappaR, en, ls='--', c=c)
        axs[0].plot(d_inv[:, 0], d_inv[:, 1], c=c)
        axs[0].plot(x_axis_kappaR, en_inv, ls='--', c=c)
        DipolePath3.plot_style(fig, axs[0], size=None)
        axs[0].get_legend().set_bbox_to_anchor((0.65, 1.03))
    # for d, d_inv, en, en_inv, label, c in zip(ic_data_abar, ic_data_abar_inv, energy_abar.T, energy_abar_inv.T, labels_abar, COLOR_LIST):
    #     axs[1].set_title('Eccentricity', fontsize=11, y=0.98)
    #     axs[1].plot(d[:, 0], d[:, 1], label=label, c=c)
    #     axs[1].plot(x_axis_abar, en, ls='--', c=c)
    #     axs[1].plot(d_inv[:, 0], d_inv[:, 1], c=c)
    #     axs[1].plot(x_axis_abar, en_inv, ls='--', c=c)
    #     DipolePath3.plot_style(fig, axs[1], size=None)
    #     axs[1].get_legend().set_bbox_to_anchor((0.65, 1.02))
    for d, d_inv, en, en_inv, label, c in reversed(list(zip(ic_data_sigma0, ic_data_sigma0_inv, energy_sigma0.T,
                                                            energy_sigma0_inv.T, labels_sigma0, COLOR_LIST[:3][::-1]))):
        axs[1].set_title('Net charge', fontsize=11, y=0.98)
        axs[1].plot(d[:, 0], d[:, 1], label=label, c=c)
        axs[1].plot(x_axis_sigma0, en, ls='--', c=c)
        axs[1].plot(d_inv[:, 0], d_inv[:, 1], c=c)
        axs[1].plot(x_axis_sigma0, en_inv, ls='--', c=c)
        DipolePath3.plot_style(fig, axs[1], size=None)
        axs[1].get_legend().set_bbox_to_anchor((0.65, 1.02))
    for ax in axs:
        ax.yaxis.set_label_coords(-0.08, 0.5)
    # axs[-1].set_xlabel('rotational path', fontsize=15)
    # plt.tight_layout()
    if save_as is not None:
        plt.savefig(save_as, dpi=300)
    plt.show()


if __name__ == '__main__':

    # sections_plot(save_as=Path("/home/andraz/ChargedShells/Figures/dipole_test_path.png"))

    # kappaR_dependence(np.array([3, 5]), 0.5,
    #                   # save_as=Path("/home/andraz/ChargedShells/Figures/quadrupole_kappaR_dep.png")
    #                   )

    # abar_dependence(np.array([0.3, 0.4, 0.5]), 3,
    #                 save_as=Path("/home/andraz/ChargedShells/Figures/quadrupole_abar_dep.png")
    #                 )

    # sigma0_dependence(np.array([-0.0002, 0.00, 0.0002]), 3, 0.5,
    #                   save_as=Path("/home/andraz/ChargedShells/Figures/quadrupole_charge_dep_abar05_kappaR3.png")
    #                   )

    # model_comparison(
    #                  # save_as=FIGURES_PATH.joinpath("ICi_data").joinpath('IC_CS_janus_path.pdf')
    #                  )

    # combined_kappaR_dependence(kappaR=[1, 3, 10], abar=0.5,
    #                      # save_as=FIGURES_PATH.joinpath("final_figures").joinpath('janus_kappaR_dep.png')
    #                      )
    #
    # combined_abar_dependence(kappaR=3, abar=[0.3, 0.4, 0.5],
    #                          # save_as=FIGURES_PATH.joinpath("final_figures").joinpath('janus_abar_dep.png')
    #                          )
    #
    # combined_sigma0_dependence(
    #                      # save_as=FIGURES_PATH.joinpath("final_figures").joinpath('janus_charge_dep.png')
    #                     )
    #
    combined_all(
                 # save_as=FIGURES_PATH.joinpath("final_figures").joinpath('janus_combined_dep.png')
                 )