| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import xarray as xr
- import numpy as np
- import matplotlib.pyplot as plt
- from utils.config import config
- import pathlib as pl
- import os
- # Load the evaluation results
- os.chdir(pl.Path(__file__).parent)
- model_dataset_path = pl.Path("../model_evaluations") / pl.Path(
- config["analysis"]["evaluation_name"].strip()
- ).with_suffix(".nc")
- print(f"Loading evaluation results from {model_dataset_path}")
- array = xr.open_dataset(model_dataset_path) # type: ignore
- predictions: xr.DataArray = array["predictions"]
- labels: xr.DataArray = array["labels"]
- # Make plots directory if it doesn't exist
- plots_dir = (
- pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots"
- )
- plots_dir.mkdir(parents=True, exist_ok=True)
- # This script calculates and plots accuracy vs minimum confidence percentile threshold
- # Average predictions across models
- avg_predictions = predictions.mean(dim="model")
- # Get confidence scores for the positive class
- confidence_scores = avg_predictions.sel(img_class=1).values
- true_labels = labels.sel(label=1).values
- # Calculate accuracy at different confidence percentiles
- percentiles = np.linspace(0, 100, num=21)
- accuracies: list[float] = []
- for p in percentiles:
- absolute_confidences = 2 * np.abs(confidence_scores - 0.5)
- threshold = np.percentile(absolute_confidences, p)
- # Filter the predictions such that only those with absolute confidence above the threshold are considered
- selected_indices = np.where(absolute_confidences >= threshold)[0]
- if len(selected_indices) == 0:
- accuracies.append(0.0)
- continue
- selected_confidences = confidence_scores[selected_indices]
- selected_true_labels = true_labels[selected_indices]
- predicted_positive = selected_confidences >= 0.5
- true_positive = selected_true_labels == 1
- correct_predictions = (predicted_positive == true_positive).sum().item()
- total_predictions = len(selected_confidences)
- accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
- accuracies.append(accuracy)
- # Plot accuracy vs confidence percentile threshold
- plt.figure(figsize=(10, 6))
- plt.plot(percentiles, accuracies, marker="o")
- plt.title("Accuracy vs Confidence Percentile Threshold")
- plt.xlabel("Confidence Percentile Threshold")
- plt.ylabel("Accuracy")
- plt.grid()
- plt.xticks(percentiles)
- plt.savefig(plots_dir / "accuracy_vs_confidence_percentile_threshold.png")
|