import expansion
import functions as fn
import numpy as np
import parameters
import matplotlib.pyplot as plt


Array = np.ndarray
ModelParams = parameters.ModelParams
Expansion = expansion.Expansion


def charged_shell_potential(theta: Array | float,
                            phi: Array | float,
                            dist: float,
                            ex: Expansion,
                            params: ModelParams) -> Array:
    """
    Electrostatic potential around a charged shell with patches given by expansion over spherical harmonics.

    :param theta: array of azimuthal angles
    :param phi: array of polar angles
    :param dist: distance between the particles in units of radius R
    :param ex: Expansion object detailing patch distribution
    :param params: ModelParams object specifying parameter values for the model
    """
    if isinstance(theta, float):
        theta = np.full_like(phi, theta)

    if isinstance(phi, float):
        phi = np.full_like(theta, phi)

    if not theta.shape == phi.shape:
        raise ValueError('theta and phi arrays should have the same shape.')
    l_array, m_array = ex.lm_arrays

    l_factors = (fn.coefficient_Cpm(ex.l_array, params.kappaR) * fn.sph_bessel_k(ex.l_array, params.kappa * dist)
                 / fn.sph_bessel_k(ex.l_array, params.kappaR))

    return (1 / (params.kappa * params.epsilon * params.epsilon0)
            * np.real(np.sum(ex.repeat_over_m(l_factors)[None, :] * ex.coeffs
                             * fn.sph_harm(l_array[None, :], m_array[None, :], theta[:, None], phi[:, None]), axis=1)))


if __name__ == '__main__':

    params = ModelParams(1, 3, 1, 1)
    ex = expansion.MappedExpansion(1, params.kappaR, 0.001, max_l=20)

    theta = np.linspace(0, np.pi, 1000)
    phi = 0.
    dist = 1.

    potential = charged_shell_potential(theta, phi, dist, ex, params)
    # print(potential)

    plt.plot(potential)
    plt.show()