|
@@ -12,6 +12,10 @@ import utils.metrics as met
|
|
|
import itertools as it
|
|
|
import matplotlib.ticker as ticker
|
|
|
import glob
|
|
|
+import pickle as pk
|
|
|
+import warnings
|
|
|
+
|
|
|
+warnings.filterwarnings('error')
|
|
|
|
|
|
# CONFIGURATION
|
|
|
if os.getenv('ADL_CONFIG_PATH') is None:
|
|
@@ -23,7 +27,11 @@ else:
|
|
|
|
|
|
ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
|
|
|
|
-V2_PATH = ENSEMBLE_PATH + '/v2'
|
|
|
+V3_PATH = ENSEMBLE_PATH + '/v3'
|
|
|
+
|
|
|
+# Create the directory if it does not exist
|
|
|
+if not os.path.exists(V3_PATH):
|
|
|
+ os.makedirs(V3_PATH)
|
|
|
|
|
|
|
|
|
# Models is a dictionary with the model ids as keys and the model data as values
|
|
@@ -31,7 +39,7 @@ def get_model_predictions(models, data):
|
|
|
predictions = {}
|
|
|
for model_id, model in models.items():
|
|
|
model.eval()
|
|
|
- with torch.no_grad:
|
|
|
+ with torch.no_grad():
|
|
|
# Get the predictions
|
|
|
output = model(data)
|
|
|
predictions[model_id] = output.detach().cpu().numpy()
|
|
@@ -41,7 +49,7 @@ def get_model_predictions(models, data):
|
|
|
|
|
|
def load_models_v2(folder, device):
|
|
|
glob_path = os.path.join(folder, '*.pt')
|
|
|
- model_files = glob(glob_path)
|
|
|
+ model_files = glob.glob(glob_path)
|
|
|
model_dict = {}
|
|
|
|
|
|
for model_file in model_files:
|
|
@@ -49,6 +57,9 @@ def load_models_v2(folder, device):
|
|
|
model_id = os.path.basename(model_file).split('_')[0]
|
|
|
model_dict[model_id] = model
|
|
|
|
|
|
+ if len(model_dict) == 0:
|
|
|
+ raise FileNotFoundError('No models found in the specified directory: ' + folder)
|
|
|
+
|
|
|
return model_dict
|
|
|
|
|
|
|
|
@@ -80,7 +91,10 @@ def ensemble_dataset_predictions(models, dataset, device):
|
|
|
def select_individual_model(predictions, model_id):
|
|
|
selected_model_predictions = {}
|
|
|
for key, value in predictions.items():
|
|
|
- selected_model_predictions[key] = (value[0], {model_id: value[1][model_id]})
|
|
|
+ selected_model_predictions[key] = (
|
|
|
+ value[0],
|
|
|
+ {model_id: value[1][str(model_id)]},
|
|
|
+ )
|
|
|
return selected_model_predictions
|
|
|
|
|
|
|
|
@@ -95,5 +109,246 @@ def select_subset_models(predictions, model_ids):
|
|
|
return selected_model_predictions
|
|
|
|
|
|
|
|
|
-# Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, accuracy, f1) for each result
|
|
|
+# Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
|
|
|
+# Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
|
|
|
def calculate_statistics(predictions):
|
|
|
+ # Create DataFrame with columns for each statistic
|
|
|
+ stats_df = pd.DataFrame(
|
|
|
+ columns=[
|
|
|
+ 'mean',
|
|
|
+ 'stdev',
|
|
|
+ 'entropy',
|
|
|
+ 'confidence',
|
|
|
+ 'correct',
|
|
|
+ 'predicted',
|
|
|
+ 'actual',
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ # First, loop through each prediction
|
|
|
+ for key, value in predictions.items():
|
|
|
+ target = value[0]
|
|
|
+ model_predictions = list(value[1].values())
|
|
|
+
|
|
|
+ # Calculate the mean and stdev of predictions
|
|
|
+ mean = np.squeeze(np.mean(model_predictions, axis=0))
|
|
|
+ stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
|
|
|
+
|
|
|
+ # Calculate the entropy of the predictions
|
|
|
+ entropy = met.entropy(mean)
|
|
|
+
|
|
|
+ # Calculate confidence
|
|
|
+ confidence = (np.max(mean) - 0.5) * 2
|
|
|
+
|
|
|
+ # Calculate predicted and actual
|
|
|
+ predicted = np.argmax(mean)
|
|
|
+ actual = np.argmax(target)
|
|
|
+
|
|
|
+ # Determine if the prediction is correct
|
|
|
+ correct = predicted == actual
|
|
|
+
|
|
|
+ # Add the statistics to the dataframe
|
|
|
+ stats_df.loc[key] = [
|
|
|
+ mean,
|
|
|
+ stdev,
|
|
|
+ entropy,
|
|
|
+ confidence,
|
|
|
+ correct,
|
|
|
+ predicted,
|
|
|
+ actual,
|
|
|
+ ]
|
|
|
+
|
|
|
+ return stats_df
|
|
|
+
|
|
|
+
|
|
|
+# Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
|
|
|
+# Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
|
|
|
+def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
|
|
|
+ # Gives a dataframe
|
|
|
+ percentile_df = statistics[statistic_name].quantile(
|
|
|
+ q=np.linspace(0.05, 0.95, num=18)
|
|
|
+ )
|
|
|
+
|
|
|
+ # Dictionary of form {threshold: {metric: value}}
|
|
|
+ thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
|
|
|
+ for percentile, value in percentile_df.items():
|
|
|
+ # Filter the statistics
|
|
|
+ if low_to_high:
|
|
|
+ filtered_statistics = statistics[statistics[statistic_name] < value]
|
|
|
+ else:
|
|
|
+ filtered_statistics = statistics[statistics[statistic_name] >= value]
|
|
|
+
|
|
|
+ # Calculate accuracy and f1 score
|
|
|
+ accuracy = filtered_statistics['correct'].mean()
|
|
|
+
|
|
|
+ # Calculate F1 score
|
|
|
+ predicted = filtered_statistics['predicted'].values
|
|
|
+ actual = filtered_statistics['actual'].values
|
|
|
+
|
|
|
+ f1 = metrics.f1_score(actual, predicted)
|
|
|
+
|
|
|
+ # Add the metrics to the dataframe
|
|
|
+ thresholds_pd.loc[percentile] = [accuracy, f1]
|
|
|
+
|
|
|
+ return thresholds_pd
|
|
|
+
|
|
|
+
|
|
|
+# Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
|
|
|
+# Can plot an additional line if given (used for individual results)
|
|
|
+def plot_threshold_analysis(
|
|
|
+ thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
|
|
|
+):
|
|
|
+ # Initialize the plot
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+
|
|
|
+ # Get the thresholds and metrics
|
|
|
+ thresholds = list(thresholds_metric.index)
|
|
|
+ metric = list(thresholds_metric.values)
|
|
|
+
|
|
|
+ # Plot the metric against the threshold
|
|
|
+ plt.plot(thresholds, metric, 'bo-', label='Ensemble')
|
|
|
+
|
|
|
+ if additional_set is not None:
|
|
|
+ # Get the thresholds and metrics
|
|
|
+ thresholds = list(additional_set.index)
|
|
|
+ metric = list(additional_set.values)
|
|
|
+
|
|
|
+ # Plot the metric against the threshold
|
|
|
+ plt.plot(thresholds, metric, 'rx-', label='Individual')
|
|
|
+
|
|
|
+ if flip:
|
|
|
+ ax.invert_xaxis()
|
|
|
+
|
|
|
+ # Add labels
|
|
|
+ plt.title(title)
|
|
|
+ plt.xlabel(x_label)
|
|
|
+ plt.ylabel(y_label)
|
|
|
+ plt.legend()
|
|
|
+ ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
|
|
|
+
|
|
|
+ plt.savefig(path)
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+
|
|
|
+# Code from https://stackoverflow.com/questions/16458340
|
|
|
+# Returns the intersections of multiple dictionaries
|
|
|
+def common_entries(*dcts):
|
|
|
+ if not dcts:
|
|
|
+ return
|
|
|
+ for i in set(dcts[0]).intersection(*dcts[1:]):
|
|
|
+ yield (i,) + tuple(d[i] for d in dcts)
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ # Load the models
|
|
|
+ device = torch.device(config['training']['device'])
|
|
|
+ models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
|
|
|
+
|
|
|
+ # Load Dataset
|
|
|
+ dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
|
|
|
+ f'{ENSEMBLE_PATH}/val_dataset.pt'
|
|
|
+ )
|
|
|
+
|
|
|
+ if config['ensemble']['run_models']:
|
|
|
+ # Get thre predicitons of the ensemble
|
|
|
+ ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
|
|
|
+
|
|
|
+ # Save to file using pickle
|
|
|
+ with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
|
|
|
+ pk.dump(ensemble_predictions, f)
|
|
|
+ else:
|
|
|
+ # Load the predictions from file
|
|
|
+ with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
|
|
|
+ ensemble_predictions = pk.load(f)
|
|
|
+
|
|
|
+ # Get the statistics and thresholds of the ensemble
|
|
|
+ ensemble_statistics = calculate_statistics(ensemble_predictions)
|
|
|
+ stdev_thresholds = conduct_threshold_analysis(
|
|
|
+ ensemble_statistics, 'stdev', low_to_high=True
|
|
|
+ )
|
|
|
+ entropy_thresholds = conduct_threshold_analysis(
|
|
|
+ ensemble_statistics, 'entropy', low_to_high=True
|
|
|
+ )
|
|
|
+ confidence_thresholds = conduct_threshold_analysis(
|
|
|
+ ensemble_statistics, 'confidence', low_to_high=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # Print overall ensemble statistics
|
|
|
+ print('Ensemble Statistics')
|
|
|
+ print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
|
|
|
+ print(
|
|
|
+ f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Get the predictions, statistics and thresholds an individual model
|
|
|
+ indv_id = config['ensemble']['individual_id']
|
|
|
+ indv_predictions = select_individual_model(ensemble_predictions, indv_id)
|
|
|
+ indv_statistics = calculate_statistics(indv_predictions)
|
|
|
+
|
|
|
+ # Calculate entropy and confidence thresholds for individual model
|
|
|
+ indv_entropy_thresholds = conduct_threshold_analysis(
|
|
|
+ indv_statistics, 'entropy', low_to_high=True
|
|
|
+ )
|
|
|
+ indv_confidence_thresholds = conduct_threshold_analysis(
|
|
|
+ indv_statistics, 'confidence', low_to_high=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # Plot the threshold analysis for standard deviation
|
|
|
+ plot_threshold_analysis(
|
|
|
+ stdev_thresholds['accuracy'],
|
|
|
+ 'Stdev Threshold Analysis for Accuracy',
|
|
|
+ 'Stdev Threshold',
|
|
|
+ 'Accuracy',
|
|
|
+ f'{V3_PATH}/stdev_threshold_analysis.png',
|
|
|
+ flip=True,
|
|
|
+ )
|
|
|
+ plot_threshold_analysis(
|
|
|
+ stdev_thresholds['f1'],
|
|
|
+ 'Stdev Threshold Analysis for F1 Score',
|
|
|
+ 'Stdev Threshold',
|
|
|
+ 'F1 Score',
|
|
|
+ f'{V3_PATH}/stdev_threshold_analysis_f1.png',
|
|
|
+ flip=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Plot the threshold analysis for entropy
|
|
|
+ plot_threshold_analysis(
|
|
|
+ entropy_thresholds['accuracy'],
|
|
|
+ 'Entropy Threshold Analysis for Accuracy',
|
|
|
+ 'Entropy Threshold',
|
|
|
+ 'Accuracy',
|
|
|
+ f'{V3_PATH}/entropy_threshold_analysis.png',
|
|
|
+ indv_entropy_thresholds['accuracy'],
|
|
|
+ flip=True,
|
|
|
+ )
|
|
|
+ plot_threshold_analysis(
|
|
|
+ entropy_thresholds['f1'],
|
|
|
+ 'Entropy Threshold Analysis for F1 Score',
|
|
|
+ 'Entropy Threshold',
|
|
|
+ 'F1 Score',
|
|
|
+ f'{V3_PATH}/entropy_threshold_analysis_f1.png',
|
|
|
+ indv_entropy_thresholds['f1'],
|
|
|
+ flip=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Plot the threshold analysis for confidence
|
|
|
+ plot_threshold_analysis(
|
|
|
+ confidence_thresholds['accuracy'],
|
|
|
+ 'Confidence Threshold Analysis for Accuracy',
|
|
|
+ 'Confidence Threshold',
|
|
|
+ 'Accuracy',
|
|
|
+ f'{V3_PATH}/confidence_threshold_analysis.png',
|
|
|
+ indv_confidence_thresholds['accuracy'],
|
|
|
+ )
|
|
|
+ plot_threshold_analysis(
|
|
|
+ confidence_thresholds['f1'],
|
|
|
+ 'Confidence Threshold Analysis for F1 Score',
|
|
|
+ 'Confidence Threshold',
|
|
|
+ 'F1 Score',
|
|
|
+ f'{V3_PATH}/confidence_threshold_analysis_f1.png',
|
|
|
+ indv_confidence_thresholds['f1'],
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|