import torch import pathlib import utils.models.cnn as c from typing import Tuple, List import xarray as xr type ModelPair = Tuple[c.CNN, str] type ModelPredictionData = xr.DataArray type InputData = Tuple[torch.Tensor, torch.Tensor] # This file contains functions to ensemble a folder of models and evaluate them on a test set, with included uncertainty estimation. def load_models(folder: pathlib.Path, device: str) -> List[ModelPair]: model_files = folder.glob("*.pt") model_pairs: List[ModelPair] = [] for model_file in model_files: model: c.CNN = torch.load(model_file, map_location=device, weights_only=False) # Extract model description from filename model_pairs.append((model, model_file.stem)) return model_pairs def prepare_datasets(data: Tuple[torch.Tensor, torch.Tensor]) -> InputData: # Ensure the data is in the correct format mri_data.unsqueeze(0) xls_data.unsqueeze(0) # Combine MRI and XLS data into a tuple return (mri_data, xls_data) def get_model_names(models: List[ModelPair]) -> List[str]: # Extract model names from the model pairs return [model_pair[1] for model_pair in models] def get_model_objects(models: List[ModelPair]) -> List[c.CNN]: # Extract model objects from the model pairs return [model_pair[0] for model_pair in models] def ensemble_predict(models: List[c.CNN], input: InputData): predictions = [] for model in models: model.eval() with torch.no_grad(): # Apply model and extract positive class predictions output = model(input)[:, 1] predictions.append(output) # Calculate mean and variance of predictions predictions = torch.stack(predictions) mean = predictions.mean() variance = predictions.var() return mean, variance def ensemble_predict_strict_classes(models, input): predictions = [] for model in models: model.eval() with torch.no_grad(): # Apply model and extract prediction output = model(input) _, predicted = torch.max(output.data, 1) predictions.append(predicted.item()) pos_votes = len([p for p in predictions if p == 1]) neg_votes = len([p for p in predictions if p == 0]) return pos_votes / len(models), pos_votes, neg_votes # Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files def prune_models(models, model_descs, folder, threshold): new_models = [] new_descs = [] for model, desc in zip(models, model_descs): with open(folder + desc + "_test_acc.txt", "r") as f: acc = float(f.read()) if acc >= threshold: new_models.append(model) new_descs.append(desc) return new_models, new_descs