123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- import pandas as pd
- import numpy as np
- import os
- import tomli as toml
- from utils.data.datasets import prepare_datasets
- import utils.ensemble as ens
- import torch
- import matplotlib.pyplot as plt
- import sklearn.metrics as metrics
- from tqdm import tqdm
- RUN = True
- # CONFIGURATION
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open('config.toml', 'rb') as f:
- config = toml.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = toml.load(f)
- # This function returns a list of the accuracies given a threshold
- def threshold(config):
- # First, get the model data
- ts, vs, test_set = prepare_datasets(
- config['paths']['mri_data'],
- config['paths']['xls_data'],
- config['dataset']['validation_split'],
- 944,
- config['training']['device'],
- )
- test_set = test_set + vs
- models, _ = ens.load_models(
- config['paths']['model_output'] + config['ensemble']['name'] + '/',
- config['training']['device'],
- )
- indv_model = models[0]
- predictions = []
- indv_predictions = []
- # Evaluate ensemble and uncertainty test set
- for mdata, target in tqdm(test_set, total=len(test_set)):
- mri, xls = mdata
- mri = mri.unsqueeze(0)
- xls = xls.unsqueeze(0)
- mdata = (mri, xls)
- mean, variance = ens.ensemble_predict(models, mdata)
- stdev = torch.sqrt(variance)
- prediction = mean.item()
- target = target[1]
- # Check if the prediction is correct
- correct = (prediction < 0.5 and int(target.item()) == 0) or (
- prediction >= 0.5 and int(target.item()) == 1
- )
- predictions.append(
- {
- 'Prediction': prediction,
- 'Actual': target.item(),
- 'Stdev': stdev.item(),
- 'Correct': correct,
- }
- )
- i_mean = indv_model(mdata)[:, 1].item()
- i_correct = (i_mean < 0.5 and int(target.item()) == 0) or (
- i_mean >= 0.5 and int(target.item()) == 1
- )
- indv_predictions.append(
- {
- 'Prediction': i_mean,
- 'Actual': target.item(),
- 'Stdev': 0,
- 'Correct': i_correct,
- }
- )
- # Sort the predictions by the uncertainty
- predictions = pd.DataFrame(predictions).sort_values(by='Stdev')
- # Calculate the metrics for the individual model
- indv_predictions = pd.DataFrame(indv_predictions)
- indv_correct = indv_predictions['Correct'].sum()
- indv_accuracy = indv_correct / len(indv_predictions)
- indv_false_pos = len(
- indv_predictions[
- (indv_predictions['Prediction'] >= 0.5) & (indv_predictions['Actual'] == 0)
- ]
- )
- indv_false_neg = len(
- indv_predictions[
- (indv_predictions['Prediction'] < 0.5) & (indv_predictions['Actual'] == 1)
- ]
- )
- indv_f1 = 2 * indv_correct / (2 * indv_correct + indv_false_pos + indv_false_neg)
- indv_auc = metrics.roc_auc_score(
- indv_predictions['Actual'], indv_predictions['Prediction']
- )
- indv_metrics = {'Accuracy': indv_accuracy, 'F1': indv_f1, 'AUC': indv_auc}
- thresholds = []
- quantiles = np.arange(0.1, 1, 0.1)
- # get uncertainty quantiles
- for quantile in quantiles:
- thresholds.append(predictions['Stdev'].quantile(quantile))
- # Calculate the accuracy of the model for each threshold
- accuracies = []
- # Calculate the accuracy of the model for each threshold
- for threshold, quantile in zip(thresholds, quantiles):
- filtered = predictions[predictions['Stdev'] <= threshold]
- correct = filtered['Correct'].sum()
- total = len(filtered)
- accuracy = correct / total
- false_positives = len(
- filtered[(filtered['Prediction'] >= 0.5) & (filtered['Actual'] == 0)]
- )
- false_negatives = len(
- filtered[(filtered['Prediction'] < 0.5) & (filtered['Actual'] == 1)]
- )
- f1 = 2 * correct / (2 * correct + false_positives + false_negatives)
- auc = metrics.roc_auc_score(filtered['Actual'], filtered['Prediction'])
- accuracies.append(
- {
- 'Threshold': threshold,
- 'Accuracy': accuracy,
- 'Quantile': quantile,
- 'F1': f1,
- 'AUC': auc,
- }
- )
- predictions.to_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
- )
- indv_predictions.to_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_predictions.csv"
- )
- return pd.DataFrame(accuracies), indv_metrics
- if RUN:
- result, indv = threshold(config)
- result.to_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
- )
- indv = pd.DataFrame([indv])
- indv.to_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
- )
- result = pd.read_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.csv"
- )
- predictions = pd.read_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/predictions.csv"
- )
- indv = pd.read_csv(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/indv_metrics.csv"
- )
- print(indv)
- plt.figure()
- plt.plot(result['Quantile'], result['Accuracy'], label='Ensemble Accuracy')
- plt.plot(
- result['Quantile'],
- [indv['Accuracy']] * len(result['Quantile']),
- label='Individual Accuracy',
- linestyle='--',
- )
- plt.legend()
- plt.title('Accuracy vs Coverage')
- plt.xlabel('Coverage')
- plt.ylabel('Accuracy')
- plt.gca().invert_xaxis()
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage.png"
- )
- plt.figure()
- plt.plot(result['Quantile'], result['F1'], label='Ensemble F1')
- plt.plot(
- result['Quantile'],
- [indv['F1']] * len(result['Quantile']),
- label='Individual F1',
- linestyle='--',
- )
- plt.legend()
- plt.title('F1 vs Coverage')
- plt.xlabel('Coverage')
- plt.ylabel('F1')
- plt.gca().invert_xaxis()
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_f1.png"
- )
- plt.figure()
- plt.plot(result['Quantile'], result['AUC'], label='Ensemble AUC')
- plt.plot(
- result['Quantile'],
- [indv['AUC']] * len(result['Quantile']),
- label='Individual AUC',
- linestyle='--',
- )
- plt.legend()
- plt.title('AUC vs Coverage')
- plt.xlabel('Coverage')
- plt.ylabel('AUC')
- plt.gca().invert_xaxis()
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/coverage_auc.png"
- )
- # create histogram of the incorrect predictions vs the uncertainty
- plt.figure()
- plt.hist(predictions[~predictions['Correct']]['Stdev'], bins=10)
- plt.xlabel('Uncertainty')
- plt.ylabel('Number of incorrect predictions')
- plt.savefig(
- f"{config['paths']['model_output']}{config['ensemble']['name']}/incorrect_predictions.png"
- )
|