threshold.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. RUN = True
  12. # CONFIGURATION
  13. if os.getenv('ADL_CONFIG_PATH') is None:
  14. with open('config.toml', 'rb') as f:
  15. config = toml.load(f)
  16. else:
  17. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  18. config = toml.load(f)
  19. # This function returns a list of the accuracies given a threshold
  20. def threshold(config):
  21. # First, get the model data
  22. test_set = torch.load(
  23. config['paths']['model_output']
  24. + config['ensemble']['name']
  25. + '/test_dataset.pt'
  26. )
  27. vs = torch.load(
  28. config['paths']['model_output'] + config['ensemble']['name'] + '/val_dataset.pt'
  29. )
  30. test_set = test_set + vs
  31. models, _ = ens.load_models(
  32. config['paths']['model_output'] + config['ensemble']['name'] + '/models/',
  33. config['training']['device'],
  34. )
  35. indv_model = models[0]
  36. predictions = []
  37. indv_predictions = []
  38. # Evaluate ensemble and uncertainty test set
  39. for mdata, target in tqdm(test_set, total=len(test_set)):
  40. mri, xls = mdata
  41. mri = mri.unsqueeze(0)
  42. xls = xls.unsqueeze(0)
  43. mdata = (mri, xls)
  44. mean, variance = ens.ensemble_predict(models, mdata)
  45. stdev = torch.sqrt(variance)
  46. prediction = mean.item()
  47. target = target[1]
  48. # Check if the prediction is correct
  49. correct = (prediction < 0.5 and int(target.item()) == 0) or (
  50. prediction >= 0.5 and int(target.item()) == 1
  51. )
  52. predictions.append(
  53. {
  54. 'Prediction': prediction,
  55. 'Actual': target.item(),
  56. 'Stdev': stdev.item(),
  57. 'Correct': correct,
  58. }
  59. )
  60. i_mean = indv_model(mdata)[:, 1].item()
  61. i_correct = (i_mean < 0.5 and int(target.item()) == 0) or (
  62. i_mean >= 0.5 and int(target.item()) == 1
  63. )
  64. indv_predictions.append(
  65. {
  66. 'Prediction': i_mean,
  67. 'Actual': target.item(),
  68. 'Stdev': 0,
  69. 'Correct': i_correct,
  70. }
  71. )
  72. # Sort the predictions by the uncertainty
  73. predictions = pd.DataFrame(predictions).sort_values(by='Stdev')
  74. # Calculate the metrics for the individual model
  75. indv_predictions = pd.DataFrame(indv_predictions)
  76. indv_correct = indv_predictions['Correct'].sum()
  77. indv_accuracy = indv_correct / len(indv_predictions)
  78. indv_false_pos = len(
  79. indv_predictions[
  80. (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0)
  81. ]
  82. )
  83. indv_false_neg = len(
  84. indv_predictions[
  85. (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1)
  86. ]
  87. )
  88. indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg)
  89. indv_auc = metrics.roc_auc_score(
  90. indv_predictions['Actual'], indv_predictions['Prediction']
  91. )
  92. indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc}
  93. thresholds = []
  94. quantiles = np.arange(0.1, 1, 0.1)
  95. # get uncertainty quantiles
  96. for quantile in quantiles:
  97. thresholds.append(predictions['Stdev'].quantile(quantile))
  98. # Calculate the accuracy of the model for each threshold
  99. accuracies = []
  100. # Calculate the accuracy of the model for each threshold
  101. for threshold, quantile in zip(thresholds, quantiles):
  102. filtered = predictions[predictions['Stdev'] <= threshold]
  103. correct = filtered['Correct'].sum()
  104. total = len(filtered)
  105. accuracy = correct / total
  106. false_positives = len(
  107. filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)]
  108. )
  109. false_negatives = len(
  110. filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)]
  111. )
  112. f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
  113. auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction'])
  114. accuracies.append(
  115. {
  116. 'Threshold': threshold,
  117. 'Accuracy': accuracy,
  118. 'Quantile': quantile,
  119. 'F1': f1,
  120. 'AUC': auc,
  121. }
  122. )
  123. predictions.to_csv(
  124. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  125. )
  126. indv_predictions.to_csv(
  127. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv"
  128. )
  129. return pd.DataFrame(accuracies), indv_metrics
  130. if RUN:
  131. result, indv = threshold(config)
  132. result.to_csv(
  133. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
  134. )
  135. indv = pd.DataFrame([indv])
  136. indv.to_csv(
  137. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
  138. )
  139. result = pd.read_csv(
  140. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
  141. )
  142. predictions = pd.read_csv(
  143. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  144. )
  145. indv = pd.read_csv(
  146. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
  147. )
  148. print(indv)
  149. plt.figure()
  150. plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy')
  151. plt.plot(
  152. result['Quantile'],
  153. [indv['Accuracy']] * len(result['Quantile']),
  154. label='Individual Accuracy',
  155. linestyle='--',
  156. )
  157. plt.legend()
  158. plt.title('Accuracy vs Coverage')
  159. plt.xlabel('Coverage')
  160. plt.ylabel('Accuracy')
  161. plt.gca().invert_xaxis()
  162. plt.savefig(
  163. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
  164. )
  165. plt.figure()
  166. plt.plot(result['Quantile'], result['F1'], label='Ensemble F1')
  167. plt.plot(
  168. result['Quantile'],
  169. [indv['F1']] * len(result['Quantile']),
  170. label='Individual F1',
  171. linestyle='--',
  172. )
  173. plt.legend()
  174. plt.title('F1 vs Coverage')
  175. plt.xlabel('Coverage')
  176. plt.ylabel('F1')
  177. plt.gca().invert_xaxis()
  178. plt.savefig(
  179. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
  180. )
  181. plt.figure()
  182. plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC')
  183. plt.plot(
  184. result['Quantile'],
  185. [indv['AUC']] * len(result['Quantile']),
  186. label='Individual AUC',
  187. linestyle='--',
  188. )
  189. plt.legend()
  190. plt.title('AUC vs Coverage')
  191. plt.xlabel('Coverage')
  192. plt.ylabel('AUC')
  193. plt.gca().invert_xaxis()
  194. plt.savefig(
  195. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
  196. )
  197. # create histogram of the incorrect predictions vs the uncertainty
  198. plt.figure()
  199. plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10)
  200. plt.xlabel('Uncertainty')
  201. plt.ylabel('Number of incorrect predictions')
  202. plt.savefig(
  203. f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
  204. )