123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import utils.ensemble as ens
- import os
- import tomli as toml
- from utils.system import force_init_cudnn
- from utils.data.datasets import prepare_datasets
- import math
- import torch
- # CONFIGURATION
- if os.getenv("ADL_CONFIG_PATH") is None:
- with open("config.toml", "rb") as f:
- config = toml.load(f)
- else:
- with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
- config = toml.load(f)
- # Force cuDNN initialization
- force_init_cudnn(config["training"]["device"])
- ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/"
- models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
- models, model_descs = ens.prune_models(
- models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"]
- )
- # Load test data
- test_dataset = prepare_datasets(
- config["paths"]["mri_data"],
- config["paths"]["xls_data"],
- config["dataset"]["validation_split"],
- 0,
- config["training"]["device"],
- )[2]
- # Evaluate ensemble and uncertainty test set
- correct = 0
- total = 0
- predictions = []
- actual = []
- stdevs = []
- yes_votes = []
- no_votes = []
- for data, target in test_dataset:
- mri, xls = data
- mri = mri.unsqueeze(0)
- xls = xls.unsqueeze(0)
- data = (mri, xls)
- mean, variance = ens.ensemble_predict(models, data)
- _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
- stdevs.append(math.sqrt(variance.item()))
- predicted = torch.round(mean)
- expected = target[1]
- total += 1
- correct += (predicted == expected).item()
- out = mean.tolist()
- predictions.append(out)
- act = target[1].tolist()
- actual.append(act)
- accuracy = correct / total
- with open(
- ensemble_folder
- + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
- "w",
- ) as f:
- f.write("Accuracy: " + str(accuracy) + "\n")
- f.write("Correct: " + str(correct) + "\n")
- f.write("Total: " + str(total) + "\n")
- for exp, pred, stdev in zip(actual, predictions, stdevs):
- f.write(
- str(exp)
- + ", "
- + str(pred)
- + ", "
- + str(stdev)
- + ", "
- + str(yes_votes)
- + ", "
- + str(no_votes)
- + "\n"
- )
|