Plots.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. import itertools
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import matplotlib.tri as tri
  5. import matplotlib.cm as cm
  6. from matplotlib import gridspec
  7. from matplotlib.backends.backend_pdf import PdfPages
  8. from math import sqrt, pi, ceil
  9. from scipy.stats import ks_2samp
  10. from scipy.special import erf
  11. from Tools.Decomp import TruncatedSeries
  12. _show = False
  13. # Integrate a function over bins as specified by edges. Npt points per bin.
  14. def _integrate(func, edges, Npt=100):
  15. N = len(edges) - 1
  16. t = np.zeros((N,))
  17. x = [ np.linspace(edges[n], edges[n+1], Npt + 1) for n in range (N) ]
  18. x = np.asarray(x)
  19. s = x.shape
  20. z = func( x.flatten() ).reshape(s)
  21. for n in range(N):
  22. t[n] = np.trapz( z[n], dx = 1. / Npt)
  23. return t
  24. # Gaussian CDF
  25. def _cdf(z):
  26. return (1./2) * ( 1 - erf(z/sqrt(2)) )
  27. # Interleave list items
  28. def _flip(items, ncol):
  29. return list(itertools.chain(*[items[i::ncol] for i in range(ncol)]))
  30. #########
  31. # A plot decorator providing some common code for the remaining plots.
  32. def gsPlot(*fargs, **fkw):
  33. def outer(func):
  34. def wrap(*args, **kwargs):
  35. fname = kwargs.pop("fname", None)
  36. pdf = kwargs.pop("pdf", None)
  37. print fname
  38. fig = plt.figure()
  39. gs = gridspec.GridSpec(*fargs, **fkw)
  40. # Run the function
  41. r = func(fig, gs, *args, **kwargs)
  42. # Do some layout cleanup and save.
  43. gs.tight_layout(plt.gcf(), rect=[0, 0, 1, 0.97])
  44. gs.update(wspace=0.00, hspace=0.1)
  45. if pdf is not None:
  46. plt.savefig(pdf, format='pdf')
  47. if fname is not None:
  48. plt.savefig(fname)
  49. if _show:
  50. plt.show()
  51. plt.gcf().clear()
  52. return r
  53. return wrap
  54. return outer
  55. ######## CUTFLOW #######
  56. @gsPlot(1, 1)
  57. def cutflow(fig, gs, cutflow):
  58. ax = [ plt.subplot(g) for g in gs ]
  59. cut, yld = zip(*cutflow)
  60. cuty = np.arange(len(cut))[::-1]
  61. bar = ax[0].barh(cuty, yld, align='center', alpha=0.5)
  62. fig.suptitle ('Cutflow')
  63. ax[0].set_yticks(cuty)
  64. ax[0].set_yticklabels(cut)
  65. ax[0].set_xlabel('Yield (Events)')
  66. # Annotate
  67. for h, b in zip(yld, bar):
  68. bx = h + 0.01*max(yld)
  69. by = b.get_y() + b.get_height()/2.
  70. ax[0].text(bx, by, "%.1f" % h, ha='left', va='center')
  71. ax[0].margins(x=0.175)
  72. ax[0].ticklabel_format(style='sci', axis='x', scilimits=(0,0))
  73. ######## HYPERPARMETER SCAN #######
  74. @gsPlot(1, 2, width_ratios=[19, 1])
  75. def scan(fig, gs, L, A, LLH, LBest, Ld, Ad, fin, maxZ=200, points=False):
  76. ax = [ plt.subplot(g) for g in gs ]
  77. # interpolate between data points for contouring
  78. triI = tri.Triangulation(L.flatten(), A.flatten())
  79. ref = tri.UniformTriRefiner(triI)
  80. dLLH = np.minimum(LLH - LBest, 1.5*maxZ)
  81. triO = ref.refine_field(dLLH.flatten(), subdiv=3)
  82. cmap = cm.get_cmap(name='terrain', lut=None)
  83. levels = np.linspace(0, maxZ, 51)
  84. colors = [ '0.00', '0.25', '0.25', '0.25', '0.25']
  85. lws = [ 0.40, 0.25, 0.25, 0.25, 0.25]
  86. Csf = ax[0].tricontourf(*triO, levels=levels, cmap=cmap)
  87. Cs = ax[0].tricontour (*triO, levels=levels, colors=colors, linewidths=lws)
  88. if points:
  89. sca = ax[0].scatter(L, A, marker='.', s= 1.0, color='k', label="Scan Point")
  90. sca2 = ax[0].scatter(Ld, Ad, marker='o', s=20.0, color='r', label="Initial", zorder=10)
  91. sca3 = ax[0].scatter(*fin, marker='x', s=35.0, color='r', label="Final", zorder=10)
  92. fig.colorbar(Csf, ticks=levels[::5], cax=ax[1])
  93. fig.suptitle("Hyperparameter Scan")
  94. ax[0].legend(loc='upper left')
  95. ax[0].set_xlabel(r'Scale ($\lambda$)')
  96. ax[0].set_ylabel(r'Exponent ($\alpha$)')
  97. ax[1].set_ylabel(r'$\Delta$ Log-Likelihood')
  98. ######## PLOT #######
  99. @gsPlot(3, 1, height_ratios=[3, 1, 1] )
  100. def fit(fig, gs, D, **kwargs):
  101. ax = [ plt.subplot(g) for g in gs ]
  102. # Parameters
  103. Bins = kwargs.get("Bins")
  104. Title = kwargs.get("Title", "Unnamed")
  105. XLabel = kwargs.get("XLabel", "Mass")
  106. YLabel = kwargs.get("YLabel", "Events / Bin")
  107. LogX = kwargs.get("LogX", True)
  108. LogY = kwargs.get("LogY", True)
  109. Style = kwargs.get("Style", "bar")
  110. ResYLim = kwargs.get("ResYLim", (-2.5, 2.5))
  111. YLim = kwargs.get("YLim", None)
  112. # Get some bin-derived quantities
  113. ctr = (Bins[1:] + Bins[:-1])/2
  114. wd = np.diff(Bins)
  115. rn = (Bins[0], Bins[-1])
  116. h, _ = np.histogram(D.x, bins=Bins, range=rn, weights=D.w)
  117. h *= D.Nint / wd
  118. t = np.linspace(rn[0], rn[1], 50*len(Bins))
  119. err = np.sqrt(h*wd, dtype=np.double)/wd
  120. # Make fit comparison
  121. tb = D.Nint * _integrate(D.TestB, Bins)
  122. ts=0
  123. if len(D.GetActive()) > 0:
  124. ts = D.Nint * _integrate(D.TestS, Bins)
  125. res = (h*wd - ts*wd)/np.sqrt(ts*wd)
  126. else:
  127. res = (h*wd - tb*wd)/np.sqrt(tb*wd)
  128. # Histogram and fit
  129. if Style == "bar":
  130. ax[0].bar(ctr, h, width=wd, log=LogX, label='Data', edgecolor="none", lw=0)
  131. elif Style == "errorbar":
  132. ax[0].errorbar(ctr, h, xerr=wd/2, yerr=err, label='Data', color='k', fmt='o')
  133. else:
  134. print "Style key must be 'bar' or 'errorbar'."
  135. ax[0].plot(t, D.Nint*D.TestB(t), ls='--', color='red', label='Background', zorder=10)
  136. if len(D.GetActive()) > 0:
  137. ax[0].plot(t, D.Nint*D.TestS(t), ls='-', color='red', label='Signal+Bkg', zorder=10)
  138. ax[0].legend()
  139. ax[0].yaxis.grid(ls=':')
  140. ax[0].set_ylabel(YLabel)
  141. if LogY: ax[0].set_yscale('log')
  142. if YLim is not None: ax[0].set_ylim(*YLim)
  143. # The background-subtracted data
  144. if Style == "bar":
  145. ax[1].bar (ctr, h-tb, width=wd, edgecolor="none", lw=0)
  146. elif Style == "errorbar":
  147. ax[1].errorbar (ctr, h-tb, xerr=wd/2, yerr=err, color='k', fmt='o')
  148. else:
  149. print "Style key must be 'bar' or 'errorbar'."
  150. ax[1].plot(t, np.zeros_like(t), ls='--', color='red')
  151. if len(D.GetActive()) > 0:
  152. ax[1].plot(t, D.Nint*(D.TestS(t) - D.TestB(t)), ls='-', color='red', zorder=10)
  153. ax[1].ticklabel_format(style='sci', axis='y', scilimits=(-2,2))
  154. ax[1].set_ylabel(r'Data - Bkg')
  155. # The residual plot.
  156. ax[2].bar(ctr, res, width=wd, edgecolor="none", lw=0)
  157. ax[2].plot(t, 0*t, color='black', lw=1.0)
  158. ax[2].set_ylim(*ResYLim)
  159. ax[2].set_ylabel(r'Residual ($\sigma$)')
  160. # Shared formatting
  161. fig.suptitle(Title)
  162. for a in ax:
  163. a.xaxis.grid(ls=':')
  164. a.yaxis.set_label_coords(-0.065, 0.5)
  165. a.set_xlim(*rn)
  166. if LogX: a.set_xscale('log')
  167. for a in ax[:-1]:
  168. a.tick_params(labelbottom='off')
  169. ax[-1].set_xlabel(XLabel)
  170. return h, res
  171. ######## PULL ########
  172. @gsPlot(1, 1)
  173. def pull(fig, gs, data, res):
  174. ax = [ plt.subplot(g) for g in gs ]
  175. kres = np.compress(data > 20, np.nan_to_num(res))
  176. hist, edges = np.histogram(kres, bins=np.linspace(-5, 5, 21))
  177. centers = (edges[1:] + edges[:-1])/2
  178. nrm = np.random.normal(size=20*len(kres))
  179. ks_p = ks_2samp(nrm, kres)[1] if len(kres) > 0 else 1.0
  180. fig.suptitle("Pull Distribution (Bins with $>20$ Events)")
  181. ax[0].set_xlabel(r'Deviation ($\sigma$)')
  182. ax[0].set_ylabel("Number of Bins")
  183. ax[0].bar (centers, hist, width=np.diff(edges), label='Bin Residuals \n $p=%.2g$ (KS)' % ks_p)
  184. ax[0].errorbar(centers, hist, yerr=np.sqrt(hist), color='k', fmt='o')
  185. ax[0].set_xlim(-5, 5)
  186. t = np.linspace(-5, 5, 201)
  187. n = np.exp( -0.5*t**2 ) / sqrt(2*pi)
  188. ax[0].plot(t, 0.5*n*hist.sum(), lw=1.5, color='b', label=r'Standard Normal')
  189. ax[0].legend()
  190. ######## SIGNALS AND ESTIMATORS ########
  191. @gsPlot(1, 1)
  192. def estimators(fig, gs, D, **kwargs):
  193. ax = [ plt.subplot(g) for g in gs ]
  194. # Parameters
  195. Signals = kwargs.get("Signals", [])
  196. Draw = set(kwargs.get("Draw", ["Estimators"]))
  197. Range = kwargs.get("Range")
  198. Title = kwargs.get("Title", "Unnamed")
  199. XLabel = kwargs.get("XLabel", "Mass")
  200. YLabel = kwargs.get("YLabel", "Arbitrary Units")
  201. LogX = kwargs.get("LogX", True)
  202. t = np.linspace(Range[0], Range[1], 1001)
  203. ax[0].plot(t, 0*t, lw=0.75, color='k')
  204. for sigName in Signals:
  205. eName = sigName.replace('%', '\%')
  206. M = np.zeros_like(D[sigName].Sig)
  207. if "Signal" in Draw:
  208. M[:] = D[sigName].Sig
  209. E = TruncatedSeries(D.Factory, M)
  210. ax[0].plot(t, E(t), lw=1.0, label=eName + " (Signal)")
  211. if "Residual" in Draw:
  212. M[:D.N] = 0
  213. M[D.N:] = D[sigName].Res
  214. E = TruncatedSeries(D.Factory, M)
  215. ax[0].plot(t, E(t), lw=1.0, label=eName + " (Residual)")
  216. if "Estimator" in Draw:
  217. M[:D.N] = 0
  218. M[D.N:] = D[sigName].Est
  219. E = TruncatedSeries(D.Factory, M)
  220. ax[0].plot(t, E(t), lw=1.0, label=eName + " (MinVar Estimator)")
  221. fig.suptitle(Title)
  222. ax[0].set_xlabel(XLabel)
  223. ax[0].set_ylabel(YLabel)
  224. ax[0].set_xlim(*Range)
  225. ax[0].legend()
  226. ######## MOMENT LINE/BAR PLOT ########
  227. @gsPlot(1, 1)
  228. def moments(fig, gs, D, **kwargs):
  229. def _bplot(a, x, y, label, style, Num, n):
  230. if style == "line":
  231. a.plot(x, y**2, label=label)
  232. elif style == "bar":
  233. a.bar (Num*x + n, y**2, label=label, lw=0)
  234. ax = [ plt.subplot(g) for g in gs ]
  235. # Parameters
  236. Signals = kwargs.get("Signals", [])
  237. Draw = set(kwargs.get("Draw", ["Estimators"]))
  238. Range = kwargs.get("Range")
  239. Style = kwargs.get("Style", "line")
  240. Title = kwargs.get("Title", "Unnamed")
  241. XLabel = kwargs.get("XLabel", "Moment #")
  242. YLabel = kwargs.get("YLabel", r"$\left|\mathrm{Moment}\right|^2$")
  243. LogX = kwargs.get("LogX", True)
  244. LogY = kwargs.get("LogY", True)
  245. ctr = np.arange(*Range)
  246. Num = 2 + len(Draw) * len(Signals)
  247. _bplot( ax[0], ctr, D.Mom[ctr], "Data", Style, Num, 0)
  248. n = 1
  249. for sigName in Signals:
  250. eName = sigName.replace('%', '\%')
  251. M = np.zeros_like(D[sigName].Sig)
  252. if "Signal" in Draw:
  253. M[:] = D[sigName].Sig
  254. _bplot(ax[0], ctr, M[ctr], eName + " (Signal)", Style, Num, n)
  255. n += 1
  256. if "Residual" in Draw:
  257. M[:D.N] = 0
  258. M[D.N:] = D[sigName].Res
  259. _bplot(ax[0], ctr, M[ctr], eName + " (Residual)", Style, Num, n)
  260. n += 1
  261. if "Estimator" in Draw:
  262. M[:D.N] = 0
  263. M[D.N:] = D[sigName].Est
  264. _bplot(ax[0], ctr, M[ctr], eName + " (MinVar Estimator)", Style, Num, n)
  265. n += 1
  266. fig.suptitle(Title)
  267. ax[0].set_xlabel(XLabel)
  268. ax[0].set_ylabel(YLabel)
  269. ax[0].set_xlim(*Range)
  270. if LogY:
  271. ax[0].set_yscale('log')
  272. if Style == "bar":
  273. tnum = (int(Range[1]) / 8) * np.arange(9)
  274. ax[0].set_xticks ( [(x*Num + Num/2) for x in tnum ])
  275. ax[0].set_xticklabels( [str(x) for x in tnum ])
  276. elif Style == "line":
  277. tnum = (int(Range[1]) / 8) * np.arange(9)
  278. ax[0].set_xticks ( [x for x in tnum ])
  279. ax[0].set_xticklabels( [str(x) for x in tnum ])
  280. ax[0].legend()
  281. ######## MASS SCAN ########
  282. @gsPlot(3, 1, height_ratios=[6, 2, 2])
  283. def mass_scan(fig, gs, scan, **kwargs):
  284. ax = [ plt.subplot(g) for g in gs ]
  285. Units = kwargs.get("Units", "Events")
  286. Title = kwargs.get("Title", "Scan")
  287. XLabel = kwargs.get("XLabel", "Mass")
  288. LogX = kwargs.get("LogX", True)
  289. CMap = kwargs.get("CMap", "copper")
  290. Scans = kwargs.get("Scans", scan.keys())
  291. Bands = kwargs.get("Bands", len(Scans) == 1)
  292. XRange = kwargs.get("XRange",
  293. ( min([x.Mass.min() for x in scan.values()]),
  294. max([x.Mass.max() for x in scan.values()]) ))
  295. sig = { n: x.Sig for n, x in scan.items() }
  296. keep = { n: (x.Mass > XRange[0])*(x.Mass < XRange[1]) for n, x in scan.items() }
  297. sigmax = max([x[keep[name]].max() for name, x in sig.items()])
  298. zero = np.asarray((0,0))
  299. cmap = [ plt.get_cmap(CMap)(i) for i in np.linspace(0, 1, len(Scans)) ]
  300. # Limits
  301. for c, name in zip(cmap, Scans):
  302. e = scan[name].ExpLim
  303. n = name.replace("%", "\%")
  304. if Bands:
  305. ax[0].fill_between(scan[name].Mass, e[:,0], e[:,4], color='y' )
  306. ax[0].fill_between(scan[name].Mass, e[:,1], e[:,3], color='g' )
  307. ax[0].plot(scan[name].Mass, e[:,2], label=n, color=c, ls='--')
  308. ax[0].plot(scan[name].Mass, scan[name].ObsLim, label=n + " ", color=c, ls='-')
  309. extr = plt.Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)
  310. h, l = ax[0].get_legend_handles_labels()
  311. ax[0].legend(_flip([ extr, extr] + h, 2),
  312. _flip(["Exp", "Obs"] + l, 2),
  313. loc='best', ncol=2)
  314. ax[0].set_ylabel(r'CL$_{95}$ (%s)' % Units)
  315. ax[0].set_yscale('log')
  316. # Deviation (sigmas)
  317. ax[1].fill_between(XRange, zero-2, zero+2, color='y' )
  318. ax[1].fill_between(XRange, zero-1, zero+1, color='g' )
  319. ax[1].plot (XRange, zero, color='k', ls=':' )
  320. for c, name in zip(cmap, Scans):
  321. ax[1].plot(scan[name].Mass, scan[name].Sig, label=name, color=c)
  322. ax[1].set_ylabel(r'Dev. ($\sigma$)')
  323. ax[1].set_ylim(-3, 3)
  324. # p-value
  325. for n in np.arange( 1, ceil(sigmax) + 1 ):
  326. t = 1.001*XRange[1] - 0.001*XRange[0]
  327. ax[2].plot(XRange, zero + _cdf(n), ls=':', lw=0.5,color='k')
  328. ax[2].text(t, _cdf(n), '$%d\sigma$' % n, va='center')
  329. for c, name in zip(cmap, Scans):
  330. M = scan[name].Mass [keep[name]]
  331. P = scan[name].PValue[keep[name]]
  332. ax[2].plot(M, P, label=name, color=c)
  333. ax[2].set_ylabel(r'p-value')
  334. ax[2].set_yscale('log')
  335. # Shared formatting
  336. fig.suptitle(Title)
  337. for a in ax:
  338. a.set_xlim( *XRange )
  339. if LogX: a.set_xscale('log')
  340. for a in ax[:-1]:
  341. a.tick_params(labelbottom='off')
  342. ax[-1].set_xlabel(XLabel)
  343. ######## COEFFICIENT TABLES ########
  344. @gsPlot(2, 3, width_ratios=[2, 1, 1])
  345. def summary_table(fig, gs, D):
  346. ax = [ plt.subplot(gs[0,0]),
  347. plt.subplot(gs[1,0]),
  348. plt.subplot(gs[ :,2]) ]
  349. labels = D.GetActive()
  350. yields = [ "%.1f" % D[n].Yield for n in labels ]
  351. uncs = [ "%.1f" % D[n].Unc for n in labels ]
  352. # Signals yields
  353. if len(labels) > 0:
  354. colL = [ "Yield", "Uncertainty" ]
  355. txt = zip(yields, uncs)
  356. ax[0].axis('tight')
  357. ax[0].axis('off')
  358. ax[0].set_title("Extracted Signal")
  359. ax[0].table(cellText=txt, rowLabels=labels, colLabels=colL, loc='center')
  360. # Correlations
  361. if len(labels) > 0:
  362. txt = [ [ "%.3f" % D.Corr[n,m] if n<= m else "" for n in range(len(labels)) ] for m in range(len(labels)) ]
  363. ax[1].axis('tight')
  364. ax[1].axis('off')
  365. ax[1].set_title("Signal Correlations")
  366. ax[1].table(cellText=txt, rowLabels=labels, colLabels=labels, loc='center')
  367. # Moments
  368. colL = [ "Value" ]
  369. rowL = [ r'$\lambda$', r'$\alpha$' ]
  370. rowL += [ r'$c_{%d}$' % n for n in range(D.TestB.Nmax) ]
  371. txt = [ ["%.2f" % D.Factory[x]] for x in ("Lambda", "Alpha") ]
  372. txt += [ ["%.2g" % D.TestB.MomAct[n]] for n in range(D.TestB.Nmax) ]
  373. ax[2].axis('tight')
  374. ax[2].axis('off')
  375. ax[2].set_title("Background Moments")
  376. tab = ax[2].table(cellText=txt, rowLabels=rowL, colLabels=colL, loc='center')