import utils.ensemble as ens import os import tomli as toml from utils.system import force_init_cudnn from utils.data.datasets import prepare_datasets import math import torch # CONFIGURATION if os.getenv("ADL_CONFIG_PATH") is None: with open("config.toml", "rb") as f: config = toml.load(f) else: with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f: config = toml.load(f) # Force cuDNN initialization force_init_cudnn(config["training"]["device"]) ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/" models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"]) models, model_descs = ens.prune_models( models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"] ) # Load test data test_dataset = prepare_datasets( config["paths"]["mri_data"], config["paths"]["xls_data"], config["dataset"]["validation_split"], 0, config["training"]["device"], )[2] # Evaluate ensemble and uncertainty test set correct = 0 total = 0 predictions = [] actual = [] stdevs = [] yes_votes = [] no_votes = [] for data, target in test_dataset: mri, xls = data mri = mri.unsqueeze(0) xls = xls.unsqueeze(0) data = (mri, xls) mean, variance = ens.ensemble_predict(models, data) _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data) stdevs.append(math.sqrt(variance.item())) predicted = torch.round(mean) expected = target[1] total += 1 correct += (predicted == expected).item() out = mean.tolist() predictions.append(out) act = target[1].tolist() actual.append(act) accuracy = correct / total with open( ensemble_folder + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt", "w", ) as f: f.write("Accuracy: " + str(accuracy) + "\n") f.write("Correct: " + str(correct) + "\n") f.write("Total: " + str(total) + "\n") for exp, pred, stdev in zip(actual, predictions, stdevs): f.write( str(exp) + ", " + str(pred) + ", " + str(stdev) + ", " + str(yes_votes) + ", " + str(no_votes) + "\n" )