| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 | 
							- import utils.ensemble as ens
 
- import os
 
- import tomli as toml
 
- from utils.system import force_init_cudnn
 
- from utils.data.datasets import prepare_datasets
 
- import math
 
- import torch
 
- # CONFIGURATION
 
- if os.getenv('ADL_CONFIG_PATH') is None:
 
-     with open('config.toml', 'rb') as f:
 
-         config = toml.load(f)
 
- else:
 
-     with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
 
-         config = toml.load(f)
 
- # Force cuDNN initialization
 
- force_init_cudnn(config['training']['device'])
 
- ensemble_folder = (
 
-     config['paths']['model_output'] + config['ensemble']['name'] + '/models/'
 
- )
 
- models, model_descs = ens.load_models(ensemble_folder, config['training']['device'])
 
- models, model_descs = ens.prune_models(
 
-     models, model_descs, ensemble_folder, config['ensemble']['prune_threshold']
 
- )
 
- # Load test data
 
- test_dataset = torch.load(
 
-     config['paths']['model_output'] + config['ensemble']['name'] + '/test_dataset.pt'
 
- )
 
- # Evaluate ensemble and uncertainty test set
 
- correct = 0
 
- total = 0
 
- predictions = []
 
- actual = []
 
- stdevs = []
 
- yes_votes = []
 
- no_votes = []
 
- for data, target in test_dataset:
 
-     mri, xls = data
 
-     mri = mri.unsqueeze(0)
 
-     xls = xls.unsqueeze(0)
 
-     data = (mri, xls)
 
-     mean, variance = ens.ensemble_predict(models, data)
 
-     _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
 
-     stdevs.append(math.sqrt(variance.item()))
 
-     predicted = torch.round(mean)
 
-     expected = target[1]
 
-     total += 1
 
-     correct += (predicted == expected).item()
 
-     out = mean.tolist()
 
-     predictions.append(out)
 
-     act = target[1].tolist()
 
-     actual.append(act)
 
- accuracy = correct / total
 
- with open(
 
-     ensemble_folder
 
-     + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
 
-     'w',
 
- ) as f:
 
-     f.write('Accuracy: ' + str(accuracy) + '\n')
 
-     f.write('Correct: ' + str(correct) + '\n')
 
-     f.write('Total: ' + str(total) + '\n')
 
-     for exp, pred, stdev in zip(actual, predictions, stdevs):
 
-         f.write(
 
-             str(exp)
 
-             + ', '
 
-             + str(pred)
 
-             + ', '
 
-             + str(stdev)
 
-             + ', '
 
-             + str(yes_votes)
 
-             + ', '
 
-             + str(no_votes)
 
-             + '\n'
 
-         )
 
 
  |