ensemble_predict.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import utils.ensemble as ens
  2. import os
  3. import tomli as toml
  4. from utils.system import force_init_cudnn
  5. from utils.data.datasets import prepare_datasets
  6. import math
  7. import torch
  8. # CONFIGURATION
  9. if os.getenv("ADL_CONFIG_PATH") is None:
  10. with open("config.toml", "rb") as f:
  11. config = toml.load(f)
  12. else:
  13. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  14. config = toml.load(f)
  15. # Force cuDNN initialization
  16. force_init_cudnn(config["training"]["device"])
  17. ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/"
  18. models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
  19. models, model_descs = ens.prune_models(
  20. models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"]
  21. )
  22. # Load test data
  23. test_dataset = prepare_datasets(
  24. config["paths"]["mri_data"],
  25. config["paths"]["xls_data"],
  26. config["dataset"]["validation_split"],
  27. 0,
  28. config["training"]["device"],
  29. )[2]
  30. # Evaluate ensemble and uncertainty test set
  31. correct = 0
  32. total = 0
  33. predictions = []
  34. actual = []
  35. stdevs = []
  36. yes_votes = []
  37. no_votes = []
  38. for data, target in test_dataset:
  39. mri, xls = data
  40. mri = mri.unsqueeze(0)
  41. xls = xls.unsqueeze(0)
  42. data = (mri, xls)
  43. mean, variance = ens.ensemble_predict(models, data)
  44. _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
  45. stdevs.append(math.sqrt(variance.item()))
  46. predicted = torch.round(mean)
  47. expected = target[1]
  48. total += 1
  49. correct += (predicted == expected).item()
  50. out = mean.tolist()
  51. predictions.append(out)
  52. act = target[1].tolist()
  53. actual.append(act)
  54. accuracy = correct / total
  55. with open(
  56. ensemble_folder
  57. + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
  58. "w",
  59. ) as f:
  60. f.write("Accuracy: " + str(accuracy) + "\n")
  61. f.write("Correct: " + str(correct) + "\n")
  62. f.write("Total: " + str(total) + "\n")
  63. for exp, pred, stdev in zip(actual, predictions, stdevs):
  64. f.write(
  65. str(exp)
  66. + ", "
  67. + str(pred)
  68. + ", "
  69. + str(stdev)
  70. + ", "
  71. + str(yes_votes)
  72. + ", "
  73. + str(no_votes)
  74. + "\n"
  75. )