12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- import utils.ensemble as ens
- import os
- import tomli as toml
- from utils.system import force_init_cudnn
- from utils.data.datasets import prepare_datasets
- import math
- import torch
- # CONFIGURATION
- if os.getenv("ADL_CONFIG_PATH") is None:
- with open("config.toml", "rb") as f:
- config = toml.load(f)
- else:
- with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
- config = toml.load(f)
- # Force cuDNN initialization
- force_init_cudnn(config["training"]["device"])
- ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/"
- models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
- # Load test data
- test_dataset = prepare_datasets(
- config["paths"]["mri_data"],
- config["paths"]["xls_data"],
- config["dataset"]["validation_split"],
- 0,
- config["training"]["device"],
- )[2]
- # Evaluate ensemble and uncertainty test set
- correct = 0
- total = 0
- predictions = []
- actual = []
- stdevs = []
- yes_votes = []
- no_votes = []
- for data, target in test_dataset:
- mri, xls = data
- mri = mri.unsqueeze(0)
- xls = xls.unsqueeze(0)
- data = (mri, xls)
- mean, variance = ens.ensemble_predict(models, data)
- _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
- stdevs.append(math.sqrt(variance.item()))
- predicted = torch.round(mean)
- expected = target[1]
- total += 1
- correct += (predicted == expected).item()
- out = mean.tolist()
- predictions.append(out)
- act = target[1].tolist()
- actual.append(act)
- accuracy = correct / total
- with open(ensemble_folder + "ensemble_test_results.txt", "w") as f:
- f.write("Accuracy: " + str(accuracy) + "\n")
- f.write("Correct: " + str(correct) + "\n")
- f.write("Total: " + str(total) + "\n")
- for exp, pred, stdev in zip(actual, predictions, stdevs):
- f.write(
- str(exp)
- + ", "
- + str(pred)
- + ", "
- + str(stdev)
- + ", "
- + str(yes_votes.item())
- + ", "
- + str(no_votes.item())
- + "\n"
- )
|