sensitivity_analysis.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # The purpose of this file is to perform a sensitivity analysis on the model evaluation results and graph the findings.
  2. # The sensitivity analysis will be done by varying the number of models used in the ensemble and observing the effect on overall accuracy.
  3. # We will take 50 different random arrangemnts of models for each ensemble size (other than 50, which is the full set) to get a distribution of accuracies for each ensemble size.
  4. # The will have associated error bars based on the standard deviation of the accuracies for each ensemble size.
  5. import xarray as xr
  6. from utils.config import config
  7. import pathlib as pl
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import os
  11. # Load the evaluation results
  12. os.chdir(pl.Path(__file__).parent)
  13. model_dataset_path = pl.Path("../model_evaluations") / pl.Path(
  14. config["analysis"]["evaluation_name"].strip()
  15. ).with_suffix(".nc")
  16. print(f"Loading evaluation results from {model_dataset_path}")
  17. array = xr.open_dataset(model_dataset_path) # type: ignore
  18. # This section was generated by Github Copilot - 2025-11-04
  19. # Perform sensitivity analysis by varying ensemble size and sampling subsets of models.
  20. predictions: xr.DataArray = array["predictions"]
  21. labels: xr.DataArray = array["labels"]
  22. # Make plots directory if it doesn't exist (matching other scripts)
  23. plots_dir = (
  24. pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots"
  25. )
  26. plots_dir.mkdir(parents=True, exist_ok=True)
  27. # Configuration for the sensitivity analysis
  28. rng = np.random.default_rng(777)
  29. num_models = int(predictions.sizes["model"])
  30. ensemble_sizes = list(range(1, num_models + 1))
  31. samples_per_size = 50
  32. # Extract true labels for the positive class (assumes same structure as other script)
  33. true_labels = labels.sel(label=1).values # shape: (n_samples,)
  34. # Container for results
  35. mean_accuracies: list[float] = []
  36. std_accuracies: list[float] = []
  37. all_accuracies: dict[int, list[float]] = {k: [] for k in ensemble_sizes}
  38. # Confusion-matrix counts per ensemble size (one entry per sample draw)
  39. all_true_positives: dict[int, list[int]] = {k: [] for k in ensemble_sizes}
  40. all_false_positives: dict[int, list[int]] = {k: [] for k in ensemble_sizes}
  41. all_true_negatives: dict[int, list[int]] = {k: [] for k in ensemble_sizes}
  42. all_false_negatives: dict[int, list[int]] = {k: [] for k in ensemble_sizes}
  43. # Also implement a "confidence accuracy score":
  44. # sum over samples of (confidence_of_predicted_class - 0.5) * (+1 if correct else -1)
  45. confidence_accuracy_scores: dict[int, list[float]] = {k: [] for k in ensemble_sizes}
  46. mean_confidence_accuracy_scores: list[float] = []
  47. std_confidence_accuracy_scores: list[float] = []
  48. # Track average confidence of predicted class for correct vs incorrect predictions
  49. correct_prediction_confidences: dict[int, list[float]] = {k: [] for k in ensemble_sizes}
  50. incorrect_prediction_confidences: dict[int, list[float]] = {
  51. k: [] for k in ensemble_sizes
  52. }
  53. mean_correct_prediction_confidences: list[float] = []
  54. std_correct_prediction_confidences: list[float] = []
  55. mean_incorrect_prediction_confidences: list[float] = []
  56. std_incorrect_prediction_confidences: list[float] = []
  57. # Track delta (correct - incorrect) confidence per model, averaged over models
  58. confidence_delta_per_model: dict[int, list[float]] = {k: [] for k in ensemble_sizes}
  59. mean_confidence_delta_per_model: list[float] = []
  60. std_confidence_delta_per_model: list[float] = []
  61. def confidence_accuracy_score(
  62. positive_class_confidence: np.ndarray, true_labels_binary: np.ndarray
  63. ) -> float:
  64. """Compute confidence accuracy score as described.
  65. Parameters
  66. ----------
  67. positive_class_confidence:
  68. Probability/confidence assigned to the positive class (class 1), shape (n_samples,).
  69. true_labels_binary:
  70. Ground-truth binary labels in {0, 1}, shape (n_samples,).
  71. """
  72. confs = np.asarray(positive_class_confidence, dtype=float)
  73. y = np.asarray(true_labels_binary, dtype=int)
  74. if confs.shape[0] != y.shape[0]:
  75. raise ValueError(
  76. f"Mismatched lengths: confs has {confs.shape[0]} samples, labels has {y.shape[0]}"
  77. )
  78. predicted_positive = confs >= 0.5
  79. true_positive = y == 1
  80. correct = predicted_positive == true_positive
  81. # Confidence of the predicted class (so correct negatives contribute positively)
  82. predicted_class_confidence = np.where(predicted_positive, confs, 1.0 - confs)
  83. sign = np.where(correct, 1.0, -1.0)
  84. # Normalize by number of evaluated images (per-image average score)
  85. # To normalize score to [0, 1], divide by 0.5 * n_samples
  86. return float(
  87. np.sum((predicted_class_confidence - 0.5) * sign) / (0.5 * confs.shape[0])
  88. )
  89. def average_predicted_class_confidence_by_correctness(
  90. positive_class_confidence: np.ndarray, true_labels_binary: np.ndarray
  91. ) -> tuple[float, float]:
  92. """Return (mean_conf_correct, mean_conf_incorrect) for predicted class confidence.
  93. Uses confidence of the predicted class for each sample:
  94. - if predicted positive: confidence = P(class=1)
  95. - if predicted negative: confidence = P(class=0) = 1 - P(class=1)
  96. """
  97. confs = np.asarray(positive_class_confidence, dtype=float)
  98. y = np.asarray(true_labels_binary, dtype=int)
  99. if confs.shape[0] != y.shape[0]:
  100. raise ValueError(
  101. f"Mismatched lengths: confs has {confs.shape[0]} samples, labels has {y.shape[0]}"
  102. )
  103. predicted_positive = confs >= 0.5
  104. true_positive = y == 1
  105. correct = predicted_positive == true_positive
  106. predicted_class_confidence = np.where(predicted_positive, confs, 1.0 - confs)
  107. correct_confs = predicted_class_confidence[correct]
  108. incorrect_confs = predicted_class_confidence[~correct]
  109. mean_correct = float(np.mean(correct_confs)) if correct_confs.size else float("nan")
  110. mean_incorrect = (
  111. float(np.mean(incorrect_confs)) if incorrect_confs.size else float("nan")
  112. )
  113. return mean_correct, mean_incorrect
  114. def average_confidence_delta_per_model(
  115. predictions_selected: xr.DataArray, true_labels_binary: np.ndarray
  116. ) -> float:
  117. """Average (correct - incorrect) predicted-class confidence per model.
  118. For each model independently:
  119. 1) compute predicted-class confidence per sample
  120. 2) split into correct vs incorrect samples
  121. 3) delta = mean(correct) - mean(incorrect)
  122. Returns the mean delta across models (nan-safe).
  123. """
  124. per_model_confs = predictions_selected.sel(img_class=1).values # (k_models, n)
  125. y = np.asarray(true_labels_binary, dtype=int)
  126. if per_model_confs.shape[1] != y.shape[0]:
  127. raise ValueError(
  128. f"Mismatched lengths: predictions have {per_model_confs.shape[1]} samples, labels has {y.shape[0]}"
  129. )
  130. # For each model, determine predicted class and confidence of predicted class
  131. predicted_positive = per_model_confs >= 0.5
  132. true_positive = y == 1
  133. correct = predicted_positive == true_positive # broadcasts to (k_models, n)
  134. predicted_class_confidence = np.where(
  135. predicted_positive, per_model_confs, 1.0 - per_model_confs
  136. )
  137. # Compute per-model means for correct/incorrect, then delta
  138. deltas: list[float] = []
  139. for m in range(predicted_class_confidence.shape[0]):
  140. conf_m = predicted_class_confidence[m]
  141. correct_m = correct[m]
  142. correct_vals = conf_m[correct_m]
  143. incorrect_vals = conf_m[~correct_m]
  144. if correct_vals.size == 0 or incorrect_vals.size == 0:
  145. deltas.append(float("nan"))
  146. continue
  147. deltas.append(float(np.mean(correct_vals) - np.mean(incorrect_vals)))
  148. return float(np.nanmean(np.asarray(deltas, dtype=float)))
  149. for k in ensemble_sizes:
  150. accuracies_k = []
  151. true_positives_k: list[int] = []
  152. false_positives_k: list[int] = []
  153. true_negatives_k: list[int] = []
  154. false_negatives_k: list[int] = []
  155. confidence_scores_k: list[float] = []
  156. correct_confidences_k: list[float] = []
  157. incorrect_confidences_k: list[float] = []
  158. confidence_deltas_per_model_k: list[float] = []
  159. # If using the full set, evaluate once deterministically
  160. if k == num_models:
  161. selected_idx = np.arange(num_models)
  162. preds_selected = predictions.isel(model=selected_idx).mean(dim="model")
  163. per_model_preds_selected = predictions.isel(model=selected_idx)
  164. confs = preds_selected.sel(img_class=1).values
  165. predicted_positive = confs >= 0.5
  166. true_positive = true_labels == 1
  167. confidence_scores_k.append(confidence_accuracy_score(confs, true_labels))
  168. mean_correct_conf, mean_incorrect_conf = (
  169. average_predicted_class_confidence_by_correctness(confs, true_labels)
  170. )
  171. correct_confidences_k.append(mean_correct_conf)
  172. incorrect_confidences_k.append(mean_incorrect_conf)
  173. confidence_deltas_per_model_k.append(
  174. average_confidence_delta_per_model(per_model_preds_selected, true_labels)
  175. )
  176. tp = int((predicted_positive & true_positive).sum().item())
  177. fp = int((predicted_positive & ~true_positive).sum().item())
  178. tn = int((~predicted_positive & ~true_positive).sum().item())
  179. fn = int((~predicted_positive & true_positive).sum().item())
  180. true_positives_k.append(tp)
  181. false_positives_k.append(fp)
  182. true_negatives_k.append(tn)
  183. false_negatives_k.append(fn)
  184. acc = (predicted_positive == true_positive).sum().item() / len(confs)
  185. accuracies_k.append(acc)
  186. else:
  187. for _ in range(samples_per_size):
  188. selected_idx = rng.choice(num_models, size=k, replace=False)
  189. preds_selected = predictions.isel(model=selected_idx).mean(dim="model")
  190. per_model_preds_selected = predictions.isel(model=selected_idx)
  191. confs = preds_selected.sel(img_class=1).values
  192. predicted_positive = confs >= 0.5
  193. true_positive = true_labels == 1
  194. confidence_scores_k.append(confidence_accuracy_score(confs, true_labels))
  195. mean_correct_conf, mean_incorrect_conf = (
  196. average_predicted_class_confidence_by_correctness(confs, true_labels)
  197. )
  198. correct_confidences_k.append(mean_correct_conf)
  199. incorrect_confidences_k.append(mean_incorrect_conf)
  200. confidence_deltas_per_model_k.append(
  201. average_confidence_delta_per_model(
  202. per_model_preds_selected, true_labels
  203. )
  204. )
  205. tp = int((predicted_positive & true_positive).sum().item())
  206. fp = int((predicted_positive & ~true_positive).sum().item())
  207. tn = int((~predicted_positive & ~true_positive).sum().item())
  208. fn = int((~predicted_positive & true_positive).sum().item())
  209. true_positives_k.append(tp)
  210. false_positives_k.append(fp)
  211. true_negatives_k.append(tn)
  212. false_negatives_k.append(fn)
  213. acc = (predicted_positive == true_positive).sum().item() / len(confs)
  214. accuracies_k.append(acc)
  215. all_accuracies[k] = accuracies_k
  216. all_true_positives[k] = true_positives_k
  217. all_false_positives[k] = false_positives_k
  218. all_true_negatives[k] = true_negatives_k
  219. all_false_negatives[k] = false_negatives_k
  220. confidence_accuracy_scores[k] = confidence_scores_k
  221. mean_confidence_accuracy_scores.append(float(np.mean(confidence_scores_k)))
  222. std_confidence_accuracy_scores.append(float(np.std(confidence_scores_k, ddof=0)))
  223. correct_prediction_confidences[k] = correct_confidences_k
  224. incorrect_prediction_confidences[k] = incorrect_confidences_k
  225. mean_correct_prediction_confidences.append(
  226. float(np.nanmean(np.asarray(correct_confidences_k, dtype=float)))
  227. )
  228. std_correct_prediction_confidences.append(
  229. float(np.nanstd(np.asarray(correct_confidences_k, dtype=float), ddof=0))
  230. )
  231. mean_incorrect_prediction_confidences.append(
  232. float(np.nanmean(np.asarray(incorrect_confidences_k, dtype=float)))
  233. )
  234. std_incorrect_prediction_confidences.append(
  235. float(np.nanstd(np.asarray(incorrect_confidences_k, dtype=float), ddof=0))
  236. )
  237. confidence_delta_per_model[k] = confidence_deltas_per_model_k
  238. mean_confidence_delta_per_model.append(
  239. float(np.nanmean(np.asarray(confidence_deltas_per_model_k, dtype=float)))
  240. )
  241. std_confidence_delta_per_model.append(
  242. float(np.nanstd(np.asarray(confidence_deltas_per_model_k, dtype=float), ddof=0))
  243. )
  244. mean_accuracies.append(float(np.mean(accuracies_k)))
  245. std_accuracies.append(float(np.std(accuracies_k, ddof=0)))
  246. # Compute F1 scores per ensemble size from stored confusion counts
  247. mean_f1s: list[float] = []
  248. std_f1s: list[float] = []
  249. all_f1s: dict[int, list[float]] = {k: [] for k in ensemble_sizes}
  250. for k in ensemble_sizes:
  251. tp_arr = np.asarray(all_true_positives[k], dtype=float)
  252. fp_arr = np.asarray(all_false_positives[k], dtype=float)
  253. fn_arr = np.asarray(all_false_negatives[k], dtype=float)
  254. denom = 2 * tp_arr + fp_arr + fn_arr
  255. f1_arr = np.divide(
  256. 2 * tp_arr,
  257. denom,
  258. out=np.zeros_like(denom, dtype=float),
  259. where=denom != 0,
  260. )
  261. f1s_k = [float(x) for x in f1_arr.tolist()]
  262. all_f1s[k] = f1s_k
  263. mean_f1s.append(float(np.mean(f1s_k)))
  264. std_f1s.append(float(np.std(f1s_k, ddof=0)))
  265. # Plot mean accuracy vs ensemble size with error bars (std)
  266. plt.figure(figsize=(10, 6))
  267. plt.errorbar(
  268. ensemble_sizes,
  269. mean_accuracies,
  270. yerr=std_accuracies,
  271. fmt="-o",
  272. capsize=3,
  273. color="tab:blue",
  274. ecolor="tab:blue",
  275. )
  276. plt.title("Sensitivity Analysis: Accuracy vs Ensemble Size")
  277. plt.xlabel("Number of Models in Ensemble")
  278. plt.ylabel("Accuracy")
  279. plt.grid(True)
  280. # Set x-ticks every 5 models (and always include the final model count)
  281. ticks = list(range(1, num_models + 1, 5))
  282. if len(ticks) == 0 or ticks[-1] != num_models:
  283. ticks.append(num_models)
  284. plt.xticks(ticks)
  285. # Optionally overlay raw sample distributions as jittered points
  286. for i, k in enumerate(ensemble_sizes):
  287. y = all_accuracies[k]
  288. x = np.full(len(y), k) + (rng.random(len(y)) - 0.5) * 0.2 # small jitter
  289. plt.scatter(x, y, alpha=0.35, s=10, color="lightgray")
  290. plt.tight_layout()
  291. plt.savefig(plots_dir / "sensitivity_accuracy_vs_ensemble_size.png")
  292. # Plot mean F1 vs ensemble size with error bars (std)
  293. plt.figure(figsize=(10, 6))
  294. plt.errorbar(
  295. ensemble_sizes,
  296. mean_f1s,
  297. yerr=std_f1s,
  298. fmt="-o",
  299. capsize=3,
  300. color="tab:orange",
  301. ecolor="tab:orange",
  302. )
  303. plt.title("Sensitivity Analysis: F1 Score vs Ensemble Size")
  304. plt.xlabel("Number of Models in Ensemble")
  305. plt.ylabel("F1 Score")
  306. plt.grid(True)
  307. plt.xticks(ticks)
  308. # Optionally overlay raw sample distributions as jittered points
  309. for i, k in enumerate(ensemble_sizes):
  310. y = all_f1s[k]
  311. x = np.full(len(y), k) + (rng.random(len(y)) - 0.5) * 0.2 # small jitter
  312. plt.scatter(x, y, alpha=0.35, s=10, color="lightgray")
  313. plt.tight_layout()
  314. plt.savefig(plots_dir / "sensitivity_f1_vs_ensemble_size.png")
  315. # Plot mean confidence accuracy score vs ensemble size with error bars (std)
  316. plt.figure(figsize=(10, 6))
  317. plt.errorbar(
  318. ensemble_sizes,
  319. mean_confidence_accuracy_scores,
  320. yerr=std_confidence_accuracy_scores,
  321. fmt="-o",
  322. capsize=3,
  323. color="tab:purple",
  324. ecolor="tab:purple",
  325. )
  326. plt.title("Sensitivity Analysis: Confidence Accuracy Score vs Ensemble Size")
  327. plt.xlabel("Number of Models in Ensemble")
  328. plt.ylabel("Confidence Accuracy Score (per image)")
  329. plt.grid(True)
  330. plt.xticks(ticks)
  331. # Optionally overlay raw sample distributions as jittered points
  332. for k in ensemble_sizes:
  333. y = confidence_accuracy_scores[k]
  334. x = np.full(len(y), k) + (rng.random(len(y)) - 0.5) * 0.2 # small jitter
  335. plt.scatter(x, y, alpha=0.35, s=10, color="lightgray")
  336. plt.tight_layout()
  337. plt.savefig(plots_dir / "sensitivity_confidence_accuracy_vs_ensemble_size.png")
  338. # Plot mean predicted-class confidence for correct vs incorrect predictions
  339. plt.figure(figsize=(10, 6))
  340. plt.errorbar(
  341. ensemble_sizes,
  342. mean_correct_prediction_confidences,
  343. yerr=std_correct_prediction_confidences,
  344. fmt="-o",
  345. capsize=3,
  346. color="tab:green",
  347. label="Correct predictions",
  348. )
  349. plt.errorbar(
  350. ensemble_sizes,
  351. mean_incorrect_prediction_confidences,
  352. yerr=std_incorrect_prediction_confidences,
  353. fmt="-o",
  354. capsize=3,
  355. color="tab:red",
  356. label="Incorrect predictions",
  357. )
  358. plt.title("Sensitivity Analysis: Avg Predicted-Class Confidence (Correct vs Incorrect)")
  359. plt.xlabel("Number of Models in Ensemble")
  360. plt.ylabel("Average Predicted-Class Confidence")
  361. plt.grid(True)
  362. plt.xticks(ticks)
  363. # Overlay raw sample distributions as jittered points
  364. for k in ensemble_sizes:
  365. x = np.full(samples_per_size if k != num_models else 1, k) + (
  366. (rng.random(samples_per_size if k != num_models else 1) - 0.5) * 0.2
  367. )
  368. y_correct = np.asarray(correct_prediction_confidences[k], dtype=float)
  369. y_incorrect = np.asarray(incorrect_prediction_confidences[k], dtype=float)
  370. if y_correct.size:
  371. plt.scatter(
  372. x[: y_correct.size],
  373. y_correct,
  374. alpha=0.45,
  375. s=14,
  376. color="lightgray",
  377. marker="o",
  378. )
  379. if y_incorrect.size:
  380. plt.scatter(
  381. x[: y_incorrect.size],
  382. y_incorrect,
  383. alpha=0.45,
  384. s=18,
  385. color="dimgray",
  386. marker="x",
  387. )
  388. plt.legend()
  389. plt.tight_layout()
  390. plt.savefig(plots_dir / "sensitivity_avg_confidence_correct_vs_incorrect.png")
  391. # Plot mean confidence delta per model (correct - incorrect) vs ensemble size
  392. plt.figure(figsize=(10, 6))
  393. plt.errorbar(
  394. ensemble_sizes,
  395. mean_confidence_delta_per_model,
  396. yerr=std_confidence_delta_per_model,
  397. fmt="-o",
  398. capsize=3,
  399. color="tab:blue",
  400. )
  401. plt.title("Sensitivity Analysis: Confidence Delta Per Model (Correct - Incorrect)")
  402. plt.xlabel("Number of Models in Ensemble")
  403. plt.ylabel("Avg Delta Per Model")
  404. plt.grid(True)
  405. plt.xticks(ticks)
  406. for k in ensemble_sizes:
  407. y = confidence_delta_per_model[k]
  408. x = np.full(len(y), k) + (rng.random(len(y)) - 0.5) * 0.2
  409. plt.scatter(x, y, alpha=0.35, s=12, color="lightgray")
  410. plt.tight_layout()
  411. plt.savefig(plots_dir / "sensitivity_confidence_delta_per_model.png")
  412. # End of Copilot section