|
@@ -14,24 +14,67 @@ import matplotlib.ticker as ticker
|
|
import glob
|
|
import glob
|
|
import pickle as pk
|
|
import pickle as pk
|
|
import warnings
|
|
import warnings
|
|
|
|
+import random as rand
|
|
|
|
|
|
warnings.filterwarnings('error')
|
|
warnings.filterwarnings('error')
|
|
|
|
|
|
-# CONFIGURATION
|
|
|
|
-if os.getenv('ADL_CONFIG_PATH') is None:
|
|
|
|
- with open('config.toml', 'rb') as f:
|
|
|
|
- config = toml.load(f)
|
|
|
|
-else:
|
|
|
|
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
|
|
|
|
- config = toml.load(f)
|
|
|
|
|
|
|
|
-ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
|
|
|
|
+def plot_image_grid(image_ids, dataset, rows, path, titles=None):
|
|
|
|
+ fig, axs = plt.subplots(rows, len(image_ids) // rows)
|
|
|
|
+ for i, ax in enumerate(axs.flat):
|
|
|
|
+ image_id = image_ids[i]
|
|
|
|
+ image = dataset[image_id][0][0].squeeze().cpu().numpy()
|
|
|
|
+ # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
|
|
|
|
+ image = image[:, :, 45]
|
|
|
|
+
|
|
|
|
+ ax.imshow(image, cmap='gray')
|
|
|
|
+ ax.axis('off')
|
|
|
|
+ if titles is not None:
|
|
|
|
+ ax.set_title(titles[i])
|
|
|
|
+
|
|
|
|
+ plt.savefig(path)
|
|
|
|
+ plt.close()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def plot_single_image(image_id, dataset, path, title=None):
|
|
|
|
+ fig, ax = plt.subplots()
|
|
|
|
+ image = dataset[image_id][0][0].squeeze().cpu().numpy()
|
|
|
|
+ # We now have a 3d image of size (91, 109, 91), and we want to take a slice from the middle of the image
|
|
|
|
+ image = image[:, :, 45]
|
|
|
|
+
|
|
|
|
+ ax.imshow(image, cmap='gray')
|
|
|
|
+ ax.axis('off')
|
|
|
|
+ if title is not None:
|
|
|
|
+ ax.set_title(title)
|
|
|
|
+
|
|
|
|
+ plt.savefig(path)
|
|
|
|
+ plt.close()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Given a dataframe of the form {data_id: (stat_1, stat_2, ..., correct)}, plot the two statistics against each other and color by correctness
|
|
|
|
+def plot_statistics_versus(
|
|
|
|
+ stat_1, stat_2, xaxis_name, yaxis_name, title, dataframe, path, annotate=False
|
|
|
|
+):
|
|
|
|
+ # Get correct predictions and incorrect predictions dataframes
|
|
|
|
+ corr_df = dataframe[dataframe['correct']]
|
|
|
|
+ incorr_df = dataframe[~dataframe['correct']]
|
|
|
|
|
|
-V3_PATH = ENSEMBLE_PATH + '/v3'
|
|
|
|
|
|
+ # Plot the correct and incorrect predictions
|
|
|
|
+ fig, ax = plt.subplots()
|
|
|
|
+ ax.scatter(corr_df[stat_1], corr_df[stat_2], c='green', label='Correct')
|
|
|
|
+ ax.scatter(incorr_df[stat_1], incorr_df[stat_2], c='red', label='Incorrect')
|
|
|
|
+ ax.legend()
|
|
|
|
+ ax.set_xlabel(xaxis_name)
|
|
|
|
+ ax.set_ylabel(yaxis_name)
|
|
|
|
+ ax.set_title(title)
|
|
|
|
+
|
|
|
|
+ if annotate:
|
|
|
|
+ print('DEBUG -- REMOVE: Annotating')
|
|
|
|
+ # label correct points green
|
|
|
|
+ for row in dataframe[[stat_1, stat_2]].itertuples():
|
|
|
|
+ plt.text(row[1], row[2], row[0], fontsize=6, color='black')
|
|
|
|
|
|
-# Create the directory if it does not exist
|
|
|
|
-if not os.path.exists(V3_PATH):
|
|
|
|
- os.makedirs(V3_PATH)
|
|
|
|
|
|
+ plt.savefig(path)
|
|
|
|
|
|
|
|
|
|
# Models is a dictionary with the model ids as keys and the model data as values
|
|
# Models is a dictionary with the model ids as keys and the model data as values
|
|
@@ -99,13 +142,19 @@ def select_individual_model(predictions, model_id):
|
|
|
|
|
|
|
|
|
|
# Given a dictionary of predictions, select a subset of models and eliminate the rest
|
|
# Given a dictionary of predictions, select a subset of models and eliminate the rest
|
|
|
|
+# predictions dictory of the form {data_id: (target, {model_id: prediction})}
|
|
def select_subset_models(predictions, model_ids):
|
|
def select_subset_models(predictions, model_ids):
|
|
selected_model_predictions = {}
|
|
selected_model_predictions = {}
|
|
for key, value in predictions.items():
|
|
for key, value in predictions.items():
|
|
|
|
+ target = value[0]
|
|
|
|
+ model_predictions = value[1]
|
|
|
|
+
|
|
|
|
+ # Filter the model predictions, only keeping selected models
|
|
selected_model_predictions[key] = (
|
|
selected_model_predictions[key] = (
|
|
- value[0],
|
|
|
|
- {model_id: value[1][model_id] for model_id in model_ids},
|
|
|
|
|
|
+ target,
|
|
|
|
+ {model_id: model_predictions[str(model_id + 1)] for model_id in model_ids},
|
|
)
|
|
)
|
|
|
|
+
|
|
return selected_model_predictions
|
|
return selected_model_predictions
|
|
|
|
|
|
|
|
|
|
@@ -238,13 +287,16 @@ def common_entries(*dcts):
|
|
for i in set(dcts[0]).intersection(*dcts[1:]):
|
|
for i in set(dcts[0]).intersection(*dcts[1:]):
|
|
yield (i,) + tuple(d[i] for d in dcts)
|
|
yield (i,) + tuple(d[i] for d in dcts)
|
|
|
|
|
|
-#Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
|
|
|
|
|
|
+
|
|
|
|
+# Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
|
|
def calculate_overall_statistics(ensemble_statistics):
|
|
def calculate_overall_statistics(ensemble_statistics):
|
|
predicted = ensemble_statistics['predicted']
|
|
predicted = ensemble_statistics['predicted']
|
|
actual = ensemble_statistics['actual']
|
|
actual = ensemble_statistics['actual']
|
|
|
|
|
|
# New dataframe to store the statistics
|
|
# New dataframe to store the statistics
|
|
- stats_df = pd.DataFrame(columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']).set_index('stat')
|
|
|
|
|
|
+ stats_df = pd.DataFrame(
|
|
|
|
+ columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']
|
|
|
|
+ ).set_index('stat')
|
|
|
|
|
|
# Loop through and calculate the ECE, MCE, Brier Score, and NLL
|
|
# Loop through and calculate the ECE, MCE, Brier Score, and NLL
|
|
for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
|
|
for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
|
|
@@ -254,12 +306,41 @@ def calculate_overall_statistics(ensemble_statistics):
|
|
nll = met.nll_binary(ensemble_statistics[stat], actual)
|
|
nll = met.nll_binary(ensemble_statistics[stat], actual)
|
|
|
|
|
|
stats_df.loc[stat] = [ece, mce, brier, nll]
|
|
stats_df.loc[stat] = [ece, mce, brier, nll]
|
|
-
|
|
|
|
|
|
+
|
|
return stats_df
|
|
return stats_df
|
|
|
|
|
|
|
|
|
|
|
|
+# CONFIGURATION
|
|
|
|
+def load_config():
|
|
|
|
+ if os.getenv('ADL_CONFIG_PATH') is None:
|
|
|
|
+ with open('config.toml', 'rb') as f:
|
|
|
|
+ config = toml.load(f)
|
|
|
|
+ else:
|
|
|
|
+ with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
|
|
|
|
+ config = toml.load(f)
|
|
|
|
+
|
|
|
|
+ return config
|
|
|
|
+
|
|
|
|
+def prune_dataset(dataset, pruned_ids):
|
|
|
|
+ pruned_dataset = []
|
|
|
|
+ for i, (data, target) in enumerate(dataset):
|
|
|
|
+ if i not in pruned_ids:
|
|
|
|
+ pruned_dataset.append((data, target))
|
|
|
|
+
|
|
|
|
+ return pruned_dataset
|
|
|
|
+
|
|
|
|
|
|
def main():
|
|
def main():
|
|
|
|
+ config = load_config()
|
|
|
|
+
|
|
|
|
+ ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
|
|
+
|
|
|
|
+ V3_PATH = ENSEMBLE_PATH + '/v3'
|
|
|
|
+
|
|
|
|
+ # Create the directory if it does not exist
|
|
|
|
+ if not os.path.exists(V3_PATH):
|
|
|
|
+ os.makedirs(V3_PATH)
|
|
|
|
+
|
|
# Load the models
|
|
# Load the models
|
|
device = torch.device(config['training']['device'])
|
|
device = torch.device(config['training']['device'])
|
|
models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
|
|
models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
|
|
@@ -269,6 +350,8 @@ def main():
|
|
f'{ENSEMBLE_PATH}/val_dataset.pt'
|
|
f'{ENSEMBLE_PATH}/val_dataset.pt'
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ dataset =
|
|
|
|
+
|
|
if config['ensemble']['run_models']:
|
|
if config['ensemble']['run_models']:
|
|
# Get thre predicitons of the ensemble
|
|
# Get thre predicitons of the ensemble
|
|
ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
|
|
ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
|
|
@@ -283,6 +366,7 @@ def main():
|
|
|
|
|
|
# Get the statistics and thresholds of the ensemble
|
|
# Get the statistics and thresholds of the ensemble
|
|
ensemble_statistics = calculate_statistics(ensemble_predictions)
|
|
ensemble_statistics = calculate_statistics(ensemble_predictions)
|
|
|
|
+
|
|
stdev_thresholds = conduct_threshold_analysis(
|
|
stdev_thresholds = conduct_threshold_analysis(
|
|
ensemble_statistics, 'stdev', low_to_high=True
|
|
ensemble_statistics, 'stdev', low_to_high=True
|
|
)
|
|
)
|
|
@@ -296,6 +380,78 @@ def main():
|
|
raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
|
|
raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
|
|
ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
|
|
ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
|
|
|
|
|
|
|
|
+ # Plot confidence vs standard deviation
|
|
|
|
+ plot_statistics_versus(
|
|
|
|
+ 'raw_confidence',
|
|
|
|
+ 'stdev',
|
|
|
|
+ 'Confidence',
|
|
|
|
+ 'Standard Deviation',
|
|
|
|
+ 'Confidence vs Standard Deviation',
|
|
|
|
+ ensemble_statistics,
|
|
|
|
+ f'{V3_PATH}/confidence_vs_stdev.png',
|
|
|
|
+ annotate=True,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Plot images - 3 weird and 3 normal
|
|
|
|
+ # Selected from confidence vs stdev plot
|
|
|
|
+ plot_image_grid(
|
|
|
|
+ [279, 202, 28, 107, 27, 121],
|
|
|
|
+ dataset,
|
|
|
|
+ 2,
|
|
|
|
+ f'{V3_PATH}/image_grid.png',
|
|
|
|
+ titles=[
|
|
|
|
+ 'Weird: 279',
|
|
|
|
+ 'Weird: 202',
|
|
|
|
+ 'Weird: 28',
|
|
|
|
+ 'Normal: 107',
|
|
|
|
+ 'Normal: 27',
|
|
|
|
+ 'Normal: 121',
|
|
|
|
+ ],
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Filter dataset for where confidence < .7 and stdev < .1
|
|
|
|
+ weird_results = ensemble_statistics.loc[
|
|
|
|
+ (
|
|
|
|
+ (ensemble_statistics['raw_confidence'] < 0.7)
|
|
|
|
+ & (ensemble_statistics['stdev'] < 0.1)
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+ normal_results = ensemble_statistics.loc[
|
|
|
|
+ ~(
|
|
|
|
+ (ensemble_statistics['raw_confidence'] < 0.7)
|
|
|
|
+ & (ensemble_statistics['stdev'] < 0.1)
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+ # Get the data ids in a list
|
|
|
|
+ # Plot the images
|
|
|
|
+ if not os.path.exists(f'{V3_PATH}/images'):
|
|
|
|
+ os.makedirs(f'{V3_PATH}/images/weird')
|
|
|
|
+ os.makedirs(f'{V3_PATH}/images/normal')
|
|
|
|
+
|
|
|
|
+ for i in weird_results.itertuples():
|
|
|
|
+ id = i.Index
|
|
|
|
+ conf = i.raw_confidence
|
|
|
|
+ stdev = i.stdev
|
|
|
|
+
|
|
|
|
+ plot_single_image(
|
|
|
|
+ id,
|
|
|
|
+ dataset,
|
|
|
|
+ f'{V3_PATH}/images/weird/{id}.png',
|
|
|
|
+ title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ for i in normal_results.itertuples():
|
|
|
|
+ id = i.Index
|
|
|
|
+ conf = i.raw_confidence
|
|
|
|
+ stdev = i.stdev
|
|
|
|
+
|
|
|
|
+ plot_single_image(
|
|
|
|
+ id,
|
|
|
|
+ dataset,
|
|
|
|
+ f'{V3_PATH}/images/normal/{id}.png',
|
|
|
|
+ title=f'ID: {id}, Confidence: {conf}, Stdev: {stdev}',
|
|
|
|
+ )
|
|
|
|
+
|
|
# Calculate overall statistics
|
|
# Calculate overall statistics
|
|
overall_statistics = calculate_overall_statistics(ensemble_statistics)
|
|
overall_statistics = calculate_overall_statistics(ensemble_statistics)
|
|
|
|
|