12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import utils.ensemble as ens
- import os
- import tomli as toml
- from utils.system import force_init_cudnn
- from utils.data.datasets import prepare_datasets
- import math
- import torch
- # CONFIGURATION
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open('config.toml', 'rb') as f:
- config = toml.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = toml.load(f)
- # Force cuDNN initialization
- force_init_cudnn(config['training']['device'])
- ensemble_folder = (
- config['paths']['model_output'] + config['ensemble']['name'] + '/models/'
- )
- models, model_descs = ens.load_models(ensemble_folder, config['training']['device'])
- models, model_descs = ens.prune_models(
- models, model_descs, ensemble_folder, config['ensemble']['prune_threshold']
- )
- # Load test data
- test_dataset = torch.load(
- config['paths']['model_output'] + config['ensemble']['name'] + '/test_dataset.pt'
- )
- # Evaluate ensemble and uncertainty test set
- correct = 0
- total = 0
- predictions = []
- actual = []
- stdevs = []
- yes_votes = []
- no_votes = []
- for data, target in test_dataset:
- mri, xls = data
- mri = mri.unsqueeze(0)
- xls = xls.unsqueeze(0)
- data = (mri, xls)
- mean, variance = ens.ensemble_predict(models, data)
- _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
- stdevs.append(math.sqrt(variance.item()))
- predicted = torch.round(mean)
- expected = target[1]
- total += 1
- correct += (predicted == expected).item()
- out = mean.tolist()
- predictions.append(out)
- act = target[1].tolist()
- actual.append(act)
- accuracy = correct / total
- with open(
- ensemble_folder
- + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
- 'w',
- ) as f:
- f.write('Accuracy: ' + str(accuracy) + '\n')
- f.write('Correct: ' + str(correct) + '\n')
- f.write('Total: ' + str(total) + '\n')
- for exp, pred, stdev in zip(actual, predictions, stdevs):
- f.write(
- str(exp)
- + ', '
- + str(pred)
- + ', '
- + str(stdev)
- + ', '
- + str(yes_votes)
- + ', '
- + str(no_votes)
- + '\n'
- )
|