123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538 |
- import pandas as pd
- import numpy as np
- import os
- import tomli as toml
- from utils.data.datasets import prepare_datasets
- import utils.ensemble as ens
- import torch
- import matplotlib.pyplot as plt
- import sklearn.metrics as metrics
- from tqdm import tqdm
- import utils.metrics as met
- import itertools as it
- import matplotlib.ticker as ticker
- import glob
- import pickle as pk
- import warnings
- import random as rand
- warnings.filterwarnings('error')
- def plot_image_grid(image_ids, dataset, rows, path, titles=None):
- fig, axs = plt.subplots(rows, len(image_ids) // rows)
- for i, ax in enumerate(axs.flat):
- image_id = image_ids[i]
- image = dataset[image_id][0][0].squeeze().cpu().numpy()
- # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
- image = image[:, :, 45]
- ax.imshow(image, cmap='gray')
- ax.axis('off')
- if titles is not None:
- ax.set_title(titles[i])
- plt.savefig(path)
- plt.close()
- def plot_single_image(image_id, dataset, path, title=None):
- fig, ax = plt.subplots()
- image = dataset[image_id][0][0].squeeze().cpu().numpy()
- # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
- image = image[:, :, 45]
- ax.imshow(image, cmap='gray')
- ax.axis('off')
- if title is not None:
- ax.set_title(title)
- plt.savefig(path)
- plt.close()
- # Given a dataframe of the form {data_id: (stat_1, stat_2, ..., correct)}, plot the two statistics against each other and color by correctness
- def plot_statistics_versus(
- stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
- ):
- # Get correct predictions and incorrect predictions dataframes
- corr_df = dataframe[dataframe['correct']]
- incorr_df = dataframe[~dataframe['correct']]
- # Plot the correct and incorrect predictions
- fig, ax = plt.subplots()
- ax.scatter(corr_df[stat_1], corr_df[stat_2], c='green', label='Correct')
- ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c='red', label='Incorrect')
- ax.legend()
- ax.set_xlabel(xaxis_name)
- ax.set_ylabel(yaxis_name)
- ax.set_title(title)
- if annotate:
- print('DEBUG -- REMOVE: Annotating')
- # label correct points green
- for row in dataframe[[stat_1, stat_2]].itertuples():
- plt.text(row[1], row[2], row[0], fontsize=6, color='black')
- plt.savefig(path)
- # Models is a dictionary with the model ids as keys and the model data as values
- def get_model_predictions(models, data):
- predictions = {}
- for model_id, model in models.items():
- model.eval()
- with torch.no_grad():
- # Get the predictions
- output = model(data)
- predictions[model_id] = output.detach().cpu().numpy()
- return predictions
- def load_models_v2(folder, device):
- glob_path = os.path.join(folder, '*.pt')
- model_files = glob.glob(glob_path)
- model_dict = {}
- for model_file in model_files:
- model = torch.load(model_file, map_location=device)
- model_id = os.path.basename(model_file).split('_')[0]
- model_dict[model_id] = model
- if len(model_dict) == 0:
- raise FileNotFoundError('No models found in the specified directory: ' + folder)
- return model_dict
- # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
- def preprocess_data(data, device):
- mri, xls = data
- mri = mri.unsqueeze(0).to(device)
- xls = xls.unsqueeze(0).to(device)
- return (mri, xls)
- def ensemble_dataset_predictions(models, dataset, device):
- # For each datapoint, get the predictions of each model
- predictions = {}
- for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
- # Preprocess data
- data = preprocess_data(data, device)
- # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
- # The key is the id of the image
- predictions[i] = (
- target.detach().cpu().numpy(),
- get_model_predictions(models, data),
- )
- return predictions
- # Given a dictionary of predictions, select one model and eliminate the rest
- def select_individual_model(predictions, model_id):
- selected_model_predictions = {}
- for key, value in predictions.items():
- selected_model_predictions[key] = (
- value[0],
- {model_id: value[1][str(model_id)]},
- )
- return selected_model_predictions
- # Given a dictionary of predictions, select a subset of models and eliminate the rest
- # predictions dictory of the form {data_id: (target, {model_id: prediction})}
- def select_subset_models(predictions, model_ids):
- selected_model_predictions = {}
- for key, value in predictions.items():
- target = value[0]
- model_predictions = value[1]
- # Filter the model predictions, only keeping selected models
- selected_model_predictions[key] = (
- target,
- {model_id: model_predictions[str(model_id + 1)] for model_id in model_ids},
- )
- return selected_model_predictions
- # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
- # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
- def calculate_statistics(predictions):
- # Create DataFrame with columns for each statistic
- stats_df = pd.DataFrame(
- columns=[
- 'mean',
- 'stdev',
- 'entropy',
- 'confidence',
- 'correct',
- 'predicted',
- 'actual',
- ]
- )
- # First, loop through each prediction
- for key, value in predictions.items():
- target = value[0]
- model_predictions = list(value[1].values())
- # Calculate the mean and stdev of predictions
- mean = np.squeeze(np.mean(model_predictions, axis=0))
- stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
- # Calculate the entropy of the predictions
- entropy = met.entropy(mean)
- # Calculate confidence
- confidence = (np.max(mean) - 0.5) * 2
- # Calculate predicted and actual
- predicted = np.argmax(mean)
- actual = np.argmax(target)
- # Determine if the prediction is correct
- correct = predicted == actual
- # Add the statistics to the dataframe
- stats_df.loc[key] = [
- mean,
- stdev,
- entropy,
- confidence,
- correct,
- predicted,
- actual,
- ]
- return stats_df
- # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
- # Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
- def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
- # Gives a dataframe
- percentile_df = statistics[statistic_name].quantile(
- q=np.linspace(0.05, 0.95, num=18)
- )
- # Dictionary of form {threshold: {metric: value}}
- thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
- for percentile, value in percentile_df.items():
- # Filter the statistics
- if low_to_high:
- filtered_statistics = statistics[statistics[statistic_name] < value]
- else:
- filtered_statistics = statistics[statistics[statistic_name] >= value]
- # Calculate accuracy and f1 score
- accuracy = filtered_statistics['correct'].mean()
- # Calculate F1 score
- predicted = filtered_statistics['predicted'].values
- actual = filtered_statistics['actual'].values
- f1 = metrics.f1_score(actual, predicted)
- # Add the metrics to the dataframe
- thresholds_pd.loc[percentile] = [accuracy, f1]
- return thresholds_pd
- # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
- # Can plot an additional line if given (used for individual results)
- def plot_threshold_analysis(
- thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
- ):
- # Initialize the plot
- fig, ax = plt.subplots()
- # Get the thresholds and metrics
- thresholds = list(thresholds_metric.index)
- metric = list(thresholds_metric.values)
- # Plot the metric against the threshold
- plt.plot(thresholds, metric, 'bo-', label='Ensemble')
- if additional_set is not None:
- # Get the thresholds and metrics
- thresholds = list(additional_set.index)
- metric = list(additional_set.values)
- # Plot the metric against the threshold
- plt.plot(thresholds, metric, 'rx-', label='Individual')
- if flip:
- ax.invert_xaxis()
- # Add labels
- plt.title(title)
- plt.xlabel(x_label)
- plt.ylabel(y_label)
- plt.legend()
- ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
- plt.savefig(path)
- plt.close()
- # Code from https://stackoverflow.com/questions/16458340
- # Returns the intersections of multiple dictionaries
- def common_entries(*dcts):
- if not dcts:
- return
- for i in set(dcts[0]).intersection(*dcts[1:]):
- yield (i,) + tuple(d[i] for d in dcts)
- # Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
- def calculate_overall_statistics(ensemble_statistics):
- predicted = ensemble_statistics['predicted']
- actual = ensemble_statistics['actual']
- # New dataframe to store the statistics
- stats_df = pd.DataFrame(
- columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']
- ).set_index('stat')
- # Loop through and calculate the ECE, MCE, Brier Score, and NLL
- for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
- ece = met.ECE(predicted, ensemble_statistics[stat], actual)
- mce = met.MCE(predicted, ensemble_statistics[stat], actual)
- brier = met.brier_binary(ensemble_statistics[stat], actual)
- nll = met.nll_binary(ensemble_statistics[stat], actual)
- stats_df.loc[stat] = [ece, mce, brier, nll]
- return stats_df
- # CONFIGURATION
- def load_config():
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open('config.toml', 'rb') as f:
- config = toml.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = toml.load(f)
- return config
- def prune_dataset(dataset, pruned_ids):
- pruned_dataset = []
- for i, (data, target) in enumerate(dataset):
- if i not in pruned_ids:
- pruned_dataset.append((data, target))
- return pruned_dataset
- def main():
- config = load_config()
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V3_PATH = ENSEMBLE_PATH + '/v3'
- # Create the directory if it does not exist
- if not os.path.exists(V3_PATH):
- os.makedirs(V3_PATH)
- # Load the models
- device = torch.device(config['training']['device'])
- models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
- # Load Dataset
- dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
- f'{ENSEMBLE_PATH}/val_dataset.pt'
- )
- if config['ensemble']['run_models']:
- # Get thre predicitons of the ensemble
- ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
- # Save to file using pickle
- with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
- pk.dump(ensemble_predictions, f)
- else:
- # Load the predictions from file
- with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
- ensemble_predictions = pk.load(f)
- # Get the statistics and thresholds of the ensemble
- ensemble_statistics = calculate_statistics(ensemble_predictions)
- stdev_thresholds = conduct_threshold_analysis(
- ensemble_statistics, 'stdev', low_to_high=True
- )
- entropy_thresholds = conduct_threshold_analysis(
- ensemble_statistics, 'entropy', low_to_high=True
- )
- confidence_thresholds = conduct_threshold_analysis(
- ensemble_statistics, 'confidence', low_to_high=False
- )
- raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
- ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
- # Plot confidence vs standard deviation
- plot_statistics_versus(
- 'raw_confidence',
- 'stdev',
- 'Confidence',
- 'Standard Deviation',
- 'Confidence vs Standard Deviation',
- ensemble_statistics,
- f'{V3_PATH}/confidence_vs_stdev.png',
- annotate=True,
- )
- # Plot images - 3 weird and 3 normal
- # Selected from confidence vs stdev plot
- plot_image_grid(
- [279, 202, 28, 107, 27, 121],
- dataset,
- 2,
- f'{V3_PATH}/image_grid.png',
- titles=[
- 'Weird: 279',
- 'Weird: 202',
- 'Weird: 28',
- 'Normal: 107',
- 'Normal: 27',
- 'Normal: 121',
- ],
- )
- # Filter dataset for where confidence < .7 and stdev < .1
- weird_results = ensemble_statistics.loc[
- (
- (ensemble_statistics['raw_confidence'] < 0.7)
- & (ensemble_statistics['stdev'] < 0.1)
- )
- ]
- normal_results = ensemble_statistics.loc[
- ~(
- (ensemble_statistics['raw_confidence'] < 0.7)
- & (ensemble_statistics['stdev'] < 0.1)
- )
- ]
- # Get the data ids in a list
- # Plot the images
- if not os.path.exists(f'{V3_PATH}/images'):
- os.makedirs(f'{V3_PATH}/images/weird')
- os.makedirs(f'{V3_PATH}/images/normal')
- for i in weird_results.itertuples():
- id = i.Index
- conf = i.raw_confidence
- stdev = i.stdev
- plot_single_image(
- id,
- dataset,
- f'{V3_PATH}/images/weird/{id}.png',
- title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
- )
- for i in normal_results.itertuples():
- id = i.Index
- conf = i.raw_confidence
- stdev = i.stdev
- plot_single_image(
- id,
- dataset,
- f'{V3_PATH}/images/normal/{id}.png',
- title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
- )
- # Calculate overall statistics
- overall_statistics = calculate_overall_statistics(ensemble_statistics)
- # Print overall statistics
- print(overall_statistics)
- # Print overall ensemble statistics
- print('Ensemble Statistics')
- print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
- print(
- f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
- )
- # Get the predictions, statistics and thresholds an individual model
- indv_id = config['ensemble']['individual_id']
- indv_predictions = select_individual_model(ensemble_predictions, indv_id)
- indv_statistics = calculate_statistics(indv_predictions)
- # Calculate entropy and confidence thresholds for individual model
- indv_entropy_thresholds = conduct_threshold_analysis(
- indv_statistics, 'entropy', low_to_high=True
- )
- indv_confidence_thresholds = conduct_threshold_analysis(
- indv_statistics, 'confidence', low_to_high=False
- )
- # Plot the threshold analysis for standard deviation
- plot_threshold_analysis(
- stdev_thresholds['accuracy'],
- 'Stdev Threshold Analysis for Accuracy',
- 'Stdev Threshold',
- 'Accuracy',
- f'{V3_PATH}/stdev_threshold_analysis.png',
- flip=True,
- )
- plot_threshold_analysis(
- stdev_thresholds['f1'],
- 'Stdev Threshold Analysis for F1 Score',
- 'Stdev Threshold',
- 'F1 Score',
- f'{V3_PATH}/stdev_threshold_analysis_f1.png',
- flip=True,
- )
- # Plot the threshold analysis for entropy
- plot_threshold_analysis(
- entropy_thresholds['accuracy'],
- 'Entropy Threshold Analysis for Accuracy',
- 'Entropy Threshold',
- 'Accuracy',
- f'{V3_PATH}/entropy_threshold_analysis.png',
- indv_entropy_thresholds['accuracy'],
- flip=True,
- )
- plot_threshold_analysis(
- entropy_thresholds['f1'],
- 'Entropy Threshold Analysis for F1 Score',
- 'Entropy Threshold',
- 'F1 Score',
- f'{V3_PATH}/entropy_threshold_analysis_f1.png',
- indv_entropy_thresholds['f1'],
- flip=True,
- )
- # Plot the threshold analysis for confidence
- plot_threshold_analysis(
- confidence_thresholds['accuracy'],
- 'Confidence Threshold Analysis for Accuracy',
- 'Confidence Threshold',
- 'Accuracy',
- f'{V3_PATH}/confidence_threshold_analysis.png',
- indv_confidence_thresholds['accuracy'],
- )
- plot_threshold_analysis(
- confidence_thresholds['f1'],
- 'Confidence Threshold Analysis for F1 Score',
- 'Confidence Threshold',
- 'F1 Score',
- f'{V3_PATH}/confidence_threshold_analysis_f1.png',
- indv_confidence_thresholds['f1'],
- )
- if __name__ == '__main__':
- main()
|