|  | @@ -19,24 +19,35 @@ def potential_patch_size(ex: Expansion, params: ModelParams,
 | 
	
		
			
				|  |  |                           phi: float = 0, theta0: Array | float = 0, theta1: Array | float = np.pi / 2,
 | 
	
		
			
				|  |  |                           match_expansion_axis_to_params: int = None):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    meatp = match_expansion_axis_to_params
 | 
	
		
			
				|  |  | +    # this is more complicate to map over leading axes of the expansion as potential also depends on model parameters,
 | 
	
		
			
				|  |  | +    # with some, such as kappaR, also being the parameters of the expansion in the first place. When mapping,
 | 
	
		
			
				|  |  | +    # we must therefore provide the expansion axis that should match the collection of parameters in params.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    meap = match_expansion_axis_to_params  # just a shorter variable name
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @expansion.map_over_expansion
 | 
	
		
			
				|  |  |      def potential_zero(exp: Expansion, prms: ModelParams):
 | 
	
		
			
				|  |  |          return bisect(lambda theta: potentials.charged_shell_potential(theta, phi, 1, exp, prms), theta0, theta1)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    print(ex.shape)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      params_list = params.unravel()
 | 
	
		
			
				|  |  | -    if meatp is not None:
 | 
	
		
			
				|  |  | -        expansion_list = [Expansion(ex.l_array, np.take(ex.coefs, i, axis=meatp)) for i in range(ex.shape[meatp])]
 | 
	
		
			
				|  |  | +    if meap is not None:
 | 
	
		
			
				|  |  | +        expansion_list = [Expansion(ex.l_array, np.take(ex.coefs, i, axis=meap)) for i in range(ex.shape[meap])]
 | 
	
		
			
				|  |  |      else:
 | 
	
		
			
				|  |  | -        expansion_list = [ex]
 | 
	
		
			
				|  |  | +        expansion_list = [ex for _ in params_list]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    if not len(expansion_list) == len(params_list):
 | 
	
		
			
				|  |  | +        raise ValueError(f'Axis of expansion that is supposed to match params does not have the same length, got '
 | 
	
		
			
				|  |  | +                         f'len(params.unravel()) = {len(params_list)} and '
 | 
	
		
			
				|  |  | +                         f'expansion.shape[{meap}] = {len(expansion_list)}')
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      results = []
 | 
	
		
			
				|  |  |      for exp, prms in zip(expansion_list, params_list):
 | 
	
		
			
				|  |  |          results.append(potential_zero(exp, prms))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    if meap is not None:
 | 
	
		
			
				|  |  | +        return np.array(results).swapaxes(0, meap)  # return the params-matched axis to where it belongs
 | 
	
		
			
				|  |  | +    return np.array(results)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  if __name__ == '__main__':
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -44,12 +55,11 @@ if __name__ == '__main__':
 | 
	
		
			
				|  |  |      kappaR = np.array([0.26, 1, 3, 10, 26])
 | 
	
		
			
				|  |  |      params = ModelParams(R=150, kappaR=kappaR)
 | 
	
		
			
				|  |  |      ex = expansion.MappedExpansionQuad(a_bar=a_bar[:, None], sigma_m=0.001, l_max=20, kappaR=kappaR[None, :])
 | 
	
		
			
				|  |  | -    print(ex.shape)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    patch_size = charge_patch_size(ex)
 | 
	
		
			
				|  |  | +    # patch_size = charge_patch_size(ex)
 | 
	
		
			
				|  |  |      patch_size_pot = potential_patch_size(ex, params, match_expansion_axis_to_params=1)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    plt.plot(a_bar, patch_size * 180 / np.pi, label=kappaR)
 | 
	
		
			
				|  |  | +    plt.plot(a_bar, patch_size_pot * 180 / np.pi, label=kappaR)
 | 
	
		
			
				|  |  |      plt.legend()
 | 
	
		
			
				|  |  |      plt.show()
 | 
	
		
			
				|  |  |  
 |