fitting_example_CB.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import numpy as np
  2. from scipy.optimize import curve_fit
  3. import matplotlib.pyplot as plt
  4. #############################################################################################################################
  5. # Load data
  6. fontsize = 16
  7. fileName = "DATA/original_histograms/mass_mm_higgs_Signal.npz"
  8. with np.load(fileName) as data:
  9. bin_edges = data['bin_edges']
  10. bin_centers = data['bin_centers']
  11. bin_values = data['bin_values']
  12. bin_errors = data['bin_errors']
  13. #############################################################################################################################
  14. # Crystal Ball Function
  15. def CrystalBall(x, A, aL, aR, nL, nR, mCB, sCB):
  16. return np.piecewise(x, [(x - mCB) / sCB <= -aL,
  17. (x - mCB) / sCB >= aR],
  18. [lambda x: A * (nL / np.abs(aL))**nL * np.exp(-aL**2 / 2) * (nL / np.abs(aL) - np.abs(aL) - (x - mCB) / sCB)**(-nL),
  19. lambda x: A * (nR / np.abs(aR))**nR * np.exp(-aR**2 / 2) * (nR / np.abs(aR) - np.abs(aR) + (x - mCB) / sCB)**(-nR),
  20. lambda x: A * np.exp(-(x - mCB)**2 / (2 * sCB**2))
  21. ])
  22. #############################################################################################################################
  23. # Fit
  24. popt, pcov = curve_fit(CrystalBall, bin_centers, bin_values, sigma=bin_errors, p0=[133., 1.5, 1.5, 3.7, 9.6, 124.5, 3.])
  25. perr = np.sqrt(np.diag(pcov))
  26. A, aL, aR, nL, nR, mCB, sCB = popt
  27. xs = np.linspace(110, 160, 501)
  28. my_fit = np.array(CrystalBall(xs, A, aL, aR, nL, nR, mCB, sCB))
  29. xerrs = 0.5 * (bin_edges[1:] - bin_edges[:-1])
  30. plt.figure(figsize=(8, 4.5))
  31. plt.errorbar(bin_centers, bin_values, bin_errors, xerrs, marker='o', markersize=5, color='k', ecolor='k', ls='', label='Original histogram')
  32. plt.plot(xs, my_fit, 'r-', label='signal fit CB')
  33. plt.xlabel(r'$m_{\mu \mu}$', fontsize=fontsize)
  34. plt.ylabel('Number of events', fontsize=fontsize)
  35. plt.legend(fontsize=fontsize)
  36. plt.xticks(size=fontsize)
  37. plt.yticks(size=fontsize)
  38. plt.tight_layout()
  39. plt.grid()
  40. plt.show()
  41. # Print out the parameters
  42. print("A: {A:.3f}\naL: {aL:.3f}\naR: {aR:.3f}\nnL: {nL:.3f}\nnR: {nR:.3f}\nmCB: {mCB:.3f}\nsCB: {sCB:.3f}".format(A=A, aL=aL, aR=aR, nL=nL, nR=nR, mCB=mCB, sCB=sCB))