sensitivity_analysis.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # The purpose of this file is to perform a sensitivity analysis on the model evaluation results and graph the findings.
  2. # The sensitivity analysis will be done by varying the number of models used in the ensemble and observing the effect on overall accuracy.
  3. # We will take 50 different random arrangemnts of models for each ensemble size (other than 50, which is the full set) to get a distribution of accuracies for each ensemble size.
  4. # The will have associated error bars based on the standard deviation of the accuracies for each ensemble size.
  5. import xarray as xr
  6. from utils.config import config
  7. import pathlib as pl
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import os
  11. # Load the evaluation results
  12. os.chdir(pl.Path(__file__).parent)
  13. model_dataset_path = pl.Path("../model_evaluations") / pl.Path(
  14. config["analysis"]["evaluation_name"].strip()
  15. ).with_suffix(".nc")
  16. print(f"Loading evaluation results from {model_dataset_path}")
  17. array = xr.open_dataset(model_dataset_path) # type: ignore
  18. # This section was generated by Github Copilot - 2025-11-04
  19. # Perform sensitivity analysis by varying ensemble size and sampling subsets of models.
  20. predictions: xr.DataArray = array["predictions"]
  21. labels: xr.DataArray = array["labels"]
  22. # Make plots directory if it doesn't exist (matching other scripts)
  23. plots_dir = (
  24. pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots"
  25. )
  26. plots_dir.mkdir(parents=True, exist_ok=True)
  27. # Configuration for the sensitivity analysis
  28. rng = np.random.default_rng(42)
  29. num_models = int(predictions.sizes["model"])
  30. ensemble_sizes = list(range(1, num_models + 1))
  31. samples_per_size = 50
  32. # Extract true labels for the positive class (assumes same structure as other script)
  33. true_labels = labels.sel(label=1).values # shape: (n_samples,)
  34. # Container for results
  35. mean_accuracies: list[float] = []
  36. std_accuracies: list[float] = []
  37. all_accuracies: dict[int, list[float]] = {k: [] for k in ensemble_sizes}
  38. for k in ensemble_sizes:
  39. accuracies_k = []
  40. # If using the full set, evaluate once deterministically
  41. if k == num_models:
  42. selected_idx = np.arange(num_models)
  43. preds_selected = predictions.isel(model=selected_idx).mean(dim="model")
  44. confs = preds_selected.sel(img_class=1).values
  45. predicted_positive = confs >= 0.5
  46. true_positive = true_labels == 1
  47. acc = (predicted_positive == true_positive).sum().item() / len(confs)
  48. accuracies_k.append(acc)
  49. else:
  50. for _ in range(samples_per_size):
  51. selected_idx = rng.choice(num_models, size=k, replace=False)
  52. preds_selected = predictions.isel(model=selected_idx).mean(dim="model")
  53. confs = preds_selected.sel(img_class=1).values
  54. predicted_positive = confs >= 0.5
  55. true_positive = true_labels == 1
  56. acc = (predicted_positive == true_positive).sum().item() / len(confs)
  57. accuracies_k.append(acc)
  58. all_accuracies[k] = accuracies_k
  59. mean_accuracies.append(float(np.mean(accuracies_k)))
  60. std_accuracies.append(float(np.std(accuracies_k, ddof=0)))
  61. # Plot mean accuracy vs ensemble size with error bars (std)
  62. plt.figure(figsize=(10, 6))
  63. plt.errorbar(ensemble_sizes, mean_accuracies, yerr=std_accuracies, fmt="-o", capsize=3)
  64. plt.title("Sensitivity Analysis: Accuracy vs Ensemble Size")
  65. plt.xlabel("Number of Models in Ensemble")
  66. plt.ylabel("Accuracy")
  67. plt.grid(True)
  68. # Set x-ticks every 5 models (and always include the final model count)
  69. ticks = list(range(1, num_models + 1, 5))
  70. if len(ticks) == 0 or ticks[-1] != num_models:
  71. ticks.append(num_models)
  72. plt.xticks(ticks)
  73. # Optionally overlay raw sample distributions as jittered points
  74. for i, k in enumerate(ensemble_sizes):
  75. y = all_accuracies[k]
  76. x = np.full(len(y), k) + (rng.random(len(y)) - 0.5) * 0.2 # small jitter
  77. plt.scatter(x, y, alpha=0.3, s=8, color="gray")
  78. plt.tight_layout()
  79. plt.savefig(plots_dir / "sensitivity_accuracy_vs_ensemble_size.png")
  80. # End of Copilot section