ensemble_predict.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import utils.ensemble as ens
  2. import os
  3. import tomli as toml
  4. from utils.system import force_init_cudnn
  5. from utils.data.datasets import prepare_datasets
  6. import math
  7. import torch
  8. # CONFIGURATION
  9. if os.getenv('ADL_CONFIG_PATH') is None:
  10. with open('config.toml', 'rb') as f:
  11. config = toml.load(f)
  12. else:
  13. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  14. config = toml.load(f)
  15. # Force cuDNN initialization
  16. force_init_cudnn(config['training']['device'])
  17. ensemble_folder = (
  18. config['paths']['model_output'] + config['ensemble']['name'] + '/models/'
  19. )
  20. models, model_descs = ens.load_models(ensemble_folder, config['training']['device'])
  21. models, model_descs = ens.prune_models(
  22. models, model_descs, ensemble_folder, config['ensemble']['prune_threshold']
  23. )
  24. # Load test data
  25. test_dataset = torch.load(
  26. config['paths']['model_output'] + config['ensemble']['name'] + '/test_dataset.pt'
  27. )
  28. # Evaluate ensemble and uncertainty test set
  29. correct = 0
  30. total = 0
  31. predictions = []
  32. actual = []
  33. stdevs = []
  34. yes_votes = []
  35. no_votes = []
  36. for data, target in test_dataset:
  37. mri, xls = data
  38. mri = mri.unsqueeze(0)
  39. xls = xls.unsqueeze(0)
  40. data = (mri, xls)
  41. mean, variance = ens.ensemble_predict(models, data)
  42. _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
  43. stdevs.append(math.sqrt(variance.item()))
  44. predicted = torch.round(mean)
  45. expected = target[1]
  46. total += 1
  47. correct += (predicted == expected).item()
  48. out = mean.tolist()
  49. predictions.append(out)
  50. act = target[1].tolist()
  51. actual.append(act)
  52. accuracy = correct / total
  53. with open(
  54. ensemble_folder
  55. + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
  56. 'w',
  57. ) as f:
  58. f.write('Accuracy: ' + str(accuracy) + '\n')
  59. f.write('Correct: ' + str(correct) + '\n')
  60. f.write('Total: ' + str(total) + '\n')
  61. for exp, pred, stdev in zip(actual, predictions, stdevs):
  62. f.write(
  63. str(exp)
  64. + ', '
  65. + str(pred)
  66. + ', '
  67. + str(stdev)
  68. + ', '
  69. + str(yes_votes)
  70. + ', '
  71. + str(no_votes)
  72. + '\n'
  73. )