generate_statistics.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import xarray as xr
  2. from utils.config import config
  3. import pathlib as pl
  4. import numpy as np
  5. import os
  6. # Load the evaluation results
  7. os.chdir(pl.Path(__file__).parent)
  8. model_dataset_path = pl.Path("../model_evaluations") / pl.Path(
  9. config["analysis"]["evaluation_name"].strip()
  10. ).with_suffix(".nc")
  11. print(f"Loading evaluation results from {model_dataset_path}")
  12. array = xr.open_dataset(model_dataset_path) # type: ignore
  13. # This dataset includes two dataarrays: 'predictions' and 'labels'
  14. # For the first analysis, the goal is to average the predictions across all models for each image, then to determine the accuracy of these averaged predictions against the true labels, graphing accurac vs confidence threshold.
  15. predictions: xr.DataArray = array["predictions"]
  16. labels: xr.DataArray = array["labels"]
  17. # Average predictions across models
  18. avg_predictions = predictions.mean(dim="model")
  19. # Loop through different confidence thresholds and calculate accuracy
  20. thresholds = np.linspace(0.5, 1.0, num=10) # From 0.5 to 1.0
  21. accuracies: list[float] = []
  22. for i, threshold in enumerate(thresholds):
  23. # pick the positive class for the lables and predictions
  24. predicted_positive = avg_predictions.sel(img_class=1) >= threshold
  25. true_positive = labels.sel(label=1) == 1
  26. # Calculate accuracy
  27. correct_predictions = (predicted_positive == true_positive).sum().item()
  28. # For debugging, print list of predictions, labels and correctness
  29. total_predictions = len(avg_predictions.img_id)
  30. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  31. accuracies.append(accuracy)
  32. # Print the accuracies for each threshold
  33. for threshold, accuracy in zip(thresholds, accuracies):
  34. print(f"Threshold: {threshold:.2f}, Accuracy: {accuracy:.4f}")