threshold.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. RUN = True
  13. # CONFIGURATION
  14. if os.getenv('ADL_CONFIG_PATH') is None:
  15. with open('config.toml', 'rb') as f:
  16. config = toml.load(f)
  17. else:
  18. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  19. config = toml.load(f)
  20. # This function returns a list of the accuracies given a threshold
  21. def threshold(config):
  22. # First, get the model data
  23. test_set = torch.load(
  24. config['paths']['model_output']
  25. + config['ensemble']['name']
  26. + '/test_dataset.pt'
  27. )
  28. vs = torch.load(
  29. config['paths']['model_output'] + config['ensemble']['name'] + '/val_dataset.pt'
  30. )
  31. test_set = test_set + vs
  32. models, _ = ens.load_models(
  33. config['paths']['model_output'] + config['ensemble']['name'] + '/models/',
  34. config['training']['device'],
  35. )
  36. indv_model = models[0]
  37. predictions = []
  38. indv_predictions = []
  39. # Evaluate ensemble and uncertainty test set
  40. for mdata, target in tqdm(test_set, total=len(test_set)):
  41. mri, xls = mdata
  42. mri = mri.unsqueeze(0)
  43. xls = xls.unsqueeze(0)
  44. mdata = (mri, xls)
  45. mean, variance = ens.ensemble_predict(models, mdata)
  46. stdev = torch.sqrt(variance)
  47. prediction = mean.item()
  48. target = target[1]
  49. # Check if the prediction is correct
  50. correct = (prediction < 0.5 and int(target.item()) == 0) or (
  51. prediction >= 0.5 and int(target.item()) == 1
  52. )
  53. predictions.append(
  54. {
  55. 'Prediction': prediction,
  56. 'Actual': target.item(),
  57. 'Stdev': stdev.item(),
  58. 'Correct': correct,
  59. }
  60. )
  61. i_mean = indv_model(mdata)[:, 1].item()
  62. i_correct = (i_mean < 0.5 and int(target.item()) == 0) or (
  63. i_mean >= 0.5 and int(target.item()) == 1
  64. )
  65. indv_predictions.append(
  66. {
  67. 'Prediction': i_mean,
  68. 'Actual': target.item(),
  69. 'Stdev': 0,
  70. 'Correct': i_correct,
  71. }
  72. )
  73. # Sort the predictions by the uncertainty
  74. predictions = pd.DataFrame(predictions).sort_values(by='Stdev')
  75. # Calculate the metrics for the individual model
  76. indv_predictions = pd.DataFrame(indv_predictions)
  77. indv_correct = indv_predictions['Correct'].sum()
  78. indv_accuracy = indv_correct / len(indv_predictions)
  79. indv_false_pos = len(
  80. indv_predictions[
  81. (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0)
  82. ]
  83. )
  84. indv_false_neg = len(
  85. indv_predictions[
  86. (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1)
  87. ]
  88. )
  89. indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg)
  90. indv_auc = metrics.roc_auc_score(
  91. indv_predictions['Actual'], indv_predictions['Prediction']
  92. )
  93. indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc}
  94. thresholds = []
  95. quantiles = np.arange(0.1, 1, 0.1)
  96. # get uncertainty quantiles
  97. for quantile in quantiles:
  98. thresholds.append(predictions['Stdev'].quantile(quantile))
  99. # Calculate the accuracy of the model for each threshold
  100. accuracies = []
  101. # Calculate the accuracy of the model for each threshold
  102. for threshold, quantile in zip(thresholds, quantiles):
  103. filtered = predictions[predictions['Stdev'] <= threshold]
  104. correct = filtered['Correct'].sum()
  105. total = len(filtered)
  106. accuracy = correct / total
  107. false_positives = len(
  108. filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)]
  109. )
  110. false_negatives = len(
  111. filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)]
  112. )
  113. f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
  114. auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction'])
  115. accuracies.append(
  116. {
  117. 'Threshold': threshold,
  118. 'Accuracy': accuracy,
  119. 'Quantile': quantile,
  120. 'F1': f1,
  121. 'AUC': auc,
  122. }
  123. )
  124. predictions.to_csv(
  125. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  126. )
  127. indv_predictions.to_csv(
  128. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv"
  129. )
  130. return pd.DataFrame(accuracies), indv_metrics
  131. if RUN:
  132. result, indv = threshold(config)
  133. result.to_csv(
  134. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
  135. )
  136. indv = pd.DataFrame([indv])
  137. indv.to_csv(
  138. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
  139. )
  140. result = pd.read_csv(
  141. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
  142. )
  143. predictions = pd.read_csv(
  144. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  145. )
  146. indv = pd.read_csv(
  147. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
  148. )
  149. print(indv)
  150. plt.figure()
  151. plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy')
  152. plt.plot(
  153. result['Quantile'],
  154. [indv['Accuracy']] * len(result['Quantile']),
  155. label='Individual Accuracy',
  156. linestyle='--',
  157. )
  158. plt.legend()
  159. plt.title('Accuracy vs Coverage')
  160. plt.xlabel('Coverage')
  161. plt.ylabel('Accuracy')
  162. plt.gca().invert_xaxis()
  163. plt.savefig(
  164. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
  165. )
  166. plt.figure()
  167. plt.plot(result['Quantile'], result['F1'], label='Ensemble F1')
  168. plt.plot(
  169. result['Quantile'],
  170. [indv['F1']] * len(result['Quantile']),
  171. label='Individual F1',
  172. linestyle='--',
  173. )
  174. plt.legend()
  175. plt.title('F1 vs Coverage')
  176. plt.xlabel('Coverage')
  177. plt.ylabel('F1')
  178. plt.gca().invert_xaxis()
  179. plt.savefig(
  180. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
  181. )
  182. plt.figure()
  183. plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC')
  184. plt.plot(
  185. result['Quantile'],
  186. [indv['AUC']] * len(result['Quantile']),
  187. label='Individual AUC',
  188. linestyle='--',
  189. )
  190. plt.legend()
  191. plt.title('AUC vs Coverage')
  192. plt.xlabel('Coverage')
  193. plt.ylabel('AUC')
  194. plt.gca().invert_xaxis()
  195. plt.savefig(
  196. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
  197. )
  198. # create histogram of the incorrect predictions vs the uncertainty
  199. plt.figure()
  200. plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10)
  201. plt.xlabel('Uncertainty')
  202. plt.ylabel('Number of incorrect predictions')
  203. plt.savefig(
  204. f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
  205. )
  206. ece = met.ECE(predictions['Prediction'], predictions['Actual'])
  207. print(f'ECE: {ece}')
  208. with open(
  209. f"{config['paths']['model_output']}{config['ensemble']['name']}/summary.txt", 'a'
  210. ) as f:
  211. f.write(f'ECE: {ece}\n')