123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import torch
- import pathlib
- import utils.models.cnn as c
- from typing import Tuple, List
- import xarray as xr
- type ModelPair = Tuple[c.CNN, str]
- type ModelPredictionData = xr.DataArray
- type InputData = Tuple[torch.Tensor, torch.Tensor]
- # This file contains functions to ensemble a folder of models and evaluate them on a test set, with included uncertainty estimation.
- def load_models(folder: pathlib.Path, device: str) -> List[ModelPair]:
- model_files = folder.glob("*.pt")
- model_pairs: List[ModelPair] = []
- for model_file in model_files:
- model: c.CNN = torch.load(model_file, map_location=device, weights_only=False)
- # Extract model description from filename
- model_pairs.append((model, model_file.stem))
- return model_pairs
- def prepare_datasets(data: Tuple[torch.Tensor, torch.Tensor]) -> InputData:
- # Ensure the data is in the correct format
- mri_data.unsqueeze(0)
- xls_data.unsqueeze(0)
- # Combine MRI and XLS data into a tuple
- return (mri_data, xls_data)
- def get_model_names(models: List[ModelPair]) -> List[str]:
- # Extract model names from the model pairs
- return [model_pair[1] for model_pair in models]
- def get_model_objects(models: List[ModelPair]) -> List[c.CNN]:
- # Extract model objects from the model pairs
- return [model_pair[0] for model_pair in models]
- def ensemble_predict(models: List[c.CNN], input: InputData):
- predictions = []
- for model in models:
- model.eval()
- with torch.no_grad():
- # Apply model and extract positive class predictions
- output = model(input)[:, 1]
- predictions.append(output)
- # Calculate mean and variance of predictions
- predictions = torch.stack(predictions)
- mean = predictions.mean()
- variance = predictions.var()
- return mean, variance
- def ensemble_predict_strict_classes(models, input):
- predictions = []
- for model in models:
- model.eval()
- with torch.no_grad():
- # Apply model and extract prediction
- output = model(input)
- _, predicted = torch.max(output.data, 1)
- predictions.append(predicted.item())
- pos_votes = len([p for p in predictions if p == 1])
- neg_votes = len([p for p in predictions if p == 0])
- return pos_votes / len(models), pos_votes, neg_votes
- # Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files
- def prune_models(models, model_descs, folder, threshold):
- new_models = []
- new_descs = []
- for model, desc in zip(models, model_descs):
- with open(folder + desc + "_test_acc.txt", "r") as f:
- acc = float(f.read())
- if acc >= threshold:
- new_models.append(model)
- new_descs.append(desc)
- return new_models, new_descs
|