generate_statistics.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import xarray as xr
  2. from utils.config import config
  3. import pathlib as pl
  4. import pandas as pd
  5. import numpy as np
  6. # Load the evaluation results
  7. array = xr.open_dataset(pl.Path(config["output"]["path"]) / "model_evaluation_results.nc") # type: ignore
  8. # This dataset includes two dataarrays: 'predictions' and 'labels'
  9. # For the first analysis, the goal is to average the predictions across all models for each image, then to determine the accuracy of these averaged predictions against the true labels, graphing accurac vs confidence threshold.
  10. predictions: xr.DataArray = array["predictions"]
  11. labels: xr.DataArray = array["labels"]
  12. # Average predictions across models
  13. avg_predictions = predictions.mean(dim="model")
  14. # Loop through different confidence thresholds and calculate accuracy
  15. thresholds = np.linspace(0.5, 1.0, num=10) # From 0.5 to 1.0
  16. accuracies = []
  17. for i, threshold in enumerate(thresholds):
  18. # pick the positive class for the lables and predictions
  19. predicted_positive = avg_predictions.sel(img_class=1) >= threshold
  20. true_positive = labels.sel(label=1) == 1
  21. # Calculate accuracy
  22. correct_predictions = (predicted_positive == true_positive).sum().item()
  23. # For debugging, print list of predictions, labels and correctness
  24. total_predictions = len(avg_predictions.img_id)
  25. accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
  26. accuracies.append(accuracy)
  27. # Print the accuracies for each threshold
  28. for threshold, accuracy in zip(thresholds, accuracies):
  29. print(f"Threshold: {threshold:.2f}, Accuracy: {accuracy:.4f}")