threshold_refac.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. import pandas as pd
  2. import numpy as np
  3. import os
  4. import tomli as toml
  5. import torch
  6. import matplotlib.pyplot as plt
  7. import sklearn.metrics as metrics
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. import matplotlib.ticker as ticker
  11. import glob
  12. import pickle as pk
  13. import warnings
  14. import utils.models.cnn
  15. import torch.serialization
  16. import utils.models.layers
  17. import utils.data.datasets
  18. import pandas.core.frame
  19. import pandas.core.internals.managers
  20. torch.serialization.add_safe_globals(
  21. [
  22. torch.nn.modules.linear.Linear,
  23. torch.nn.modules.batchnorm.BatchNorm3d,
  24. torch.nn.modules.container.Sequential,
  25. torch.nn.modules.activation.ELU,
  26. utils.models.layers.ConvBlock,
  27. utils.models.layers.FullConnBlock,
  28. utils.models.layers.SplitConvBlock,
  29. utils.models.layers.SepConv3d,
  30. utils.models.cnn.CNN_Image_Section,
  31. torch.nn.modules.dropout.Dropout,
  32. torch.nn.modules.conv.Conv3d,
  33. torch.nn.modules.batchnorm.BatchNorm1d,
  34. utils.models.layers.MidFlowBlock,
  35. torch.nn.modules.activation.Softmax,
  36. utils.models.layers.SepConvBlock,
  37. utils.models.cnn.CNN,
  38. utils.data.datasets.ADNIDataset,
  39. pandas.core.frame.DataFrame,
  40. pandas.core.internals.managers.BlockManager,
  41. pandas._libs.internals._unpickle_block,
  42. ]
  43. )
  44. warnings.filterwarnings("error")
  45. def plot_image_grid(image_ids, dataset, rows, path, titles=None):
  46. fig, axs = plt.subplots(rows, len(image_ids) // rows)
  47. for i, ax in enumerate(axs.flat):
  48. image_id = image_ids[i]
  49. image = dataset[image_id][0][0].squeeze().cpu().numpy()
  50. # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
  51. image = image[:, :, 45]
  52. ax.imshow(image, cmap="gray")
  53. ax.axis("off")
  54. if titles is not None:
  55. ax.set_title(titles[i])
  56. plt.savefig(path)
  57. plt.close()
  58. def plot_single_image(image_id, dataset, path, title=None):
  59. fig, ax = plt.subplots()
  60. image = dataset[image_id][0][0].squeeze().cpu().numpy()
  61. # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
  62. image = image[:, :, 45]
  63. ax.imshow(image, cmap="gray")
  64. ax.axis("off")
  65. if title is not None:
  66. ax.set_title(title)
  67. plt.savefig(path)
  68. plt.close()
  69. # Given a dataframe of the form {data_id: (stat_1, stat_2, ..., correct)}, plot the two statistics against each other and color by correctness
  70. def plot_statistics_versus(
  71. stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
  72. ):
  73. # Get correct predictions and incorrect predictions dataframes
  74. corr_df = dataframe[dataframe["correct"]]
  75. incorr_df = dataframe[~dataframe["correct"]]
  76. # Plot the correct and incorrect predictions
  77. fig, ax = plt.subplots()
  78. ax.scatter(corr_df[stat_1], corr_df[stat_2], c="green", label="Correct")
  79. ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c="red", label="Incorrect")
  80. ax.legend()
  81. ax.set_xlabel(xaxis_name)
  82. ax.set_ylabel(yaxis_name)
  83. ax.set_title(title)
  84. if annotate:
  85. print("DEBUG -- REMOVE: Annotating")
  86. # label correct points green
  87. for row in dataframe[[stat_1, stat_2]].itertuples():
  88. plt.text(row[1], row[2], row[0], fontsize=6, color="black")
  89. plt.savefig(path)
  90. # Models is a dictionary with the model ids as keys and the model data as values
  91. def get_model_predictions(models, data):
  92. predictions = {}
  93. for model_id, model in models.items():
  94. model.eval()
  95. with torch.no_grad():
  96. # Get the predictions
  97. output = model(data)
  98. predictions[model_id] = output.detach().cpu().numpy()
  99. return predictions
  100. def load_models_v2(folder, device):
  101. glob_path = os.path.join(folder, "*.pt")
  102. model_files = glob.glob(glob_path)
  103. model_dict = {}
  104. for model_file in model_files:
  105. print(model_file)
  106. model = torch.load(model_file, map_location=device)
  107. model_id = os.path.basename(model_file).split("_")[0]
  108. model_dict[model_id] = model
  109. if len(model_dict) == 0:
  110. raise FileNotFoundError("No models found in the specified directory: " + folder)
  111. return model_dict
  112. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  113. def preprocess_data(data, device):
  114. mri, xls = data
  115. mri = mri.unsqueeze(0).to(device)
  116. xls = xls.unsqueeze(0).to(device)
  117. return (mri, xls)
  118. def ensemble_dataset_predictions(models, dataset, device):
  119. # For each datapoint, get the predictions of each model
  120. predictions = {}
  121. for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
  122. # Preprocess data
  123. data = preprocess_data(data, device)
  124. # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
  125. # The key is the id of the image
  126. predictions[i] = (
  127. target.detach().cpu().numpy(),
  128. get_model_predictions(models, data),
  129. )
  130. return predictions
  131. # Given a dictionary of predictions, select one model and eliminate the rest
  132. def select_individual_model(predictions, model_id):
  133. selected_model_predictions = {}
  134. for key, value in predictions.items():
  135. selected_model_predictions[key] = (
  136. value[0],
  137. {model_id: value[1][str(model_id)]},
  138. )
  139. return selected_model_predictions
  140. # Given a dictionary of predictions, select a subset of models and eliminate the rest
  141. # predictions dictory of the form {data_id: (target, {model_id: prediction})}
  142. def select_subset_models(predictions, model_ids):
  143. selected_model_predictions = {}
  144. for key, value in predictions.items():
  145. target = value[0]
  146. model_predictions = value[1]
  147. # Filter the model predictions, only keeping selected models
  148. selected_model_predictions[key] = (
  149. target,
  150. {model_id: model_predictions[str(model_id + 1)] for model_id in model_ids},
  151. )
  152. return selected_model_predictions
  153. # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
  154. # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
  155. def calculate_statistics(predictions):
  156. # Create DataFrame with columns for each statistic
  157. stats_df = pd.DataFrame(
  158. columns=[
  159. "mean",
  160. "stdev",
  161. "entropy",
  162. "confidence",
  163. "correct",
  164. "predicted",
  165. "actual",
  166. ]
  167. )
  168. # First, loop through each prediction
  169. for key, value in predictions.items():
  170. target = value[0]
  171. model_predictions = list(value[1].values())
  172. # Calculate the mean and stdev of predictions
  173. mean = np.squeeze(np.mean(model_predictions, axis=0))
  174. stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
  175. # Calculate the entropy of the predictions
  176. entropy = met.entropy(mean)
  177. # Calculate confidence
  178. confidence = (np.max(mean) - 0.5) * 2
  179. # Calculate predicted and actual
  180. predicted = np.argmax(mean)
  181. actual = np.argmax(target)
  182. # Determine if the prediction is correct
  183. correct = predicted == actual
  184. # Add the statistics to the dataframe
  185. stats_df.loc[key] = [
  186. mean,
  187. stdev,
  188. entropy,
  189. confidence,
  190. correct,
  191. predicted,
  192. actual,
  193. ]
  194. return stats_df
  195. # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
  196. # Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
  197. def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
  198. # Gives a dataframe
  199. percentile_df = statistics[statistic_name].quantile(
  200. q=np.linspace(0.05, 0.95, num=18)
  201. )
  202. # Dictionary of form {threshold: {metric: value}}
  203. thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=["accuracy", "f1"])
  204. for percentile, value in percentile_df.items():
  205. # Filter the statistics
  206. if low_to_high:
  207. filtered_statistics = statistics[statistics[statistic_name] < value]
  208. else:
  209. filtered_statistics = statistics[statistics[statistic_name] >= value]
  210. # Calculate accuracy and f1 score
  211. accuracy = filtered_statistics["correct"].mean()
  212. # Calculate F1 score
  213. predicted = filtered_statistics["predicted"].values
  214. actual = filtered_statistics["actual"].values
  215. f1 = metrics.f1_score(actual, predicted)
  216. # Add the metrics to the dataframe
  217. thresholds_pd.loc[percentile] = [accuracy, f1]
  218. return thresholds_pd
  219. # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
  220. # Can plot an additional line if given (used for individual results)
  221. def plot_threshold_analysis(
  222. thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
  223. ):
  224. # Initialize the plot
  225. fig, ax = plt.subplots()
  226. # Get the thresholds and metrics
  227. thresholds = list(thresholds_metric.index)
  228. metric = list(thresholds_metric.values)
  229. # Plot the metric against the threshold
  230. plt.plot(thresholds, metric, "bo-", label="Ensemble")
  231. if additional_set is not None:
  232. # Get the thresholds and metrics
  233. thresholds = list(additional_set.index)
  234. metric = list(additional_set.values)
  235. # Plot the metric against the threshold
  236. plt.plot(thresholds, metric, "rx-", label="Individual")
  237. if flip:
  238. ax.invert_xaxis()
  239. # Add labels
  240. plt.title(title)
  241. plt.xlabel(x_label)
  242. plt.ylabel(y_label)
  243. plt.legend()
  244. ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
  245. plt.savefig(path)
  246. plt.close()
  247. # Code from https://stackoverflow.com/questions/16458340
  248. # Returns the intersections of multiple dictionaries
  249. def common_entries(*dcts):
  250. if not dcts:
  251. return
  252. for i in set(dcts[0]).intersection(*dcts[1:]):
  253. yield (i,) + tuple(d[i] for d in dcts)
  254. # Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
  255. def calculate_overall_statistics(ensemble_statistics):
  256. predicted = ensemble_statistics["predicted"]
  257. actual = ensemble_statistics["actual"]
  258. # New dataframe to store the statistics
  259. stats_df = pd.DataFrame(
  260. columns=["stat", "ECE", "MCE", "Brier Score", "NLL"]
  261. ).set_index("stat")
  262. # Loop through and calculate the ECE, MCE, Brier Score, and NLL
  263. for stat in ["confidence", "entropy", "stdev", "raw_confidence"]:
  264. ece = met.ECE(predicted, ensemble_statistics[stat], actual)
  265. mce = met.MCE(predicted, ensemble_statistics[stat], actual)
  266. brier = met.brier_binary(ensemble_statistics[stat], actual)
  267. nll = met.nll_binary(ensemble_statistics[stat], actual)
  268. stats_df.loc[stat] = [ece, mce, brier, nll]
  269. return stats_df
  270. # CONFIGURATION
  271. def load_config():
  272. if os.getenv("ADL_CONFIG_PATH") is None:
  273. with open("config.toml", "rb") as f:
  274. config = toml.load(f)
  275. else:
  276. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  277. config = toml.load(f)
  278. return config
  279. def prune_dataset(dataset, pruned_ids):
  280. pruned_dataset = []
  281. for i, (data, target) in enumerate(dataset):
  282. if i not in pruned_ids:
  283. pruned_dataset.append((data, target))
  284. return pruned_dataset
  285. def main():
  286. config = load_config()
  287. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  288. V3_PATH = ENSEMBLE_PATH + "/v3"
  289. # Create the directory if it does not exist
  290. if not os.path.exists(V3_PATH):
  291. os.makedirs(V3_PATH)
  292. # Load the models
  293. device = torch.device(config["training"]["device"])
  294. models = load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
  295. # Load Dataset
  296. dataset = torch.load(
  297. f"{ENSEMBLE_PATH}/test_dataset.pt", weights_only=False
  298. ) + torch.load(f"{ENSEMBLE_PATH}/val_dataset.pt", weights_only=False)
  299. if config["ensemble"]["run_models"]:
  300. # Get thre predicitons of the ensemble
  301. ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
  302. # Save to file using pickle
  303. with open(f"{V3_PATH}/ensemble_predictions.pk", "wb") as f:
  304. pk.dump(ensemble_predictions, f)
  305. else:
  306. # Load the predictions from file
  307. with open(f"{V3_PATH}/ensemble_predictions.pk", "rb") as f:
  308. ensemble_predictions = pk.load(f)
  309. # Get the statistics and thresholds of the ensemble
  310. ensemble_statistics = calculate_statistics(ensemble_predictions)
  311. stdev_thresholds = conduct_threshold_analysis(
  312. ensemble_statistics, "stdev", low_to_high=True
  313. )
  314. entropy_thresholds = conduct_threshold_analysis(
  315. ensemble_statistics, "entropy", low_to_high=True
  316. )
  317. confidence_thresholds = conduct_threshold_analysis(
  318. ensemble_statistics, "confidence", low_to_high=False
  319. )
  320. raw_confidence = ensemble_statistics["confidence"].apply(lambda x: (x / 2) + 0.5)
  321. ensemble_statistics.insert(4, "raw_confidence", raw_confidence)
  322. # Plot confidence vs standard deviation
  323. plot_statistics_versus(
  324. "raw_confidence",
  325. "stdev",
  326. "Confidence",
  327. "Standard Deviation",
  328. "Confidence vs Standard Deviation",
  329. ensemble_statistics,
  330. f"{V3_PATH}/confidence_vs_stdev.png",
  331. annotate=True,
  332. )
  333. # Plot images - 3 weird and 3 normal
  334. # Selected from confidence vs stdev plot
  335. plot_image_grid(
  336. [279, 202, 28, 107, 27, 121],
  337. dataset,
  338. 2,
  339. f"{V3_PATH}/image_grid.png",
  340. titles=[
  341. "Weird: 279",
  342. "Weird: 202",
  343. "Weird: 28",
  344. "Normal: 107",
  345. "Normal: 27",
  346. "Normal: 121",
  347. ],
  348. )
  349. # Filter dataset for where confidence < .7 and stdev < .1
  350. weird_results = ensemble_statistics.loc[
  351. (
  352. (ensemble_statistics["raw_confidence"] < 0.7)
  353. & (ensemble_statistics["stdev"] < 0.1)
  354. )
  355. ]
  356. normal_results = ensemble_statistics.loc[
  357. ~(
  358. (ensemble_statistics["raw_confidence"] < 0.7)
  359. & (ensemble_statistics["stdev"] < 0.1)
  360. )
  361. ]
  362. # Get the data ids in a list
  363. # Plot the images
  364. if not os.path.exists(f"{V3_PATH}/images"):
  365. os.makedirs(f"{V3_PATH}/images/weird")
  366. os.makedirs(f"{V3_PATH}/images/normal")
  367. for i in weird_results.itertuples():
  368. id = i.Index
  369. conf = i.raw_confidence
  370. stdev = i.stdev
  371. plot_single_image(
  372. id,
  373. dataset,
  374. f"{V3_PATH}/images/weird/{id}.png",
  375. title=f"ID: {id}, Confidence: {conf}, Stdev: {stdev}",
  376. )
  377. for i in normal_results.itertuples():
  378. id = i.Index
  379. conf = i.raw_confidence
  380. stdev = i.stdev
  381. plot_single_image(
  382. id,
  383. dataset,
  384. f"{V3_PATH}/images/normal/{id}.png",
  385. title=f"ID: {id}, Confidence: {conf}, Stdev: {stdev}",
  386. )
  387. # Calculate overall statistics
  388. overall_statistics = calculate_overall_statistics(ensemble_statistics)
  389. # Print overall statistics
  390. print(overall_statistics)
  391. # Print overall ensemble statistics
  392. print("Ensemble Statistics")
  393. print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
  394. print(
  395. f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
  396. )
  397. # Get the predictions, statistics and thresholds an individual model
  398. indv_id = config["ensemble"]["individual_id"]
  399. indv_predictions = select_individual_model(ensemble_predictions, indv_id)
  400. indv_statistics = calculate_statistics(indv_predictions)
  401. # Calculate entropy and confidence thresholds for individual model
  402. indv_entropy_thresholds = conduct_threshold_analysis(
  403. indv_statistics, "entropy", low_to_high=True
  404. )
  405. indv_confidence_thresholds = conduct_threshold_analysis(
  406. indv_statistics, "confidence", low_to_high=False
  407. )
  408. # Plot the threshold analysis for standard deviation
  409. plot_threshold_analysis(
  410. stdev_thresholds["accuracy"],
  411. "Stdev Threshold Analysis for Accuracy",
  412. "Stdev Threshold",
  413. "Accuracy",
  414. f"{V3_PATH}/stdev_threshold_analysis.png",
  415. flip=True,
  416. )
  417. plot_threshold_analysis(
  418. stdev_thresholds["f1"],
  419. "Stdev Threshold Analysis for F1 Score",
  420. "Stdev Threshold",
  421. "F1 Score",
  422. f"{V3_PATH}/stdev_threshold_analysis_f1.png",
  423. flip=True,
  424. )
  425. # Plot the threshold analysis for entropy
  426. plot_threshold_analysis(
  427. entropy_thresholds["accuracy"],
  428. "Entropy Threshold Analysis for Accuracy",
  429. "Entropy Threshold",
  430. "Accuracy",
  431. f"{V3_PATH}/entropy_threshold_analysis.png",
  432. indv_entropy_thresholds["accuracy"],
  433. flip=True,
  434. )
  435. plot_threshold_analysis(
  436. entropy_thresholds["f1"],
  437. "Entropy Threshold Analysis for F1 Score",
  438. "Entropy Threshold",
  439. "F1 Score",
  440. f"{V3_PATH}/entropy_threshold_analysis_f1.png",
  441. indv_entropy_thresholds["f1"],
  442. flip=True,
  443. )
  444. # Plot the threshold analysis for confidence
  445. plot_threshold_analysis(
  446. confidence_thresholds["accuracy"],
  447. "Confidence Threshold Analysis for Accuracy",
  448. "Confidence Threshold",
  449. "Accuracy",
  450. f"{V3_PATH}/confidence_threshold_analysis.png",
  451. indv_confidence_thresholds["accuracy"],
  452. )
  453. plot_threshold_analysis(
  454. confidence_thresholds["f1"],
  455. "Confidence Threshold Analysis for F1 Score",
  456. "Confidence Threshold",
  457. "F1 Score",
  458. f"{V3_PATH}/confidence_threshold_analysis_f1.png",
  459. indv_confidence_thresholds["f1"],
  460. )
  461. if __name__ == "__main__":
  462. main()