import xarray as xr import numpy as np import matplotlib.pyplot as plt from utils.config import config import pathlib as pl import os # Load the evaluation results os.chdir(pl.Path(__file__).parent) model_dataset_path = pl.Path("../model_evaluations") / pl.Path( config["analysis"]["evaluation_name"].strip() ).with_suffix(".nc") print(f"Loading evaluation results from {model_dataset_path}") array = xr.open_dataset(model_dataset_path) # type: ignore predictions: xr.DataArray = array["predictions"] labels: xr.DataArray = array["labels"] # Make plots directory if it doesn't exist plots_dir = ( pl.Path("../output") / pl.Path(config["analysis"]["evaluation_name"]) / "plots" ) plots_dir.mkdir(parents=True, exist_ok=True) # This script calculates and plots accuracy vs minimum confidence percentile threshold # Average predictions across models avg_predictions = predictions.mean(dim="model") # Get confidence scores for the positive class confidence_scores = avg_predictions.sel(img_class=1).values true_labels = labels.sel(label=1).values # Calculate accuracy at different confidence percentiles percentiles = np.linspace(0, 100, num=21) accuracies: list[float] = [] for p in percentiles: absolute_confidences = 2 * np.abs(confidence_scores - 0.5) threshold = np.percentile(absolute_confidences, p) # Filter the predictions such that only those with absolute confidence above the threshold are considered selected_indices = np.where(absolute_confidences >= threshold)[0] if len(selected_indices) == 0: accuracies.append(0.0) continue selected_confidences = confidence_scores[selected_indices] selected_true_labels = true_labels[selected_indices] predicted_positive = selected_confidences >= 0.5 true_positive = selected_true_labels == 1 correct_predictions = (predicted_positive == true_positive).sum().item() total_predictions = len(selected_confidences) accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 accuracies.append(accuracy) # Plot accuracy vs confidence percentile threshold plt.figure(figsize=(10, 6)) plt.plot(percentiles, accuracies, marker="o") plt.title("Accuracy vs Confidence Percentile Threshold") plt.xlabel("Confidence Percentile Threshold") plt.ylabel("Accuracy") plt.grid() plt.xticks(percentiles) plt.savefig(plots_dir / "accuracy_vs_confidence_percentile_threshold.png")