import utils.ensemble as ens import os import tomli as tml from utils.system import force_init_cudnn import torch import pathlib as pl from utils.data.datasets import ADNIDataset import xarray as xr # CONFIGURATION with open(os.getenv("ADL_CONFIG_PATH", "config.toml"), "rb") as f: config = tml.load(f) force_init_cudnn(config["training"]["device"]) # INIT DATA AND MODELS ensemble_folder: pl.Path = ( config["paths"]["model_output"] + config["ensemble"]["name"] + "/models/" ) # Load test data test_dataset: ADNIDataset = torch.load( config["paths"]["model_output"] + config["ensemble"]["name"] + "/test_dataset.pt", weights_only=False, ) models = ens.load_models(pl.Path(ensemble_folder), config["training"]["device"]) # We are generating a large matrix, with the dimensions of the models, the test set, and the number of classes # Therefore we are capturing the output of every model for every item in the test set and storing it in a matrix type ResultsMatrix = xr.DataArray type ActualMatrix = xr.DataArray results: ResultsMatrix = xr.DataArray( data=0, dims=["model", "test_item", "class"], coords={ "model": ens.get_model_names(models), "test_item": range(len(test_dataset)), "class": [0, 1], }, ) actual: ActualMatrix = xr.DataArray( data=0, dims=["test_item", "class"], coords={ "test_item": range(len(test_dataset)), "class": [0, 1], }, ) final: xr.Dataset = xr.Dataset( data_vars={ "evaluated": results, "actual": actual, }, ) # Iterate over the test set and get the predictions for each model for i, (unp_data, target) in enumerate(test_dataset): data = ens.prepare_datasets(unp_data) for j, (model_obj, model_name) in enumerate(models): model_obj.eval() with torch.no_grad(): output: torch.Tensor = model_obj(data) final.results.loc[dict(model=model_name, test_item=i)] = output.numpy() # type: ignore final.actual.loc[dict(test_item=i)] = target.numpy() # type: ignore # Save the results to a file final.to_netcdf( # type: ignore config["paths"]["model_output"] + config["ensemble"]["name"] + "/test_results.nc", mode="w", format="NETCDF4", )