| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # pyright: basic
- from __future__ import annotations
- import numpy as np
- def binary_confusion(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> dict[str, int]:
- y_true_int = y_true.astype(int)
- y_pred = (y_prob >= threshold).astype(int)
- tp = int(np.logical_and(y_pred == 1, y_true_int == 1).sum())
- fp = int(np.logical_and(y_pred == 1, y_true_int == 0).sum())
- tn = int(np.logical_and(y_pred == 0, y_true_int == 0).sum())
- fn = int(np.logical_and(y_pred == 0, y_true_int == 1).sum())
- return {"tp": tp, "fp": fp, "tn": tn, "fn": fn}
- def _safe_div(num: float, den: float) -> float:
- if den == 0:
- return 0.0
- return num / den
- def performance_at_threshold(
- y_true: np.ndarray,
- y_prob: np.ndarray,
- threshold: float,
- ) -> dict[str, float]:
- c = binary_confusion(y_true, y_prob, threshold)
- tp = c["tp"]
- fp = c["fp"]
- tn = c["tn"]
- fn = c["fn"]
- total = tp + fp + tn + fn
- accuracy = _safe_div(tp + tn, total)
- precision = _safe_div(tp, tp + fp)
- recall = _safe_div(tp, tp + fn)
- f1 = _safe_div(2 * precision * recall, precision + recall)
- return {
- "threshold": float(threshold),
- "accuracy": float(accuracy),
- "precision": float(precision),
- "recall": float(recall),
- "f1": float(f1),
- "tp": float(tp),
- "fp": float(fp),
- "tn": float(tn),
- "fn": float(fn),
- }
- def threshold_sweep(
- y_true: np.ndarray,
- y_prob: np.ndarray,
- thresholds: np.ndarray,
- ) -> list[dict[str, float]]:
- return [performance_at_threshold(y_true, y_prob, float(t)) for t in thresholds]
- def calibration_stats(
- y_true: np.ndarray,
- y_prob: np.ndarray,
- bins: int = 10,
- ) -> tuple[dict[str, float], np.ndarray]:
- y_true_int = y_true.astype(int)
- y_prob_f = y_prob.astype(float)
- edges = np.linspace(0.0, 1.0, bins + 1)
- bin_data: list[tuple[float, float, int]] = []
- ece = 0.0
- mce = 0.0
- n = len(y_prob_f)
- for i in range(bins):
- lo = edges[i]
- hi = edges[i + 1]
- if i == bins - 1:
- mask = (y_prob_f >= lo) & (y_prob_f <= hi)
- else:
- mask = (y_prob_f >= lo) & (y_prob_f < hi)
- count = int(mask.sum())
- if count == 0:
- bin_data.append((float((lo + hi) / 2.0), np.nan, 0))
- continue
- mean_conf = float(y_prob_f[mask].mean())
- frac_pos = float(y_true_int[mask].mean())
- gap = abs(frac_pos - mean_conf)
- ece += (count / n) * gap
- mce = max(mce, gap)
- bin_data.append((mean_conf, frac_pos, count))
- brier = float(np.mean((y_prob_f - y_true_int) ** 2))
- summary = {
- "ece": float(ece),
- "mce": float(mce),
- "brier": brier,
- "bins": float(bins),
- }
- arr = np.array(bin_data, dtype=float)
- return summary, arr
|