confidence_percentile.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import xarray as xr
  2. import numpy as np
  3. import sys
  4. import os
  5. import matplotlib.pyplot as plt
  6. sys.path.append(
  7. os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
  8. ) # to allow imports from parent directory
  9. from utils.config import config
  10. import pathlib as pl
  11. import colorama as clr
  12. model_dataset_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
  13. array = xr.open_dataset(model_dataset_path) # type: ignore
  14. predictions: xr.DataArray = array["predictions"]
  15. labels: xr.DataArray = array["labels"]
  16. # Make plots directory if it doesn't exist
  17. plots_dir = pl.Path(config["output"]["path"]) / "plots"
  18. plots_dir.mkdir(parents=True, exist_ok=True)
  19. # This script calculates and plots accuracy vs minimum confidence percentile threshold
  20. # Average predictions across models
  21. avg_predictions = predictions.mean(dim="model")
  22. # Get confidence scores for the positive class
  23. confidence_scores = avg_predictions.sel(img_class=1).values
  24. true_labels = labels.sel(label=1).values
  25. # Calculate accuracy at different confidence percentiles
  26. percentiles = np.linspace(0, 100, num=21)
  27. accuracies = []
  28. for p in percentiles:
  29. absolute_confidences = 2 * np.abs(confidence_scores - 0.5)
  30. threshold = np.percentile(absolute_confidences, p)
  31. # Filter the predictions such that only those with absolute confidence above the threshold are considered
  32. selected_indices = np.where(absolute_confidences >= threshold)[0]
  33. if len(selected_indices) == 0:
  34. accuracies.append(0.0)
  35. continue
  36. selected_confidences = confidence_scores[selected_indices]
  37. selected_true_labels = true_labels[selected_indices]
  38. predicted_positive = selected_confidences >= 0.5
  39. true_positive = selected_true_labels == 1
  40. correct_predictions = (predicted_positive == true_positive).sum().item()
  41. total_predictions = len(selected_confidences)
  42. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  43. accuracies.append(accuracy)
  44. # Plot accuracy vs confidence percentile threshold
  45. plt.figure(figsize=(10, 6))
  46. plt.plot(percentiles, accuracies, marker="o")
  47. plt.title("Accuracy vs Confidence Percentile Threshold")
  48. plt.xlabel("Confidence Percentile Threshold")
  49. plt.ylabel("Accuracy")
  50. plt.grid()
  51. plt.xticks(percentiles)
  52. plt.savefig(
  53. pl.Path(config["output"]["path"])
  54. / "plots"
  55. / "accuracy_vs_confidence_percentile.png"
  56. )