1
0

2 Commits 7775cf4b28 ... 257a539245

Autor SHA1 Nachricht Datum
  Nicholas Schense 257a539245 Finished threshold rewrite! vor 2 Monaten
  Nicholas Schense 8a70abc2e8 Begin full threshold refactor vor 2 Monaten
4 geänderte Dateien mit 376 neuen und 0 gelöschten Zeilen
  1. 16 0
      .vscode/launch.json
  2. 2 0
      config.toml
  3. 354 0
      threshold_refac.py
  4. 4 0
      utils/metrics.py

+ 16 - 0
.vscode/launch.json

@@ -0,0 +1,16 @@
+{
+    // Use IntelliSense to learn about possible attributes.
+    // Hover to view descriptions of existing attributes.
+    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+        
+        {
+            "name": "Python Debugger: Current File",
+            "type": "debugpy",
+            "request": "launch",
+            "program": "${file}",
+            "console": "integratedTerminal"
+        }
+    ]
+}

+ 2 - 0
config.toml

@@ -31,3 +31,5 @@ silent = false
 [ensemble]
 name = 'cnn-50x30'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
+individual_id = 1     # The id of the individual model to be used for the ensemble
+run_models = true     # If true, the ensemble will run the models to generate the predictions, otherwise will load from file

+ 354 - 0
threshold_refac.py

@@ -0,0 +1,354 @@
+import pandas as pd
+import numpy as np
+import os
+import tomli as toml
+from utils.data.datasets import prepare_datasets
+import utils.ensemble as ens
+import torch
+import matplotlib.pyplot as plt
+import sklearn.metrics as metrics
+from tqdm import tqdm
+import utils.metrics as met
+import itertools as it
+import matplotlib.ticker as ticker
+import glob
+import pickle as pk
+import warnings
+
+warnings.filterwarnings('error')
+
+# CONFIGURATION
+if os.getenv('ADL_CONFIG_PATH') is None:
+    with open('config.toml', 'rb') as f:
+        config = toml.load(f)
+else:
+    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+        config = toml.load(f)
+
+ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+
+V3_PATH = ENSEMBLE_PATH + '/v3'
+
+# Create the directory if it does not exist
+if not os.path.exists(V3_PATH):
+    os.makedirs(V3_PATH)
+
+
+# Models is a dictionary with the model ids as keys and the model data as values
+def get_model_predictions(models, data):
+    predictions = {}
+    for model_id, model in models.items():
+        model.eval()
+        with torch.no_grad():
+            # Get the predictions
+            output = model(data)
+            predictions[model_id] = output.detach().cpu().numpy()
+
+    return predictions
+
+
+def load_models_v2(folder, device):
+    glob_path = os.path.join(folder, '*.pt')
+    model_files = glob.glob(glob_path)
+    model_dict = {}
+
+    for model_file in model_files:
+        model = torch.load(model_file, map_location=device)
+        model_id = os.path.basename(model_file).split('_')[0]
+        model_dict[model_id] = model
+
+    if len(model_dict) == 0:
+        raise FileNotFoundError('No models found in the specified directory: ' + folder)
+
+    return model_dict
+
+
+# Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
+def preprocess_data(data, device):
+    mri, xls = data
+    mri = mri.unsqueeze(0).to(device)
+    xls = xls.unsqueeze(0).to(device)
+    return (mri, xls)
+
+
+def ensemble_dataset_predictions(models, dataset, device):
+    # For each datapoint, get the predictions of each model
+    predictions = {}
+    for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
+        # Preprocess data
+        data = preprocess_data(data, device)
+        # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
+        # The key is the id of the image
+        predictions[i] = (
+            target.detach().cpu().numpy(),
+            get_model_predictions(models, data),
+        )
+
+    return predictions
+
+
+# Given a dictionary of predictions, select one model and eliminate the rest
+def select_individual_model(predictions, model_id):
+    selected_model_predictions = {}
+    for key, value in predictions.items():
+        selected_model_predictions[key] = (
+            value[0],
+            {model_id: value[1][str(model_id)]},
+        )
+    return selected_model_predictions
+
+
+# Given a dictionary of predictions, select a subset of models and eliminate the rest
+def select_subset_models(predictions, model_ids):
+    selected_model_predictions = {}
+    for key, value in predictions.items():
+        selected_model_predictions[key] = (
+            value[0],
+            {model_id: value[1][model_id] for model_id in model_ids},
+        )
+    return selected_model_predictions
+
+
+# Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result
+# Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)}
+def calculate_statistics(predictions):
+    # Create DataFrame with columns for each statistic
+    stats_df = pd.DataFrame(
+        columns=[
+            'mean',
+            'stdev',
+            'entropy',
+            'confidence',
+            'correct',
+            'predicted',
+            'actual',
+        ]
+    )
+
+    # First, loop through each prediction
+    for key, value in predictions.items():
+        target = value[0]
+        model_predictions = list(value[1].values())
+
+        # Calculate the mean and stdev of predictions
+        mean = np.squeeze(np.mean(model_predictions, axis=0))
+        stdev = np.squeeze(np.std(model_predictions, axis=0))[1]
+
+        # Calculate the entropy of the predictions
+        entropy = met.entropy(mean)
+
+        # Calculate confidence
+        confidence = (np.max(mean) - 0.5) * 2
+
+        # Calculate predicted and actual
+        predicted = np.argmax(mean)
+        actual = np.argmax(target)
+
+        # Determine if the prediction is correct
+        correct = predicted == actual
+
+        # Add the statistics to the dataframe
+        stats_df.loc[key] = [
+            mean,
+            stdev,
+            entropy,
+            confidence,
+            correct,
+            predicted,
+            actual,
+        ]
+
+    return stats_df
+
+
+# Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic
+# Output of the form DataFrame(index=threshold, columns=[accuracy, f1])
+def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True):
+    # Gives a dataframe
+    percentile_df = statistics[statistic_name].quantile(
+        q=np.linspace(0.05, 0.95, num=18)
+    )
+
+    # Dictionary of form {threshold: {metric: value}}
+    thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1'])
+    for percentile, value in percentile_df.items():
+        # Filter the statistics
+        if low_to_high:
+            filtered_statistics = statistics[statistics[statistic_name] < value]
+        else:
+            filtered_statistics = statistics[statistics[statistic_name] >= value]
+
+        # Calculate accuracy and f1 score
+        accuracy = filtered_statistics['correct'].mean()
+
+        # Calculate F1 score
+        predicted = filtered_statistics['predicted'].values
+        actual = filtered_statistics['actual'].values
+
+        f1 = metrics.f1_score(actual, predicted)
+
+        # Add the metrics to the dataframe
+        thresholds_pd.loc[percentile] = [accuracy, f1]
+
+    return thresholds_pd
+
+
+# Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold.
+# Can plot an additional line if given (used for individual results)
+def plot_threshold_analysis(
+    thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False
+):
+    # Initialize the plot
+    fig, ax = plt.subplots()
+
+    # Get the thresholds and metrics
+    thresholds = list(thresholds_metric.index)
+    metric = list(thresholds_metric.values)
+
+    # Plot the metric against the threshold
+    plt.plot(thresholds, metric, 'bo-', label='Ensemble')
+
+    if additional_set is not None:
+        # Get the thresholds and metrics
+        thresholds = list(additional_set.index)
+        metric = list(additional_set.values)
+
+        # Plot the metric against the threshold
+        plt.plot(thresholds, metric, 'rx-', label='Individual')
+
+    if flip:
+        ax.invert_xaxis()
+
+    # Add labels
+    plt.title(title)
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.legend()
+    ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
+
+    plt.savefig(path)
+    plt.close()
+
+
+# Code from https://stackoverflow.com/questions/16458340
+# Returns the intersections of multiple dictionaries
+def common_entries(*dcts):
+    if not dcts:
+        return
+    for i in set(dcts[0]).intersection(*dcts[1:]):
+        yield (i,) + tuple(d[i] for d in dcts)
+
+
+def main():
+    # Load the models
+    device = torch.device(config['training']['device'])
+    models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+
+    # Load Dataset
+    dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
+        f'{ENSEMBLE_PATH}/val_dataset.pt'
+    )
+
+    if config['ensemble']['run_models']:
+        # Get thre predicitons of the ensemble
+        ensemble_predictions = ensemble_dataset_predictions(models, dataset, device)
+
+        # Save to file using pickle
+        with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
+            pk.dump(ensemble_predictions, f)
+    else:
+        # Load the predictions from file
+        with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
+            ensemble_predictions = pk.load(f)
+
+    # Get the statistics and thresholds of the ensemble
+    ensemble_statistics = calculate_statistics(ensemble_predictions)
+    stdev_thresholds = conduct_threshold_analysis(
+        ensemble_statistics, 'stdev', low_to_high=True
+    )
+    entropy_thresholds = conduct_threshold_analysis(
+        ensemble_statistics, 'entropy', low_to_high=True
+    )
+    confidence_thresholds = conduct_threshold_analysis(
+        ensemble_statistics, 'confidence', low_to_high=False
+    )
+
+    # Print overall ensemble statistics
+    print('Ensemble Statistics')
+    print(f"Accuracy: {ensemble_statistics['correct'].mean()}")
+    print(
+        f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}"
+    )
+
+    # Get the predictions, statistics and thresholds an individual model
+    indv_id = config['ensemble']['individual_id']
+    indv_predictions = select_individual_model(ensemble_predictions, indv_id)
+    indv_statistics = calculate_statistics(indv_predictions)
+
+    # Calculate entropy and confidence thresholds for individual model
+    indv_entropy_thresholds = conduct_threshold_analysis(
+        indv_statistics, 'entropy', low_to_high=True
+    )
+    indv_confidence_thresholds = conduct_threshold_analysis(
+        indv_statistics, 'confidence', low_to_high=False
+    )
+
+    # Plot the threshold analysis for standard deviation
+    plot_threshold_analysis(
+        stdev_thresholds['accuracy'],
+        'Stdev Threshold Analysis for Accuracy',
+        'Stdev Threshold',
+        'Accuracy',
+        f'{V3_PATH}/stdev_threshold_analysis.png',
+        flip=True,
+    )
+    plot_threshold_analysis(
+        stdev_thresholds['f1'],
+        'Stdev Threshold Analysis for F1 Score',
+        'Stdev Threshold',
+        'F1 Score',
+        f'{V3_PATH}/stdev_threshold_analysis_f1.png',
+        flip=True,
+    )
+
+    # Plot the threshold analysis for entropy
+    plot_threshold_analysis(
+        entropy_thresholds['accuracy'],
+        'Entropy Threshold Analysis for Accuracy',
+        'Entropy Threshold',
+        'Accuracy',
+        f'{V3_PATH}/entropy_threshold_analysis.png',
+        indv_entropy_thresholds['accuracy'],
+        flip=True,
+    )
+    plot_threshold_analysis(
+        entropy_thresholds['f1'],
+        'Entropy Threshold Analysis for F1 Score',
+        'Entropy Threshold',
+        'F1 Score',
+        f'{V3_PATH}/entropy_threshold_analysis_f1.png',
+        indv_entropy_thresholds['f1'],
+        flip=True,
+    )
+
+    # Plot the threshold analysis for confidence
+    plot_threshold_analysis(
+        confidence_thresholds['accuracy'],
+        'Confidence Threshold Analysis for Accuracy',
+        'Confidence Threshold',
+        'Accuracy',
+        f'{V3_PATH}/confidence_threshold_analysis.png',
+        indv_confidence_thresholds['accuracy'],
+    )
+    plot_threshold_analysis(
+        confidence_thresholds['f1'],
+        'Confidence Threshold Analysis for F1 Score',
+        'Confidence Threshold',
+        'F1 Score',
+        f'{V3_PATH}/confidence_threshold_analysis_f1.png',
+        indv_confidence_thresholds['f1'],
+    )
+
+
+if __name__ == '__main__':
+    main()

+ 4 - 0
utils/metrics.py

@@ -71,3 +71,7 @@ def F1(predicted_labels, true_labels):
 def AUC(confidences, true_labels):
     fpr, tpr, _ = mt.roc_curve(true_labels, confidences)
     return mt.auc(fpr, tpr)
+
+
+def entropy(confidences):
+    return -np.sum(confidences * np.log(confidences))