from charged_shells import expansion, interactions
from charged_shells.parameters import ModelParams
import time
import numpy as np
import matplotlib.pyplot as plt


def v22_distance_test():

    params = ModelParams(R=10, kappaR=3.29)
    ex0 = expansion.Expansion24(1, 0, 0)
    ex1 = ex0.clone()

    ex0.rotate_euler(0, np.array([0, 0, np.pi / 2]), 0)
    ex1.rotate_euler(0, np.array([0, np.pi / 2, np.pi / 2]), 0)

    dist = np.linspace(2, 3.2, 100)
    energy_array = np.zeros((dist.shape[0], 3))
    for i, d in enumerate(dist):
        energy_array[i, ...] = interactions.charged_shell_energy(ex0, ex1, params, d)

    print(interactions.charged_shell_energy(ex0, ex1, params, dist=2.))

    plt.plot(dist, energy_array)
    plt.show()


def quadrupole_variation_test():

    params = ModelParams(R=10, kappaR=3.29)
    sigma2 = np.array([0.45, 0.5, 0.55, 0.6, 0.65])
    ex0 = expansion.Expansion24(sigma2, 0, sigma0=0.1)
    ex1 = ex0.clone()

    ex1.rotate_euler(0, np.pi / 2, 0)

    dist = np.linspace(2, 3.2, 100)
    energy_array = np.zeros((dist.shape[0], len(sigma2)))
    for i, d in enumerate(dist):
        energy_array[i, ...] = interactions.charged_shell_energy(ex0, ex1, params, d)

    plt.plot(dist, energy_array)
    plt.show()


def timing():
    params = ModelParams(R=150, kappaR=3)
    ex1 = expansion.MappedExpansionQuad(0.44, params.kappaR, 0.001, l_max=20)
    ex2 = ex1.clone()

    dist = 2.

    # ex1, ex2 = expansions_to_common_l(ex1, ex2)
    # print(ex1.coeffs)
    # print(ex2.coeffs)

    t0 = time.perf_counter()
    energy = interactions.charged_shell_energy(ex1, ex2, params, dist)
    t1 = time.perf_counter()

    print('energy: ', energy)
    print('time: ', t1 - t0)

    # plt.plot(energy)
    # plt.show()


if __name__ == '__main__':

    # v22_distance_test()
    # timing()

    quadrupole_variation_test()