threshold.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. RUN = True
  12. # CONFIGURATION
  13. if os.getenv('ADL_CONFIG_PATH') is None:
  14. with open('config.toml', 'rb') as f:
  15. config = toml.load(f)
  16. else:
  17. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  18. config = toml.load(f)
  19. # This function returns a list of the accuracies given a threshold
  20. def threshold(config):
  21. # First, get the model data
  22. ts, vs, test_set = prepare_datasets(
  23. config['paths']['mri_data'],
  24. config['paths']['xls_data'],
  25. config['dataset']['validation_split'],
  26. 944,
  27. config['training']['device'],
  28. )
  29. test_set = test_set + vs
  30. models, _ = ens.load_models(
  31. config['paths']['model_output'] + config['ensemble']['name'] + '/',
  32. config['training']['device'],
  33. )
  34. indv_model = models[0]
  35. predictions = []
  36. indv_predictions = []
  37. # Evaluate ensemble and uncertainty test set
  38. for mdata, target in tqdm(test_set, total=len(test_set)):
  39. mri, xls = mdata
  40. mri = mri.unsqueeze(0)
  41. xls = xls.unsqueeze(0)
  42. mdata = (mri, xls)
  43. mean, variance = ens.ensemble_predict(models, mdata)
  44. stdev = torch.sqrt(variance)
  45. prediction = mean.item()
  46. target = target[1]
  47. # Check if the prediction is correct
  48. correct = (prediction < 0.5 and int(target.item()) == 0) or (
  49. prediction >= 0.5 and int(target.item()) == 1
  50. )
  51. predictions.append(
  52. {
  53. 'Prediction': prediction,
  54. 'Actual': target.item(),
  55. 'Stdev': stdev.item(),
  56. 'Correct': correct,
  57. }
  58. )
  59. i_mean = indv_model(mdata)[:, 1].item()
  60. i_correct = (i_mean < 0.5 and int(target.item()) == 0) or (
  61. i_mean >= 0.5 and int(target.item()) == 1
  62. )
  63. indv_predictions.append(
  64. {
  65. 'Prediction': i_mean,
  66. 'Actual': target.item(),
  67. 'Stdev': 0,
  68. 'Correct': i_correct,
  69. }
  70. )
  71. # Sort the predictions by the uncertainty
  72. predictions = pd.DataFrame(predictions).sort_values(by='Stdev')
  73. # Calculate the metrics for the individual model
  74. indv_predictions = pd.DataFrame(indv_predictions)
  75. indv_correct = indv_predictions['Correct'].sum()
  76. indv_accuracy = indv_correct / len(indv_predictions)
  77. indv_false_pos = len(
  78. indv_predictions[
  79. (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0)
  80. ]
  81. )
  82. indv_false_neg = len(
  83. indv_predictions[
  84. (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1)
  85. ]
  86. )
  87. indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg)
  88. indv_auc = metrics.roc_auc_score(
  89. indv_predictions['Actual'], indv_predictions['Prediction']
  90. )
  91. indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc}
  92. thresholds = []
  93. quantiles = np.arange(0.1, 1, 0.1)
  94. # get uncertainty quantiles
  95. for quantile in quantiles:
  96. thresholds.append(predictions['Stdev'].quantile(quantile))
  97. # Calculate the accuracy of the model for each threshold
  98. accuracies = []
  99. # Calculate the accuracy of the model for each threshold
  100. for threshold, quantile in zip(thresholds, quantiles):
  101. filtered = predictions[predictions['Stdev'] <= threshold]
  102. correct = filtered['Correct'].sum()
  103. total = len(filtered)
  104. accuracy = correct / total
  105. false_positives = len(
  106. filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)]
  107. )
  108. false_negatives = len(
  109. filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)]
  110. )
  111. f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
  112. auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction'])
  113. accuracies.append(
  114. {
  115. 'Threshold': threshold,
  116. 'Accuracy': accuracy,
  117. 'Quantile': quantile,
  118. 'F1': f1,
  119. 'AUC': auc,
  120. }
  121. )
  122. predictions.to_csv(
  123. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  124. )
  125. indv_predictions.to_csv(
  126. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv"
  127. )
  128. return pd.DataFrame(accuracies), indv_metrics
  129. if RUN:
  130. result, indv = threshold(config)
  131. result.to_csv(
  132. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
  133. )
  134. indv = pd.DataFrame([indv])
  135. indv.to_csv(
  136. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
  137. )
  138. result = pd.read_csv(
  139. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
  140. )
  141. predictions = pd.read_csv(
  142. f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
  143. )
  144. indv = pd.read_csv(
  145. f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
  146. )
  147. print(indv)
  148. plt.figure()
  149. plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy')
  150. plt.plot(
  151. result['Quantile'],
  152. [indv['Accuracy']] * len(result['Quantile']),
  153. label='Individual Accuracy',
  154. linestyle='--',
  155. )
  156. plt.legend()
  157. plt.title('Accuracy vs Coverage')
  158. plt.xlabel('Coverage')
  159. plt.ylabel('Accuracy')
  160. plt.gca().invert_xaxis()
  161. plt.savefig(
  162. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
  163. )
  164. plt.figure()
  165. plt.plot(result['Quantile'], result['F1'], label='Ensemble F1')
  166. plt.plot(
  167. result['Quantile'],
  168. [indv['F1']] * len(result['Quantile']),
  169. label='Individual F1',
  170. linestyle='--',
  171. )
  172. plt.legend()
  173. plt.title('F1 vs Coverage')
  174. plt.xlabel('Coverage')
  175. plt.ylabel('F1')
  176. plt.gca().invert_xaxis()
  177. plt.savefig(
  178. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
  179. )
  180. plt.figure()
  181. plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC')
  182. plt.plot(
  183. result['Quantile'],
  184. [indv['AUC']] * len(result['Quantile']),
  185. label='Individual AUC',
  186. linestyle='--',
  187. )
  188. plt.legend()
  189. plt.title('AUC vs Coverage')
  190. plt.xlabel('Coverage')
  191. plt.ylabel('AUC')
  192. plt.gca().invert_xaxis()
  193. plt.savefig(
  194. f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
  195. )
  196. # create histogram of the incorrect predictions vs the uncertainty
  197. plt.figure()
  198. plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10)
  199. plt.xlabel('Uncertainty')
  200. plt.ylabel('Number of incorrect predictions')
  201. plt.savefig(
  202. f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
  203. )