threshold_refac.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. from utils.data.datasets import prepare_datasets
  6. import utils.ensemble as ens
  7. import torch
  8. import matplotlib.pyplot as plt
  9. import sklearn.metrics as metrics
  10. from tqdm import tqdm
  11. import utils.metrics as met
  12. import itertools as it
  13. import matplotlib.ticker as ticker
  14. import glob
  15. import pickle as pk
  16. import warnings
  17. import random as rand
  18. warnings.filterwarnings('error')
  19. def plot_image_grid(image_ids, dataset, rows, path, titles=None):
  20. fig, axs = plt.subplots(rows, len(image_ids) // rows)
  21. for i, ax in enumerate(axs.flat):
  22. image_id = image_ids[i]
  23. image = dataset[image_id][0][0].squeeze().cpu().numpy()
  24. # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
  25. image = image[:, :, 45]
  26. ax.imshow(image, cmap='gray')
  27. ax.axis('off')
  28. if titles is not None:
  29. ax.set_title(titles[i])
  30. plt.savefig(path)
  31. plt.close()
  32. def plot_single_image(image_id, dataset, path, title=None):
  33. fig, ax = plt.subplots()
  34. image = dataset[image_id][0][0].squeeze().cpu().numpy()
  35. # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
  36. image = image[:, :, 45]
  37. ax.imshow(image, cmap='gray')
  38. ax.axis('off')
  39. if title is not None:
  40. ax.set_title(title)
  41. plt.savefig(path)
  42. plt.close()
  43. # Given a dataframe of the form {data_id: (stat_1, stat_2, ..., correct)}, plot the two statistics against each other and color by correctness
  44. def plot_statistics_versus(
  45. stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
  46. ):
  47. # Get correct predictions and incorrect predictions dataframes
  48. corr_df = dataframe[dataframe['correct']]
  49. incorr_df = dataframe[~dataframe['correct']]
  50. # Plot the correct and incorrect predictions
  51. fig, ax = plt.subplots()
  52. ax.scatter(corr_df[stat_1], corr_df[stat_2], c='green', label='Correct')
  53. ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c='red', label='Incorrect')
  54. ax.legend()
  55. ax.set_xlabel(xaxis_name)
  56. ax.set_ylabel(yaxis_name)
  57. ax.set_title(title)
  58. if annotate:
  59. print('DEBUG -- REMOVE: Annotating')
  60. # label correct points green
  61. for row in dataframe[[stat_1, stat_2]].itertuples():
  62. plt.text(row[1], row[2], row[0], fontsize=6, color='black')
  63. plt.savefig(path)
  64. # Models is a dictionary with the model ids as keys and the model data as values
  65. def get_model_predictions(models, data):
  66. predictions = {}
  67. for model_id, model in models.items():
  68. model.eval()
  69. with torch.no_grad():
  70. # Get the predictions
  71. output = model(data)
  72. predictions[model_id] = output.detach().cpu().numpy()
  73. return predictions
  74. def load_models_v2(folder, device):
  75. glob_path = os.path.join(folder, '*.pt')
  76. model_files = glob.glob(glob_path)
  77. model_dict = {}
  78. for model_file in model_files:
  79. model = torch.load(model_file, map_location=device)
  80. model_id = os.path.basename(model_file).split('_')[0]
  81. model_dict[model_id] = model
  82. if len(model_dict) == 0:
  83. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  84. return model_dict
  85. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  86. def preprocess_data(data, device):
  87. mri, xls = data
  88. mri = mri.unsqueeze(0).to(device)
  89. xls = xls.unsqueeze(0).to(device)
  90. return (mri, xls)
  91. def ensemble_dataset_predictions(models, dataset, device):
  92. # For each datapoint, get the predictions of each model
  93. predictions = {}
  94. for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
  95. # Preprocess data
  96. data = preprocess_data(data, device)
  97. # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
  98. # The key is the id of the image
  99. predictions[i] = (
  100. target.detach().cpu().numpy(),
  101. get_model_predictions(models, data),
  102. )
  103. return predictions
  104. # Given a dictionary of predictions, select one model and eliminate the rest
  105. def select_individual_model(predictions, model_id):
  106. selected_model_predictions = {}
  107. for key, value in predictions.items():
  108. selected_model_predictions[key] = (
  109. value[0],
  110. {model_id: value[1][str(model_id)]},
  111. )
  112. return selected_model_predictions
  113. # Given a dictionary of predictions, select a subset of models and eliminate the rest
  114. # predictions dictory of the form {data_id: (target, {model_id: prediction})}
  115. def select_subset_models(predictions, model_ids):
  116. selected_model_predictions = {}
  117. for key, value in predictions.items():
  118. target = value[0]
  119. model_predictions = value[1]
  120. # Filter the model predictions, only keeping selected models
  121. selected_model_predictions[key] = (
  122. target,
  123. {model_id: model_predictions[str(model_id + 1)] for model_id in model_ids},
  124. )
  125. return selected_model_predictions
  126. # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
  127. # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
  128. def calculate_statistics(predictions):
  129. # Create DataFrame with columns for each statistic
  130. stats_df = pd.DataFrame(
  131. columns=[
  132. 'mean',
  133. 'stdev',
  134. 'entropy',
  135. 'confidence',
  136. 'correct',
  137. 'predicted',
  138. 'actual',
  139. ]
  140. )
  141. # First, loop through each prediction
  142. for key, value in predictions.items():
  143. target = value[0]
  144. model_predictions = list(value[1].values())
  145. # Calculate the mean and stdev of predictions
  146. mean = np.squeeze(np.mean(model_predictions, axis=0))
  147. stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
  148. # Calculate the entropy of the predictions
  149. entropy = met.entropy(mean)
  150. # Calculate confidence
  151. confidence = (np.max(mean) - 0.5) * 2
  152. # Calculate predicted and actual
  153. predicted = np.argmax(mean)
  154. actual = np.argmax(target)
  155. # Determine if the prediction is correct
  156. correct = predicted == actual
  157. # Add the statistics to the dataframe
  158. stats_df.loc[key] = [
  159. mean,
  160. stdev,
  161. entropy,
  162. confidence,
  163. correct,
  164. predicted,
  165. actual,
  166. ]
  167. return stats_df
  168. # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
  169. # Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
  170. def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
  171. # Gives a dataframe
  172. percentile_df = statistics[statistic_name].quantile(
  173. q=np.linspace(0.05, 0.95, num=18)
  174. )
  175. # Dictionary of form {threshold: {metric: value}}
  176. thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
  177. for percentile, value in percentile_df.items():
  178. # Filter the statistics
  179. if low_to_high:
  180. filtered_statistics = statistics[statistics[statistic_name] < value]
  181. else:
  182. filtered_statistics = statistics[statistics[statistic_name] >= value]
  183. # Calculate accuracy and f1 score
  184. accuracy = filtered_statistics['correct'].mean()
  185. # Calculate F1 score
  186. predicted = filtered_statistics['predicted'].values
  187. actual = filtered_statistics['actual'].values
  188. f1 = metrics.f1_score(actual, predicted)
  189. # Add the metrics to the dataframe
  190. thresholds_pd.loc[percentile] = [accuracy, f1]
  191. return thresholds_pd
  192. # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
  193. # Can plot an additional line if given (used for individual results)
  194. def plot_threshold_analysis(
  195. thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
  196. ):
  197. # Initialize the plot
  198. fig, ax = plt.subplots()
  199. # Get the thresholds and metrics
  200. thresholds = list(thresholds_metric.index)
  201. metric = list(thresholds_metric.values)
  202. # Plot the metric against the threshold
  203. plt.plot(thresholds, metric, 'bo-', label='Ensemble')
  204. if additional_set is not None:
  205. # Get the thresholds and metrics
  206. thresholds = list(additional_set.index)
  207. metric = list(additional_set.values)
  208. # Plot the metric against the threshold
  209. plt.plot(thresholds, metric, 'rx-', label='Individual')
  210. if flip:
  211. ax.invert_xaxis()
  212. # Add labels
  213. plt.title(title)
  214. plt.xlabel(x_label)
  215. plt.ylabel(y_label)
  216. plt.legend()
  217. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
  218. plt.savefig(path)
  219. plt.close()
  220. # Code from https://stackoverflow.com/questions/16458340
  221. # Returns the intersections of multiple dictionaries
  222. def common_entries(*dcts):
  223. if not dcts:
  224. return
  225. for i in set(dcts[0]).intersection(*dcts[1:]):
  226. yield (i,) + tuple(d[i] for d in dcts)
  227. # Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
  228. def calculate_overall_statistics(ensemble_statistics):
  229. predicted = ensemble_statistics['predicted']
  230. actual = ensemble_statistics['actual']
  231. # New dataframe to store the statistics
  232. stats_df = pd.DataFrame(
  233. columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']
  234. ).set_index('stat')
  235. # Loop through and calculate the ECE, MCE, Brier Score, and NLL
  236. for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
  237. ece = met.ECE(predicted, ensemble_statistics[stat], actual)
  238. mce = met.MCE(predicted, ensemble_statistics[stat], actual)
  239. brier = met.brier_binary(ensemble_statistics[stat], actual)
  240. nll = met.nll_binary(ensemble_statistics[stat], actual)
  241. stats_df.loc[stat] = [ece, mce, brier, nll]
  242. return stats_df
  243. # CONFIGURATION
  244. def load_config():
  245. if os.getenv('ADL_CONFIG_PATH') is None:
  246. with open('config.toml', 'rb') as f:
  247. config = toml.load(f)
  248. else:
  249. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  250. config = toml.load(f)
  251. return config
  252. def prune_dataset(dataset, pruned_ids):
  253. pruned_dataset = []
  254. for i, (data, target) in enumerate(dataset):
  255. if i not in pruned_ids:
  256. pruned_dataset.append((data, target))
  257. return pruned_dataset
  258. def main():
  259. config = load_config()
  260. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  261. V3_PATH = ENSEMBLE_PATH + '/v3'
  262. # Create the directory if it does not exist
  263. if not os.path.exists(V3_PATH):
  264. os.makedirs(V3_PATH)
  265. # Load the models
  266. device = torch.device(config['training']['device'])
  267. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  268. # Load Dataset
  269. dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  270. f'{ENSEMBLE_PATH}/val_dataset.pt'
  271. )
  272. if config['ensemble']['run_models']:
  273. # Get thre predicitons of the ensemble
  274. ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
  275. # Save to file using pickle
  276. with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
  277. pk.dump(ensemble_predictions, f)
  278. else:
  279. # Load the predictions from file
  280. with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
  281. ensemble_predictions = pk.load(f)
  282. # Get the statistics and thresholds of the ensemble
  283. ensemble_statistics = calculate_statistics(ensemble_predictions)
  284. stdev_thresholds = conduct_threshold_analysis(
  285. ensemble_statistics, 'stdev', low_to_high=True
  286. )
  287. entropy_thresholds = conduct_threshold_analysis(
  288. ensemble_statistics, 'entropy', low_to_high=True
  289. )
  290. confidence_thresholds = conduct_threshold_analysis(
  291. ensemble_statistics, 'confidence', low_to_high=False
  292. )
  293. raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
  294. ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
  295. # Plot confidence vs standard deviation
  296. plot_statistics_versus(
  297. 'raw_confidence',
  298. 'stdev',
  299. 'Confidence',
  300. 'Standard Deviation',
  301. 'Confidence vs Standard Deviation',
  302. ensemble_statistics,
  303. f'{V3_PATH}/confidence_vs_stdev.png',
  304. annotate=True,
  305. )
  306. # Plot images - 3 weird and 3 normal
  307. # Selected from confidence vs stdev plot
  308. plot_image_grid(
  309. [279, 202, 28, 107, 27, 121],
  310. dataset,
  311. 2,
  312. f'{V3_PATH}/image_grid.png',
  313. titles=[
  314. 'Weird: 279',
  315. 'Weird: 202',
  316. 'Weird: 28',
  317. 'Normal: 107',
  318. 'Normal: 27',
  319. 'Normal: 121',
  320. ],
  321. )
  322. # Filter dataset for where confidence < .7 and stdev < .1
  323. weird_results = ensemble_statistics.loc[
  324. (
  325. (ensemble_statistics['raw_confidence'] < 0.7)
  326. & (ensemble_statistics['stdev'] < 0.1)
  327. )
  328. ]
  329. normal_results = ensemble_statistics.loc[
  330. ~(
  331. (ensemble_statistics['raw_confidence'] < 0.7)
  332. & (ensemble_statistics['stdev'] < 0.1)
  333. )
  334. ]
  335. # Get the data ids in a list
  336. # Plot the images
  337. if not os.path.exists(f'{V3_PATH}/images'):
  338. os.makedirs(f'{V3_PATH}/images/weird')
  339. os.makedirs(f'{V3_PATH}/images/normal')
  340. for i in weird_results.itertuples():
  341. id = i.Index
  342. conf = i.raw_confidence
  343. stdev = i.stdev
  344. plot_single_image(
  345. id,
  346. dataset,
  347. f'{V3_PATH}/images/weird/{id}.png',
  348. title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
  349. )
  350. for i in normal_results.itertuples():
  351. id = i.Index
  352. conf = i.raw_confidence
  353. stdev = i.stdev
  354. plot_single_image(
  355. id,
  356. dataset,
  357. f'{V3_PATH}/images/normal/{id}.png',
  358. title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
  359. )
  360. # Calculate overall statistics
  361. overall_statistics = calculate_overall_statistics(ensemble_statistics)
  362. # Print overall statistics
  363. print(overall_statistics)
  364. # Print overall ensemble statistics
  365. print('Ensemble Statistics')
  366. print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
  367. print(
  368. f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
  369. )
  370. # Get the predictions, statistics and thresholds an individual model
  371. indv_id = config['ensemble']['individual_id']
  372. indv_predictions = select_individual_model(ensemble_predictions, indv_id)
  373. indv_statistics = calculate_statistics(indv_predictions)
  374. # Calculate entropy and confidence thresholds for individual model
  375. indv_entropy_thresholds = conduct_threshold_analysis(
  376. indv_statistics, 'entropy', low_to_high=True
  377. )
  378. indv_confidence_thresholds = conduct_threshold_analysis(
  379. indv_statistics, 'confidence', low_to_high=False
  380. )
  381. # Plot the threshold analysis for standard deviation
  382. plot_threshold_analysis(
  383. stdev_thresholds['accuracy'],
  384. 'Stdev Threshold Analysis for Accuracy',
  385. 'Stdev Threshold',
  386. 'Accuracy',
  387. f'{V3_PATH}/stdev_threshold_analysis.png',
  388. flip=True,
  389. )
  390. plot_threshold_analysis(
  391. stdev_thresholds['f1'],
  392. 'Stdev Threshold Analysis for F1 Score',
  393. 'Stdev Threshold',
  394. 'F1 Score',
  395. f'{V3_PATH}/stdev_threshold_analysis_f1.png',
  396. flip=True,
  397. )
  398. # Plot the threshold analysis for entropy
  399. plot_threshold_analysis(
  400. entropy_thresholds['accuracy'],
  401. 'Entropy Threshold Analysis for Accuracy',
  402. 'Entropy Threshold',
  403. 'Accuracy',
  404. f'{V3_PATH}/entropy_threshold_analysis.png',
  405. indv_entropy_thresholds['accuracy'],
  406. flip=True,
  407. )
  408. plot_threshold_analysis(
  409. entropy_thresholds['f1'],
  410. 'Entropy Threshold Analysis for F1 Score',
  411. 'Entropy Threshold',
  412. 'F1 Score',
  413. f'{V3_PATH}/entropy_threshold_analysis_f1.png',
  414. indv_entropy_thresholds['f1'],
  415. flip=True,
  416. )
  417. # Plot the threshold analysis for confidence
  418. plot_threshold_analysis(
  419. confidence_thresholds['accuracy'],
  420. 'Confidence Threshold Analysis for Accuracy',
  421. 'Confidence Threshold',
  422. 'Accuracy',
  423. f'{V3_PATH}/confidence_threshold_analysis.png',
  424. indv_confidence_thresholds['accuracy'],
  425. )
  426. plot_threshold_analysis(
  427. confidence_thresholds['f1'],
  428. 'Confidence Threshold Analysis for F1 Score',
  429. 'Confidence Threshold',
  430. 'F1 Score',
  431. f'{V3_PATH}/confidence_threshold_analysis_f1.png',
  432. indv_confidence_thresholds['f1'],
  433. )
  434. if __name__ == '__main__':
  435. main()