metrics.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # pyright: basic
  2. from __future__ import annotations
  3. import numpy as np
  4. def binary_confusion(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> dict[str, int]:
  5. y_true_int = y_true.astype(int)
  6. y_pred = (y_prob >= threshold).astype(int)
  7. tp = int(np.logical_and(y_pred == 1, y_true_int == 1).sum())
  8. fp = int(np.logical_and(y_pred == 1, y_true_int == 0).sum())
  9. tn = int(np.logical_and(y_pred == 0, y_true_int == 0).sum())
  10. fn = int(np.logical_and(y_pred == 0, y_true_int == 1).sum())
  11. return {"tp": tp, "fp": fp, "tn": tn, "fn": fn}
  12. def _safe_div(num: float, den: float) -> float:
  13. if den == 0:
  14. return 0.0
  15. return num / den
  16. def performance_at_threshold(
  17. y_true: np.ndarray,
  18. y_prob: np.ndarray,
  19. threshold: float,
  20. ) -> dict[str, float]:
  21. c = binary_confusion(y_true, y_prob, threshold)
  22. tp = c["tp"]
  23. fp = c["fp"]
  24. tn = c["tn"]
  25. fn = c["fn"]
  26. total = tp + fp + tn + fn
  27. accuracy = _safe_div(tp + tn, total)
  28. precision = _safe_div(tp, tp + fp)
  29. recall = _safe_div(tp, tp + fn)
  30. f1 = _safe_div(2 * precision * recall, precision + recall)
  31. return {
  32. "threshold": float(threshold),
  33. "accuracy": float(accuracy),
  34. "precision": float(precision),
  35. "recall": float(recall),
  36. "f1": float(f1),
  37. "tp": float(tp),
  38. "fp": float(fp),
  39. "tn": float(tn),
  40. "fn": float(fn),
  41. }
  42. def threshold_sweep(
  43. y_true: np.ndarray,
  44. y_prob: np.ndarray,
  45. thresholds: np.ndarray,
  46. ) -> list[dict[str, float]]:
  47. return [performance_at_threshold(y_true, y_prob, float(t)) for t in thresholds]
  48. def calibration_stats(
  49. y_true: np.ndarray,
  50. y_prob: np.ndarray,
  51. bins: int = 10,
  52. ) -> tuple[dict[str, float], np.ndarray]:
  53. y_true_int = y_true.astype(int)
  54. y_prob_f = y_prob.astype(float)
  55. edges = np.linspace(0.0, 1.0, bins + 1)
  56. bin_data: list[tuple[float, float, int]] = []
  57. ece = 0.0
  58. mce = 0.0
  59. n = len(y_prob_f)
  60. for i in range(bins):
  61. lo = edges[i]
  62. hi = edges[i + 1]
  63. if i == bins - 1:
  64. mask = (y_prob_f >= lo) & (y_prob_f <= hi)
  65. else:
  66. mask = (y_prob_f >= lo) & (y_prob_f < hi)
  67. count = int(mask.sum())
  68. if count == 0:
  69. bin_data.append((float((lo + hi) / 2.0), np.nan, 0))
  70. continue
  71. mean_conf = float(y_prob_f[mask].mean())
  72. frac_pos = float(y_true_int[mask].mean())
  73. gap = abs(frac_pos - mean_conf)
  74. ece += (count / n) * gap
  75. mce = max(mce, gap)
  76. bin_data.append((mean_conf, frac_pos, count))
  77. brier = float(np.mean((y_prob_f - y_true_int) ** 2))
  78. summary = {
  79. "ece": float(ece),
  80. "mce": float(mce),
  81. "brier": brier,
  82. "bins": float(bins),
  83. }
  84. arr = np.array(bin_data, dtype=float)
  85. return summary, arr