metrics.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # pyright: basic
  2. from __future__ import annotations
  3. import numpy as np
  4. def binary_confusion(
  5. y_true: np.ndarray, y_prob: np.ndarray, threshold: float
  6. ) -> dict[str, int]:
  7. y_true_int = y_true.astype(int)
  8. y_pred = (y_prob >= threshold).astype(int)
  9. tp = int(np.logical_and(y_pred == 1, y_true_int == 1).sum())
  10. fp = int(np.logical_and(y_pred == 1, y_true_int == 0).sum())
  11. tn = int(np.logical_and(y_pred == 0, y_true_int == 0).sum())
  12. fn = int(np.logical_and(y_pred == 0, y_true_int == 1).sum())
  13. return {"tp": tp, "fp": fp, "tn": tn, "fn": fn}
  14. def _safe_div(num: float, den: float) -> float:
  15. if den == 0:
  16. return 0.0
  17. return num / den
  18. def performance_at_threshold(
  19. y_true: np.ndarray,
  20. y_prob: np.ndarray,
  21. threshold: float,
  22. ) -> dict[str, float]:
  23. c = binary_confusion(y_true, y_prob, threshold)
  24. tp = c["tp"]
  25. fp = c["fp"]
  26. tn = c["tn"]
  27. fn = c["fn"]
  28. total = tp + fp + tn + fn
  29. accuracy = _safe_div(tp + tn, total)
  30. precision = _safe_div(tp, tp + fp)
  31. recall = _safe_div(tp, tp + fn)
  32. f1 = _safe_div(2 * precision * recall, precision + recall)
  33. return {
  34. "threshold": float(threshold),
  35. "accuracy": float(accuracy),
  36. "precision": float(precision),
  37. "recall": float(recall),
  38. "f1": float(f1),
  39. "tp": float(tp),
  40. "fp": float(fp),
  41. "tn": float(tn),
  42. "fn": float(fn),
  43. }
  44. def threshold_sweep(
  45. y_true: np.ndarray,
  46. y_prob: np.ndarray,
  47. thresholds: np.ndarray,
  48. ) -> list[dict[str, float]]:
  49. return [performance_at_threshold(y_true, y_prob, float(t)) for t in thresholds]
  50. def calibration_stats(
  51. y_true: np.ndarray,
  52. y_prob: np.ndarray,
  53. bins: int = 10,
  54. ) -> tuple[dict[str, float], np.ndarray]:
  55. y_true_int = y_true.astype(int)
  56. y_prob_f = y_prob.astype(float)
  57. edges = np.linspace(0.0, 1.0, bins + 1)
  58. bin_data: list[tuple[float, float, int]] = []
  59. mce = 0.0
  60. n = len(y_prob_f)
  61. for i in range(bins):
  62. lo = edges[i]
  63. hi = edges[i + 1]
  64. if i == bins - 1:
  65. mask = (y_prob_f >= lo) & (y_prob_f <= hi)
  66. else:
  67. mask = (y_prob_f >= lo) & (y_prob_f < hi)
  68. count = int(mask.sum())
  69. if count == 0:
  70. bin_data.append((float((lo + hi) / 2.0), np.nan, 0))
  71. continue
  72. mean_conf = float(y_prob_f[mask].mean())
  73. frac_pos = float(y_true_int[mask].mean())
  74. gap = abs(frac_pos - mean_conf)
  75. mce = max(mce, gap)
  76. bin_data.append((mean_conf, frac_pos, count))
  77. brier = float(np.mean((y_prob_f - y_true_int) ** 2))
  78. summary = {
  79. "mce": float(mce),
  80. "brier": brier,
  81. "bins": float(bins),
  82. }
  83. arr = np.array(bin_data, dtype=float)
  84. return summary, arr