import xarray as xr import numpy as np from utils.config import config import pathlib as pl import colorama as clr import os os.chdir(pl.Path(__file__).parent) model_dataset_path = pl.Path("../model_evaluations") / pl.Path( config["analysis"]["evaluation_name"].strip() ).with_suffix(".nc") array = xr.open_dataset(model_dataset_path) # type: ignore predictions: xr.DataArray = array["predictions"] labels: xr.DataArray = array["labels"] # Average predictions across models avg_predictions = predictions.mean(dim="model") # Sort from highest to lowest confidence for the positive class (img_class=1) sorted_indices = np.argsort(-avg_predictions.sel(img_class=1).values) sorted_avg_predictions = avg_predictions.isel(img_id=sorted_indices) sorted_labels = labels.isel(img_id=sorted_indices) # Print out all predictions with their labels top_n = sorted_avg_predictions.sizes[ "img_id" ] # Change this value to print more or fewer print( clr.Fore.CYAN + f"Top {top_n} Predictions (Confidence for Positive Class):" + clr.Style.RESET_ALL ) for i in range(top_n): confidence = sorted_avg_predictions.sel(img_class=1).isel(img_id=i).item() label = sorted_labels.isel(img_id=i, label=1).values correctness = ( "CORRECT" if (confidence >= 0.5 and label == 1) or (confidence < 0.5 and label == 0) else "INCORRECT" ) color = clr.Fore.GREEN if correctness == "CORRECT" else clr.Fore.RED print( f"Image ID: {sorted_avg_predictions.img_id.isel(img_id=i).item():<8}, " f"Confidence: {confidence:.4f}, " f"Label: {label:<3}, " + color + f"{correctness:<9}" + clr.Style.RESET_ALL ) # Calculate overall accuracy predicted_positive = avg_predictions.sel(img_class=1) >= 0.5 true_positive = labels.sel(label=1) == 1 correct_predictions = (predicted_positive == true_positive).sum().item() total_predictions = len(avg_predictions.img_id) overall_accuracy = ( correct_predictions / total_predictions if total_predictions > 0 else 0.0 ) print( clr.Fore.MAGENTA + f"\nOverall Accuracy (Threshold 0.5): {overall_accuracy:.4f}" + clr.Style.RESET_ALL ) # Then go through all individual models and print out their accuracies for comparison, sorted from highest to lowest model_accuracies = [] for model_idx in predictions.coords["model"].values: model_preds = predictions.sel(model=model_idx) predicted_positive = model_preds.sel(img_class=1) >= 0.5 correct_predictions = (predicted_positive == true_positive).sum().item() accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 model_accuracies.append((model_idx, accuracy)) # Sort by accuracy model_accuracies.sort(key=lambda x: x[1], reverse=True) print( clr.Fore.CYAN + f"\nIndividual Model Accuracies (Threshold 0.5):" + clr.Style.RESET_ALL ) for model_idx, accuracy in model_accuracies: print(f"Model {int(model_idx):<3}: Accuracy: {accuracy:.4f}") # Then calculate the average accuracy if we were to ensemble the top K models, for K=1 to total number of models total_models = len(predictions.coords["model"].values) ensemble_accuracies = [] for k in range(1, total_models + 1): top_k_models = [ma[0] for ma in model_accuracies[:k]] ensemble_preds = predictions.sel(model=top_k_models).mean(dim="model") predicted_positive = ensemble_preds.sel(img_class=1) >= 0.5 correct_predictions = (predicted_positive == true_positive).sum().item() accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 ensemble_accuracies.append((k, accuracy)) print( clr.Fore.CYAN + f"\nEnsemble Accuracies for Top K Models (Threshold 0.5):" + clr.Style.RESET_ALL ) for k, accuracy in ensemble_accuracies: print(f"Top {k:<3} Models: Ensemble Accuracy: {accuracy:.4f}") # Finally, identify the top 5 most confidently incorrect predictions incorrect_predictions = [] for i in range(len(avg_predictions.img_id)): confidence = avg_predictions.sel(img_class=1).isel(img_id=i).item() label = labels.isel(img_id=i, label=1).values predicted_label = 1 if confidence >= 0.5 else 0 if predicted_label != label: incorrect_predictions.append((i, confidence, label)) # Sort by confidence incorrect_predictions.sort(key=lambda x: -abs(x[1] - 0.5)) top_incorrect = incorrect_predictions[:5] print( clr.Fore.YELLOW + f"\nTop 5 Most Confident Incorrect Predictions:" + clr.Style.RESET_ALL ) for i, confidence, label in top_incorrect: predicted_label = 1 if confidence >= 0.5 else 0 print( f"Image ID: {avg_predictions.img_id.isel(img_id=i).item():<8}, " f"Confidence: {confidence:.4f}, " f"Predicted Label: {predicted_label:<3}, " f"True Label: {label:<3}" )