ensemble_predict.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import utils.ensemble as ens
  2. import os
  3. import tomli as toml
  4. from utils.system import force_init_cudnn
  5. from utils.data.datasets import prepare_datasets
  6. import math
  7. import torch
  8. # CONFIGURATION
  9. if os.getenv("ADL_CONFIG_PATH") is None:
  10. with open("config.toml", "rb") as f:
  11. config = toml.load(f)
  12. else:
  13. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  14. config = toml.load(f)
  15. # Force cuDNN initialization
  16. force_init_cudnn(config["training"]["device"])
  17. ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/"
  18. models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
  19. # Load test data
  20. test_dataset = prepare_datasets(
  21. config["paths"]["mri_data"],
  22. config["paths"]["xls_data"],
  23. config["dataset"]["validation_split"],
  24. 0,
  25. config["training"]["device"],
  26. )[2]
  27. # Evaluate ensemble and uncertainty test set
  28. correct = 0
  29. total = 0
  30. predictions = []
  31. actual = []
  32. stdevs = []
  33. yes_votes = []
  34. no_votes = []
  35. for data, target in test_dataset:
  36. mri, xls = data
  37. mri = mri.unsqueeze(0)
  38. xls = xls.unsqueeze(0)
  39. data = (mri, xls)
  40. mean, variance = ens.ensemble_predict(models, data)
  41. _, yes_votes, no_votes = ens.ensemble_predict_strict_classes(models, data)
  42. stdevs.append(math.sqrt(variance.item()))
  43. predicted = torch.round(mean)
  44. expected = target[1]
  45. total += 1
  46. correct += (predicted == expected).item()
  47. out = mean.tolist()
  48. predictions.append(out)
  49. act = target[1].tolist()
  50. actual.append(act)
  51. accuracy = correct / total
  52. with open(ensemble_folder + "ensemble_test_results.txt", "w") as f:
  53. f.write("Accuracy: " + str(accuracy) + "\n")
  54. f.write("Correct: " + str(correct) + "\n")
  55. f.write("Total: " + str(total) + "\n")
  56. for exp, pred, stdev in zip(actual, predictions, stdevs):
  57. f.write(
  58. str(exp)
  59. + ", "
  60. + str(pred)
  61. + ", "
  62. + str(stdev)
  63. + ", "
  64. + str(yes_votes.item())
  65. + ", "
  66. + str(no_votes.item())
  67. + "\n"
  68. )