confidence_percentile.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import xarray as xr
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from utils.config import config
  5. import pathlib as pl
  6. import os
  7. # Load the evaluation results
  8. os.chdir(pl.Path(__file__).parent)
  9. model_dataset_path = pl.Path("../model_evaluations") / pl.Path(
  10. config["analysis"]["evaluation_name"].strip()
  11. ).with_suffix(".nc")
  12. print(f"Loading evaluation results from {model_dataset_path}")
  13. array = xr.open_dataset(model_dataset_path) # type: ignore
  14. predictions: xr.DataArray = array["predictions"]
  15. labels: xr.DataArray = array["labels"]
  16. # Make plots directory if it doesn't exist
  17. plots_dir = (
  18. pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots"
  19. )
  20. plots_dir.mkdir(parents=True, exist_ok=True)
  21. # This script calculates and plots accuracy vs minimum confidence percentile threshold
  22. # Average predictions across models
  23. avg_predictions = predictions.mean(dim="model")
  24. # Get confidence scores for the positive class
  25. confidence_scores = avg_predictions.sel(img_class=1).values
  26. true_labels = labels.sel(label=1).values
  27. # Calculate accuracy at different confidence percentiles
  28. percentiles = np.linspace(0, 100, num=21)
  29. accuracies: list[float] = []
  30. for p in percentiles:
  31. absolute_confidences = 2 * np.abs(confidence_scores - 0.5)
  32. threshold = np.percentile(absolute_confidences, p)
  33. # Filter the predictions such that only those with absolute confidence above the threshold are considered
  34. selected_indices = np.where(absolute_confidences >= threshold)[0]
  35. if len(selected_indices) == 0:
  36. accuracies.append(0.0)
  37. continue
  38. selected_confidences = confidence_scores[selected_indices]
  39. selected_true_labels = true_labels[selected_indices]
  40. predicted_positive = selected_confidences >= 0.5
  41. true_positive = selected_true_labels == 1
  42. correct_predictions = (predicted_positive == true_positive).sum().item()
  43. total_predictions = len(selected_confidences)
  44. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  45. accuracies.append(accuracy)
  46. # Plot accuracy vs confidence percentile threshold
  47. plt.figure(figsize=(10, 6))
  48. plt.plot(percentiles, accuracies, marker="o")
  49. plt.title("Accuracy vs Confidence Percentile Threshold")
  50. plt.xlabel("Confidence Percentile Threshold")
  51. plt.ylabel("Accuracy")
  52. plt.grid()
  53. plt.xticks(percentiles)
  54. plt.savefig(plots_dir / "accuracy_vs_confidence_percentile_threshold.png")