threshold_refac.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. import torch
  6. import matplotlib.pyplot as plt
  7. import sklearn.metrics as metrics
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. import matplotlib.ticker as ticker
  11. import glob
  12. import pickle as pk
  13. import warnings
  14. warnings.filterwarnings('error')
  15. def plot_image_grid(image_ids, dataset, rows, path, titles=None):
  16. fig, axs = plt.subplots(rows, len(image_ids) // rows)
  17. for i, ax in enumerate(axs.flat):
  18. image_id = image_ids[i]
  19. image = dataset[image_id][0][0].squeeze().cpu().numpy()
  20. # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
  21. image = image[:, :, 45]
  22. ax.imshow(image, cmap='gray')
  23. ax.axis('off')
  24. if titles is not None:
  25. ax.set_title(titles[i])
  26. plt.savefig(path)
  27. plt.close()
  28. def plot_single_image(image_id, dataset, path, title=None):
  29. fig, ax = plt.subplots()
  30. image = dataset[image_id][0][0].squeeze().cpu().numpy()
  31. # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
  32. image = image[:, :, 45]
  33. ax.imshow(image, cmap='gray')
  34. ax.axis('off')
  35. if title is not None:
  36. ax.set_title(title)
  37. plt.savefig(path)
  38. plt.close()
  39. # Given a dataframe of the form {data_id: (stat_1, stat_2, ..., correct)}, plot the two statistics against each other and color by correctness
  40. def plot_statistics_versus(
  41. stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
  42. ):
  43. # Get correct predictions and incorrect predictions dataframes
  44. corr_df = dataframe[dataframe['correct']]
  45. incorr_df = dataframe[~dataframe['correct']]
  46. # Plot the correct and incorrect predictions
  47. fig, ax = plt.subplots()
  48. ax.scatter(corr_df[stat_1], corr_df[stat_2], c='green', label='Correct')
  49. ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c='red', label='Incorrect')
  50. ax.legend()
  51. ax.set_xlabel(xaxis_name)
  52. ax.set_ylabel(yaxis_name)
  53. ax.set_title(title)
  54. if annotate:
  55. print('DEBUG -- REMOVE: Annotating')
  56. # label correct points green
  57. for row in dataframe[[stat_1, stat_2]].itertuples():
  58. plt.text(row[1], row[2], row[0], fontsize=6, color='black')
  59. plt.savefig(path)
  60. # Models is a dictionary with the model ids as keys and the model data as values
  61. def get_model_predictions(models, data):
  62. predictions = {}
  63. for model_id, model in models.items():
  64. model.eval()
  65. with torch.no_grad():
  66. # Get the predictions
  67. output = model(data)
  68. predictions[model_id] = output.detach().cpu().numpy()
  69. return predictions
  70. def load_models_v2(folder, device):
  71. glob_path = os.path.join(folder, '*.pt')
  72. model_files = glob.glob(glob_path)
  73. model_dict = {}
  74. for model_file in model_files:
  75. model = torch.load(model_file, map_location=device)
  76. model_id = os.path.basename(model_file).split('_')[0]
  77. model_dict[model_id] = model
  78. if len(model_dict) == 0:
  79. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  80. return model_dict
  81. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  82. def preprocess_data(data, device):
  83. mri, xls = data
  84. mri = mri.unsqueeze(0).to(device)
  85. xls = xls.unsqueeze(0).to(device)
  86. return (mri, xls)
  87. def ensemble_dataset_predictions(models, dataset, device):
  88. # For each datapoint, get the predictions of each model
  89. predictions = {}
  90. for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
  91. # Preprocess data
  92. data = preprocess_data(data, device)
  93. # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
  94. # The key is the id of the image
  95. predictions[i] = (
  96. target.detach().cpu().numpy(),
  97. get_model_predictions(models, data),
  98. )
  99. return predictions
  100. # Given a dictionary of predictions, select one model and eliminate the rest
  101. def select_individual_model(predictions, model_id):
  102. selected_model_predictions = {}
  103. for key, value in predictions.items():
  104. selected_model_predictions[key] = (
  105. value[0],
  106. {model_id: value[1][str(model_id)]},
  107. )
  108. return selected_model_predictions
  109. # Given a dictionary of predictions, select a subset of models and eliminate the rest
  110. # predictions dictory of the form {data_id: (target, {model_id: prediction})}
  111. def select_subset_models(predictions, model_ids):
  112. selected_model_predictions = {}
  113. for key, value in predictions.items():
  114. target = value[0]
  115. model_predictions = value[1]
  116. # Filter the model predictions, only keeping selected models
  117. selected_model_predictions[key] = (
  118. target,
  119. {model_id: model_predictions[str(model_id + 1)] for model_id in model_ids},
  120. )
  121. return selected_model_predictions
  122. # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
  123. # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
  124. def calculate_statistics(predictions):
  125. # Create DataFrame with columns for each statistic
  126. stats_df = pd.DataFrame(
  127. columns=[
  128. 'mean',
  129. 'stdev',
  130. 'entropy',
  131. 'confidence',
  132. 'correct',
  133. 'predicted',
  134. 'actual',
  135. ]
  136. )
  137. # First, loop through each prediction
  138. for key, value in predictions.items():
  139. target = value[0]
  140. model_predictions = list(value[1].values())
  141. # Calculate the mean and stdev of predictions
  142. mean = np.squeeze(np.mean(model_predictions, axis=0))
  143. stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
  144. # Calculate the entropy of the predictions
  145. entropy = met.entropy(mean)
  146. # Calculate confidence
  147. confidence = (np.max(mean) - 0.5) * 2
  148. # Calculate predicted and actual
  149. predicted = np.argmax(mean)
  150. actual = np.argmax(target)
  151. # Determine if the prediction is correct
  152. correct = predicted == actual
  153. # Add the statistics to the dataframe
  154. stats_df.loc[key] = [
  155. mean,
  156. stdev,
  157. entropy,
  158. confidence,
  159. correct,
  160. predicted,
  161. actual,
  162. ]
  163. return stats_df
  164. # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
  165. # Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
  166. def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
  167. # Gives a dataframe
  168. percentile_df = statistics[statistic_name].quantile(
  169. q=np.linspace(0.05, 0.95, num=18)
  170. )
  171. # Dictionary of form {threshold: {metric: value}}
  172. thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
  173. for percentile, value in percentile_df.items():
  174. # Filter the statistics
  175. if low_to_high:
  176. filtered_statistics = statistics[statistics[statistic_name] < value]
  177. else:
  178. filtered_statistics = statistics[statistics[statistic_name] >= value]
  179. # Calculate accuracy and f1 score
  180. accuracy = filtered_statistics['correct'].mean()
  181. # Calculate F1 score
  182. predicted = filtered_statistics['predicted'].values
  183. actual = filtered_statistics['actual'].values
  184. f1 = metrics.f1_score(actual, predicted)
  185. # Add the metrics to the dataframe
  186. thresholds_pd.loc[percentile] = [accuracy, f1]
  187. return thresholds_pd
  188. # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
  189. # Can plot an additional line if given (used for individual results)
  190. def plot_threshold_analysis(
  191. thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
  192. ):
  193. # Initialize the plot
  194. fig, ax = plt.subplots()
  195. # Get the thresholds and metrics
  196. thresholds = list(thresholds_metric.index)
  197. metric = list(thresholds_metric.values)
  198. # Plot the metric against the threshold
  199. plt.plot(thresholds, metric, 'bo-', label='Ensemble')
  200. if additional_set is not None:
  201. # Get the thresholds and metrics
  202. thresholds = list(additional_set.index)
  203. metric = list(additional_set.values)
  204. # Plot the metric against the threshold
  205. plt.plot(thresholds, metric, 'rx-', label='Individual')
  206. if flip:
  207. ax.invert_xaxis()
  208. # Add labels
  209. plt.title(title)
  210. plt.xlabel(x_label)
  211. plt.ylabel(y_label)
  212. plt.legend()
  213. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
  214. plt.savefig(path)
  215. plt.close()
  216. # Code from https://stackoverflow.com/questions/16458340
  217. # Returns the intersections of multiple dictionaries
  218. def common_entries(*dcts):
  219. if not dcts:
  220. return
  221. for i in set(dcts[0]).intersection(*dcts[1:]):
  222. yield (i,) + tuple(d[i] for d in dcts)
  223. # Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
  224. def calculate_overall_statistics(ensemble_statistics):
  225. predicted = ensemble_statistics['predicted']
  226. actual = ensemble_statistics['actual']
  227. # New dataframe to store the statistics
  228. stats_df = pd.DataFrame(
  229. columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']
  230. ).set_index('stat')
  231. # Loop through and calculate the ECE, MCE, Brier Score, and NLL
  232. for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
  233. ece = met.ECE(predicted, ensemble_statistics[stat], actual)
  234. mce = met.MCE(predicted, ensemble_statistics[stat], actual)
  235. brier = met.brier_binary(ensemble_statistics[stat], actual)
  236. nll = met.nll_binary(ensemble_statistics[stat], actual)
  237. stats_df.loc[stat] = [ece, mce, brier, nll]
  238. return stats_df
  239. # CONFIGURATION
  240. def load_config():
  241. if os.getenv('ADL_CONFIG_PATH') is None:
  242. with open('config.toml', 'rb') as f:
  243. config = toml.load(f)
  244. else:
  245. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  246. config = toml.load(f)
  247. return config
  248. def prune_dataset(dataset, pruned_ids):
  249. pruned_dataset = []
  250. for i, (data, target) in enumerate(dataset):
  251. if i not in pruned_ids:
  252. pruned_dataset.append((data, target))
  253. return pruned_dataset
  254. def main():
  255. config = load_config()
  256. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  257. V3_PATH = ENSEMBLE_PATH + '/v3'
  258. # Create the directory if it does not exist
  259. if not os.path.exists(V3_PATH):
  260. os.makedirs(V3_PATH)
  261. # Load the models
  262. device = torch.device(config['training']['device'])
  263. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  264. # Load Dataset
  265. dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  266. f'{ENSEMBLE_PATH}/val_dataset.pt'
  267. )
  268. if config['ensemble']['run_models']:
  269. # Get thre predicitons of the ensemble
  270. ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
  271. # Save to file using pickle
  272. with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
  273. pk.dump(ensemble_predictions, f)
  274. else:
  275. # Load the predictions from file
  276. with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
  277. ensemble_predictions = pk.load(f)
  278. # Get the statistics and thresholds of the ensemble
  279. ensemble_statistics = calculate_statistics(ensemble_predictions)
  280. stdev_thresholds = conduct_threshold_analysis(
  281. ensemble_statistics, 'stdev', low_to_high=True
  282. )
  283. entropy_thresholds = conduct_threshold_analysis(
  284. ensemble_statistics, 'entropy', low_to_high=True
  285. )
  286. confidence_thresholds = conduct_threshold_analysis(
  287. ensemble_statistics, 'confidence', low_to_high=False
  288. )
  289. raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
  290. ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
  291. # Plot confidence vs standard deviation
  292. plot_statistics_versus(
  293. 'raw_confidence',
  294. 'stdev',
  295. 'Confidence',
  296. 'Standard Deviation',
  297. 'Confidence vs Standard Deviation',
  298. ensemble_statistics,
  299. f'{V3_PATH}/confidence_vs_stdev.png',
  300. annotate=True,
  301. )
  302. # Plot images - 3 weird and 3 normal
  303. # Selected from confidence vs stdev plot
  304. plot_image_grid(
  305. [279, 202, 28, 107, 27, 121],
  306. dataset,
  307. 2,
  308. f'{V3_PATH}/image_grid.png',
  309. titles=[
  310. 'Weird: 279',
  311. 'Weird: 202',
  312. 'Weird: 28',
  313. 'Normal: 107',
  314. 'Normal: 27',
  315. 'Normal: 121',
  316. ],
  317. )
  318. # Filter dataset for where confidence < .7 and stdev < .1
  319. weird_results = ensemble_statistics.loc[
  320. (
  321. (ensemble_statistics['raw_confidence'] < 0.7)
  322. & (ensemble_statistics['stdev'] < 0.1)
  323. )
  324. ]
  325. normal_results = ensemble_statistics.loc[
  326. ~(
  327. (ensemble_statistics['raw_confidence'] < 0.7)
  328. & (ensemble_statistics['stdev'] < 0.1)
  329. )
  330. ]
  331. # Get the data ids in a list
  332. # Plot the images
  333. if not os.path.exists(f'{V3_PATH}/images'):
  334. os.makedirs(f'{V3_PATH}/images/weird')
  335. os.makedirs(f'{V3_PATH}/images/normal')
  336. for i in weird_results.itertuples():
  337. id = i.Index
  338. conf = i.raw_confidence
  339. stdev = i.stdev
  340. plot_single_image(
  341. id,
  342. dataset,
  343. f'{V3_PATH}/images/weird/{id}.png',
  344. title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
  345. )
  346. for i in normal_results.itertuples():
  347. id = i.Index
  348. conf = i.raw_confidence
  349. stdev = i.stdev
  350. plot_single_image(
  351. id,
  352. dataset,
  353. f'{V3_PATH}/images/normal/{id}.png',
  354. title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
  355. )
  356. # Calculate overall statistics
  357. overall_statistics = calculate_overall_statistics(ensemble_statistics)
  358. # Print overall statistics
  359. print(overall_statistics)
  360. # Print overall ensemble statistics
  361. print('Ensemble Statistics')
  362. print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
  363. print(
  364. f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
  365. )
  366. # Get the predictions, statistics and thresholds an individual model
  367. indv_id = config['ensemble']['individual_id']
  368. indv_predictions = select_individual_model(ensemble_predictions, indv_id)
  369. indv_statistics = calculate_statistics(indv_predictions)
  370. # Calculate entropy and confidence thresholds for individual model
  371. indv_entropy_thresholds = conduct_threshold_analysis(
  372. indv_statistics, 'entropy', low_to_high=True
  373. )
  374. indv_confidence_thresholds = conduct_threshold_analysis(
  375. indv_statistics, 'confidence', low_to_high=False
  376. )
  377. # Plot the threshold analysis for standard deviation
  378. plot_threshold_analysis(
  379. stdev_thresholds['accuracy'],
  380. 'Stdev Threshold Analysis for Accuracy',
  381. 'Stdev Threshold',
  382. 'Accuracy',
  383. f'{V3_PATH}/stdev_threshold_analysis.png',
  384. flip=True,
  385. )
  386. plot_threshold_analysis(
  387. stdev_thresholds['f1'],
  388. 'Stdev Threshold Analysis for F1 Score',
  389. 'Stdev Threshold',
  390. 'F1 Score',
  391. f'{V3_PATH}/stdev_threshold_analysis_f1.png',
  392. flip=True,
  393. )
  394. # Plot the threshold analysis for entropy
  395. plot_threshold_analysis(
  396. entropy_thresholds['accuracy'],
  397. 'Entropy Threshold Analysis for Accuracy',
  398. 'Entropy Threshold',
  399. 'Accuracy',
  400. f'{V3_PATH}/entropy_threshold_analysis.png',
  401. indv_entropy_thresholds['accuracy'],
  402. flip=True,
  403. )
  404. plot_threshold_analysis(
  405. entropy_thresholds['f1'],
  406. 'Entropy Threshold Analysis for F1 Score',
  407. 'Entropy Threshold',
  408. 'F1 Score',
  409. f'{V3_PATH}/entropy_threshold_analysis_f1.png',
  410. indv_entropy_thresholds['f1'],
  411. flip=True,
  412. )
  413. # Plot the threshold analysis for confidence
  414. plot_threshold_analysis(
  415. confidence_thresholds['accuracy'],
  416. 'Confidence Threshold Analysis for Accuracy',
  417. 'Confidence Threshold',
  418. 'Accuracy',
  419. f'{V3_PATH}/confidence_threshold_analysis.png',
  420. indv_confidence_thresholds['accuracy'],
  421. )
  422. plot_threshold_analysis(
  423. confidence_thresholds['f1'],
  424. 'Confidence Threshold Analysis for F1 Score',
  425. 'Confidence Threshold',
  426. 'F1 Score',
  427. f'{V3_PATH}/confidence_threshold_analysis_f1.png',
  428. indv_confidence_thresholds['f1'],
  429. )
  430. if __name__ == '__main__':
  431. main()