threshold.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. # CONFIGURATION
  12. if os.getenv("ADL_CONFIG_PATH") is None:
  13. with open("config.toml", "rb") as f:
  14. config = toml.load(f)
  15. else:
  16. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  17. config = toml.load(f)
  18. # This function returns a list of the accuracies given a threshold
  19. def threshold(config):
  20. # First, get the model data
  21. ts, vs, test_set = prepare_datasets(
  22. config["paths"]["mri_data"],
  23. config["paths"]["xls_data"],
  24. config["dataset"]["validation_split"],
  25. 944,
  26. config["training"]["device"],
  27. )
  28. test_set = test_set + vs
  29. models, _ = ens.load_models(
  30. config["paths"]["model_output"] + config["ensemble"]["name"] + "/",
  31. config["training"]["device"],
  32. )
  33. predictions = []
  34. # Evaluate ensemble and uncertainty test set
  35. for mdata, target in tqdm(test_set, total=len(test_set)):
  36. mri, xls = mdata
  37. mri = mri.unsqueeze(0)
  38. xls = xls.unsqueeze(0)
  39. mdata = (mri, xls)
  40. mean, variance = ens.ensemble_predict(models, mdata)
  41. stdev = torch.sqrt(variance)
  42. prediction = mean.item()
  43. target = target[1]
  44. # Check if the prediction is correct
  45. correct = (prediction < 0.5 and int(target.item()) == 0) or (
  46. prediction >= 0.5 and int(target.item()) == 1
  47. )
  48. predictions.append(
  49. {
  50. "Prediction": prediction,
  51. "Actual": target.item(),
  52. "Stdev": stdev.item(),
  53. "Correct": correct,
  54. }
  55. )
  56. # Sort the predictions by the uncertainty
  57. predictions = pd.DataFrame(predictions).sort_values(by="Stdev")
  58. thresholds = []
  59. quantiles = np.arange(0.1, 1, 0.1)
  60. # get uncertainty quantiles
  61. for quantile in quantiles:
  62. thresholds.append(predictions["Stdev"].quantile(quantile))
  63. # Calculate the accuracy of the model for each threshold
  64. accuracies = []
  65. # Calculate the accuracy of the model for each threshold
  66. for threshold, quantile in zip(thresholds, quantiles):
  67. filtered = predictions[predictions["Stdev"] <= threshold]
  68. correct = filtered["Correct"].sum()
  69. total = len(filtered)
  70. accuracy = correct / total
  71. false_positives = len(
  72. filtered[(filtered["Prediction"] >= 0.5) & (filtered["Actual"] == 0)]
  73. )
  74. false_negatives = len(
  75. filtered[(filtered["Prediction"] < 0.5) & (filtered["Actual"] == 1)]
  76. )
  77. f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
  78. auc = metrics.roc_auc_score(filtered["Actual"], filtered["Prediction"])
  79. accuracies.append(
  80. {
  81. "Threshold": threshold,
  82. "Accuracy": accuracy,
  83. "Quantile": quantile,
  84. "F1": f1,
  85. "AUC": auc,
  86. }
  87. )
  88. predictions.to_csv(
  89. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  90. )
  91. return pd.DataFrame(accuracies)
  92. result = threshold(config)
  93. result.to_csv("coverage.csv")
  94. result = pd.read_csv("coverage.csv")
  95. predictions = pd.read_csv(
  96. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  97. )
  98. print(result)
  99. plt.figure()
  100. plt.plot(result["Quantile"], result["Accuracy"])
  101. plt.xlabel("Coverage")
  102. plt.ylabel("Accuracy")
  103. plt.gca().invert_xaxis()
  104. plt.savefig(
  105. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
  106. )
  107. plt.figure()
  108. plt.plot(result["Quantile"], result["F1"])
  109. plt.xlabel("Coverage")
  110. plt.ylabel("F1")
  111. plt.gca().invert_xaxis()
  112. plt.savefig(
  113. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
  114. )
  115. plt.figure()
  116. plt.plot(result["Quantile"], result["AUC"])
  117. plt.xlabel("Coverage")
  118. plt.ylabel("AUC")
  119. plt.gca().invert_xaxis()
  120. plt.savefig(
  121. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
  122. )
  123. # create histogram of the incorrect predictions vs the uncertainty
  124. plt.figure()
  125. plt.hist(predictions[~predictions["Correct"]]["Stdev"], bins=10)
  126. plt.xlabel("Uncertainty")
  127. plt.ylabel("Number of incorrect predictions")
  128. plt.savefig(
  129. f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
  130. )