import pandas as pd import numpy as np import os import tomli as toml from utils.data.datasets import prepare_datasets import utils.ensemble as ens import torch import matplotlib.pyplot as plt import sklearn.metrics as metrics from tqdm import tqdm RUN = True # CONFIGURATION if os.getenv('ADL_CONFIG_PATH') is None: with open('config.toml', 'rb') as f: config = toml.load(f) else: with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f: config = toml.load(f) # This function returns a list of the accuracies given a threshold def threshold(config): # First, get the model data test_set = torch.load( config['paths']['model_output'] + config['ensemble']['name'] + '/test_dataset.pt' ) vs = torch.load( config['paths']['model_output'] + config['ensemble']['name'] + '/val_dataset.pt' ) test_set = test_set + vs models, _ = ens.load_models( config['paths']['model_output'] + config['ensemble']['name'] + '/models/', config['training']['device'], ) indv_model = models[0] predictions = [] indv_predictions = [] # Evaluate ensemble and uncertainty test set for mdata, target in tqdm(test_set, total=len(test_set)): mri, xls = mdata mri = mri.unsqueeze(0) xls = xls.unsqueeze(0) mdata = (mri, xls) mean, variance = ens.ensemble_predict(models, mdata) stdev = torch.sqrt(variance) prediction = mean.item() target = target[1] # Check if the prediction is correct correct = (prediction < 0.5 and int(target.item()) == 0) or ( prediction >= 0.5 and int(target.item()) == 1 ) predictions.append( { 'Prediction': prediction, 'Actual': target.item(), 'Stdev': stdev.item(), 'Correct': correct, } ) i_mean = indv_model(mdata)[:, 1].item() i_correct = (i_mean < 0.5 and int(target.item()) == 0) or ( i_mean >= 0.5 and int(target.item()) == 1 ) indv_predictions.append( { 'Prediction': i_mean, 'Actual': target.item(), 'Stdev': 0, 'Correct': i_correct, } ) # Sort the predictions by the uncertainty predictions = pd.DataFrame(predictions).sort_values(by='Stdev') # Calculate the metrics for the individual model indv_predictions = pd.DataFrame(indv_predictions) indv_correct = indv_predictions['Correct'].sum() indv_accuracy = indv_correct / len(indv_predictions) indv_false_pos = len( indv_predictions[ (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0) ] ) indv_false_neg = len( indv_predictions[ (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1) ] ) indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg) indv_auc = metrics.roc_auc_score( indv_predictions['Actual'], indv_predictions['Prediction'] ) indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc} thresholds = [] quantiles = np.arange(0.1, 1, 0.1) # get uncertainty quantiles for quantile in quantiles: thresholds.append(predictions['Stdev'].quantile(quantile)) # Calculate the accuracy of the model for each threshold accuracies = [] # Calculate the accuracy of the model for each threshold for threshold, quantile in zip(thresholds, quantiles): filtered = predictions[predictions['Stdev'] <= threshold] correct = filtered['Correct'].sum() total = len(filtered) accuracy = correct / total false_positives = len( filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)] ) false_negatives = len( filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)] ) f1 = 2 * correct / (2 * correct + false_positives + false_negatives) auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction']) accuracies.append( { 'Threshold': threshold, 'Accuracy': accuracy, 'Quantile': quantile, 'F1': f1, 'AUC': auc, } ) predictions.to_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv" ) indv_predictions.to_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv" ) return pd.DataFrame(accuracies), indv_metrics if RUN: result, indv = threshold(config) result.to_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv" ) indv = pd.DataFrame([indv]) indv.to_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv" ) result = pd.read_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv" ) predictions = pd.read_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv" ) indv = pd.read_csv( f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv" ) print(indv) plt.figure() plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy') plt.plot( result['Quantile'], [indv['Accuracy']] * len(result['Quantile']), label='Individual Accuracy', linestyle='--', ) plt.legend() plt.title('Accuracy vs Coverage') plt.xlabel('Coverage') plt.ylabel('Accuracy') plt.gca().invert_xaxis() plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png" ) plt.figure() plt.plot(result['Quantile'], result['F1'], label='Ensemble F1') plt.plot( result['Quantile'], [indv['F1']] * len(result['Quantile']), label='Individual F1', linestyle='--', ) plt.legend() plt.title('F1 vs Coverage') plt.xlabel('Coverage') plt.ylabel('F1') plt.gca().invert_xaxis() plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png" ) plt.figure() plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC') plt.plot( result['Quantile'], [indv['AUC']] * len(result['Quantile']), label='Individual AUC', linestyle='--', ) plt.legend() plt.title('AUC vs Coverage') plt.xlabel('Coverage') plt.ylabel('AUC') plt.gca().invert_xaxis() plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png" ) # create histogram of the incorrect predictions vs the uncertainty plt.figure() plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10) plt.xlabel('Uncertainty') plt.ylabel('Number of incorrect predictions') plt.savefig( f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png" )