fitting_example_GPR-logartihm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import numpy as np
  2. from sklearn.gaussian_process import GaussianProcessRegressor
  3. # Different kernels available in sklearn:
  4. from sklearn.gaussian_process.kernels import RBF, ConstantKernel, Matern, WhiteKernel, ExpSineSquared, RationalQuadratic
  5. from matplotlib import pyplot as plt
  6. plt.rcParams.update({
  7. 'axes.labelsize': 16,
  8. 'xtick.labelsize': 14,
  9. 'ytick.labelsize': 14
  10. })
  11. # Set random seed for same results
  12. np.random.seed(12345)
  13. # Load data
  14. inFileName = 'DATA/original_histograms/mass_mm_higgs_Background.npz'
  15. with np.load(inFileName) as data:
  16. bin_centers = data['bin_centers']
  17. bin_values = np.log(data['bin_values'])
  18. bin_errors = data['bin_errors'] / data['bin_values'] # How does the error scale when taking f(x) = ln(x)?
  19. # Mask
  20. signal_region = (119, 131)
  21. mask = (bin_centers < signal_region[0]) | (bin_centers > signal_region[1])
  22. bin_centers_masked = bin_centers[mask]
  23. bin_values_masked = bin_values[mask]
  24. bin_errors_masked = bin_errors[mask]
  25. # Set hyper-parameter bounds for ConstantKernel
  26. nEvts = np.max(bin_values)
  27. const0 = 1.
  28. const_low = 1e-1
  29. const_hi = 1e3
  30. # Set hyper-parameter bounds for RBF kernel
  31. RBF0 = 1.
  32. RBF_low = 1e-1
  33. RBF_high = 1e2
  34. # A) Define kernel: ConstantKernel * RBF:
  35. kernel_RBF = ConstantKernel(const0, constant_value_bounds=(const_low, const_hi)) * RBF(RBF0, length_scale_bounds=(RBF_low, RBF_high))
  36. # B) Define kernel: ConstantKernel * Matern:
  37. kernel_Matern = ConstantKernel(const0, constant_value_bounds=(const_low, const_hi)) * Matern(RBF0, length_scale_bounds=(RBF_low, RBF_high), nu=1.5)
  38. # Transform x data into 2d vector!
  39. X = np.atleast_2d(bin_centers_masked).T # true datapoints
  40. X_to_predict = np.atleast_2d(np.linspace(110, 160, 1000)).T # what to predict
  41. y = bin_values_masked
  42. # Initialize Gaussian Process Regressor: !!! alpha = bin_errors, 2*bin_errors or bin_errors**2? Your task to figure out!!!
  43. gp = GaussianProcessRegressor(kernel=kernel_RBF, n_restarts_optimizer=1, alpha=bin_errors_masked**2)
  44. # Fit on X with values y
  45. gp.fit(X, y)
  46. print('Final kernel combination:\n', gp.kernel_)
  47. # Predict
  48. y_pred, sigma = gp.predict(X_to_predict, return_std=True)
  49. y_pred_sparse, sigma_sparse = gp.predict(np.atleast_2d(bin_centers).T, return_std=True)
  50. fig, axes = plt.subplot_mosaic([['main'],['main'],['main'],['ratio']], sharex=True, figsize=(8, 8))
  51. # Main pad
  52. axes['main'].set_title('Example GPR with RBF kernel', fontsize=20, fontweight='bold')
  53. axes['main'].fill_between(X_to_predict.ravel(), y_pred - sigma, y_pred + sigma)
  54. axes['main'].scatter(bin_centers_masked, bin_values_masked, color='r', linewidth=0.5, marker='o', s=25, label='Data')
  55. axes['main'].scatter(bin_centers[~mask], bin_values[~mask], color='g', marker='+', s=100, label='Blinded data (not used in the fit)')
  56. axes['main'].plot(X_to_predict, y_pred, color='k', label='GPR Prediction')
  57. axes['main'].set_ylabel('ln(events/bin)', fontsize=16)
  58. axes['main'].set_ylim((8, 12))
  59. axes['main'].legend(fontsize=16)
  60. # Ratio pad
  61. axes['ratio'].errorbar(bin_centers, bin_values/y_pred_sparse, yerr=sigma_sparse, color='k', linewidth=0., elinewidth=0.5, marker='.')
  62. axes['ratio'].axhline(1, c='k', lw=1, alpha=0.7)
  63. axes['ratio'].set_xlabel(r'$m_{\mu\mu}$ [GeV]', fontsize=16)
  64. axes['ratio'].set_ylabel('Data/Pred.', fontsize=16)
  65. # Make ratio plot labels symmetric around 1.
  66. max = np.max(np.abs(axes['ratio'].get_yticks() - 1.)) / 1.5
  67. axes['ratio'].set_ylim((1. - max, 1. + max))
  68. axes['ratio'].grid()
  69. # Make an inner plot
  70. axes['main'].plot([113, 119], [9.9, 10.9], 'k--')
  71. axes['main'].plot([138, 131], [9.9, 10.3], 'k--')
  72. ax_inner = fig.add_axes([0.2125, 0.3625, 0.4, 0.25])
  73. ax_inner.set_title('zoom-in', fontsize=20)
  74. ax_inner.scatter(bin_centers[~mask], bin_values[~mask], color='g', marker='+', s=100)
  75. ax_inner.plot(X_to_predict, y_pred, color='k')
  76. ax_inner.set_xlim(signal_region)
  77. ax_inner.set_ylim((10.3, 11))
  78. ax_inner.grid()
  79. plt.tight_layout()
  80. plt.show()
  81. plt.savefig('GPR_simple.pdf')