ensemble_predict.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import utils.ensemble as ens
  2. import os
  3. import tomli as toml
  4. from utils.system import force_init_cudnn
  5. from utils.data.datasets import prepare_datasets
  6. import math
  7. import torch
  8. # CONFIGURATION
  9. if os.getenv("ADL_CONFIG_PATH") is None:
  10. with open("config.toml", "rb") as f:
  11. config = toml.load(f)
  12. else:
  13. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  14. config = toml.load(f)
  15. # Force cuDNN initialization
  16. force_init_cudnn(config["training"]["device"])
  17. ensemble_folder = (
  18. config["paths"]["model_output"] + config["ensemble"]["name"] + "/models/"
  19. )
  20. models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
  21. models, model_descs = ens.prune_models(
  22. models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"]
  23. )
  24. # Load test data
  25. test_dataset = torch.load(
  26. config["paths"]["model_output"] + config["ensemble"]["name"] + "/test_dataset.pt",
  27. weights_only=False,
  28. )
  29. # Evaluate ensemble and uncertainty test set
  30. correct = 0
  31. total = 0
  32. predictions = []
  33. actual = []
  34. stdevs = []
  35. yes_votes = []
  36. no_votes = []
  37. for data, target in test_dataset:
  38. mri, xls = data
  39. mri = mri.unsqueeze(0)
  40. xls = xls.unsqueeze(0)
  41. data = (mri, xls)
  42. mean, variance = ens.ensemble_predict(models, data)
  43. _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
  44. stdevs.append(math.sqrt(variance.item()))
  45. predicted = torch.round(mean)
  46. expected = target[1]
  47. total += 1
  48. correct += (predicted == expected).item()
  49. out = mean.tolist()
  50. predictions.append(out)
  51. act = target[1].tolist()
  52. actual.append(act)
  53. accuracy = correct / total
  54. with open(
  55. ensemble_folder
  56. + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
  57. "w",
  58. ) as f:
  59. f.write("Accuracy: " + str(accuracy) + "\n")
  60. f.write("Correct: " + str(correct) + "\n")
  61. f.write("Total: " + str(total) + "\n")
  62. for exp, pred, stdev in zip(actual, predictions, stdevs):
  63. f.write(
  64. str(exp)
  65. + ", "
  66. + str(pred)
  67. + ", "
  68. + str(stdev)
  69. + ", "
  70. + str(yes_votes)
  71. + ", "
  72. + str(no_votes)
  73. + "\n"
  74. )