import unittest
import numpy as np
from charged_shells import expansion, interactions, parameters, units_and_constants, mapping, potentials
from functools import partial


EPSILON_0 = units_and_constants.CONSTANTS.epsilon0


class IsotropicTest(unittest.TestCase):

    def setUp(self):
        self.charge = 560
        self.kappaR = np.array([0.1, 1, 3, 5, 10, 20])
        self.radius = 150
        self.dist = 2 * self.radius
        self.sigma0 = self.charge / (4 * np.pi * self.radius ** 2)
        self.ex1 = expansion.MappedExpansionQuad(0, self.kappaR, 0, l_max=10, sigma0=self.sigma0)
        self.ex2 = self.ex1.clone()
        self.params = parameters.ModelParams(kappaR=self.kappaR, R=self.radius)

    def test_potential(self):
        theta = np.linspace(0, np.pi, 100)

        def cs_potential_fn(ex, params):
            return potentials.charged_shell_potential(theta, 0., self.dist / self.radius, ex, params)

        mapped_cs_potential_fn = mapping.parameter_map_single_expansion(cs_potential_fn, 0)
        cs_potential = mapped_cs_potential_fn(self.ex1, self.params)

        ic_potential = []
        for p in mapping.unravel_params(self.params):
            ic_potential.append(potentials.inverse_patchy_particle_potential(theta, 2., 0, self.sigma0,
                                                                             (0., 0.), p, lmax=10))
        ic_potential = np.array(ic_potential)

        np.testing.assert_almost_equal(cs_potential, ic_potential)

    def test_interaction(self):

        int_analytic = (self.charge ** 2 / (4 * np.pi * EPSILON_0 * self.params.epsilon) *
                        (np.exp(self.params.kappaR) / (1 + self.params.kappaR)) ** 2 * np.exp(-self.params.kappa * self.dist) / self.dist)

        energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy,
                                                                 dist=self.dist/self.radius, units='eV'),
                                                         0)
        int_comp = energy_fn(self.ex1, self.ex2, self.params)

        # print(int_analytic)
        # print(int_comp)
        # print(int_comp / int_analytic)

        np.testing.assert_almost_equal(int_comp, int_analytic)