123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- import pandas as pd
- import numpy as np
- import os
- import tomli as toml
- from utils.data.datasets import prepare_datasets
- import utils.ensemble as ens
- import torch
- import matplotlib.pyplot as plt
- import sklearn.metrics as metrics
- from tqdm import tqdm
- import utils.metrics as met
- import itertools as it
- RUN = False
- # CONFIGURATION
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open('config.toml', 'rb') as f:
- config = toml.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = toml.load(f)
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V2_PATH = ENSEMBLE_PATH + '/v2'
- # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
- # Want to convert to a predicted class and a confidence
- def output_to_confidence(result):
- predicted_class = torch.argmax(result).item()
- confidence = (torch.max(result).item() - 0.5) * 2
- return torch.Tensor([predicted_class, confidence])
- # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
- def get_predictions(config):
- models, model_descs = ens.load_models(
- f'{ENSEMBLE_PATH}/models/',
- config['training']['device'],
- )
- models = [model.to(config['training']['device']) for model in models]
- test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
- f'{ENSEMBLE_PATH}/val_dataset.pt'
- )
- # [([model results], labels)]
- results = []
- # [(class_1, class_2, true_label)]
- indv_results = []
- for i, (data, target) in tqdm(
- enumerate(test_set),
- total=len(test_set),
- desc='Getting predictions',
- unit='sample',
- ):
- mri, xls = data
- mri = mri.unsqueeze(0).to(config['training']['device'])
- xls = xls.unsqueeze(0).to(config['training']['device'])
- data = (mri, xls)
- res = []
- for j, model in enumerate(models):
- model.eval()
- with torch.no_grad():
- output = model(data)
- output = output.tolist()
- if j == 0:
- indv_results.append((output[0][0], output[0][1], target[1].item()))
- res.append(output)
- results.append((res, target.tolist()))
- # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
- # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
- # And one with the ensemble predicted class, ensemble standard deviation and true label
- # [(ensemble predicted class, ensemble confidence, true label)]
- confidences = []
- # [(ensemble predicted class, ensemble standard deviation, true label)]
- stdevs = []
- for result in results:
- model_results, true_label = result
- # Get the ensemble mean and variance with numpy, as these are lists
- mean = np.mean(model_results, axis=0)
- variance = np.var(model_results, axis=0)
- # Calculate confidence and standard deviation
- confidence = (np.max(mean) - 0.5) * 2
- stdev = np.sqrt(variance)
- # Get the predicted class
- predicted_class = np.argmax(mean)
- # Get the confidence and standard deviation of the predicted class
- print(stdev)
- pc_stdev = np.squeeze(stdev)[predicted_class]
- # Get the individual classes
- class_1 = mean[0][0]
- class_2 = mean[0][1]
- # Get the true label
- true_label = true_label[1]
- confidences.append((predicted_class, confidence, true_label, class_1, class_2))
- stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
- return results, confidences, stdevs, indv_results
- if RUN:
- results, confs, stdevs, indv_results = get_predictions(config)
- # Convert to pandas dataframes
- confs_df = pd.DataFrame(
- confs,
- columns=['predicted_class', 'confidence', 'true_label', 'class_1', 'class_2'],
- )
- stdevs_df = pd.DataFrame(
- stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
- )
- indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
- if not os.path.exists(V2_PATH):
- os.makedirs(V2_PATH)
- confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
- stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
- indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
- else:
- confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
- stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
- indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
- # Plot confidence vs standard deviation, and change color of dots based on if they are correct
- correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
- incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
- correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
- incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
- plt.scatter(correct_conf['confidence'], correct_stdev['stdev'], color='green')
- plt.scatter(incorrect_conf['confidence'], incorrect_stdev['stdev'], color='red')
- plt.xlabel('Confidence')
- plt.ylabel('Standard Deviation')
- plt.title('Confidence vs Standard Deviation')
- plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
- plt.close()
- # Calculate individual model accuracy
- # Determine predicted class
- indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
- indv_df['predicted_class'] = indv_df['predicted_class'].apply(
- lambda x: 0 if x == 'class_1' else 1
- )
- indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
- accuracy_indv = indv_df['correct'].mean()
- f1_indv = met.F1(
- indv_df['predicted_class'].to_numpy(), indv_df['true_label'].to_numpy()
- )
- auc_indv = metrics.roc_auc_score(
- indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
- )
- # Calculate percentiles for confidence and standard deviation
- quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
- 'confidence'
- ]
- quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
- 'stdev'
- ]
- accuracies_conf = []
- # Use the quantiles to calculate the coverage
- iter_conf = it.islice(quantiles_conf.items(), 0, None)
- for quantile in iter_conf:
- percentile = quantile[0]
- filt = confs_df[confs_df['confidence'] >= quantile[1]]
- accuracy = (
- filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
- )
- f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
- accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
- accuracies_df = pd.DataFrame(accuracies_conf)
- # Plot the coverage
- plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
- plt.plot(
- accuracies_df['percentile'],
- [accuracy_indv] * len(accuracies_df['percentile']),
- label='Individual',
- linestyle='--',
- )
- plt.xlabel('Percentile')
- plt.ylabel('Accuracy')
- plt.title('Coverage conf')
- plt.legend()
- plt.savefig(f'{V2_PATH}/coverage_conf.png')
- plt.close()
- # Plot coverage vs F1 for confidence
- plt.plot(accuracies_df['percentile'], accuracies_df['f1'], label='Ensemble')
- plt.plot(
- accuracies_df['percentile'],
- [f1_indv] * len(accuracies_df['percentile']),
- label='Individual',
- linestyle='--',
- )
- plt.xlabel('Percentile')
- plt.ylabel('F1')
- plt.title('Coverage F1')
- plt.legend()
- plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
- plt.close()
- # Repeat for standard deviation
- accuracies_stdev = []
- iter_stdev = it.islice(quantiles_stdev.items(), 0, None)
- for quantile in iter_stdev:
- percentile = quantile[0]
- filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
- accuracy = (
- filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
- )
- f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
- accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
- accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
- # Plot the coverage
- plt.plot(
- accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
- )
- plt.plot(
- accuracies_stdev_df['percentile'],
- [accuracy_indv] * len(accuracies_stdev_df['percentile']),
- label='Individual',
- linestyle='--',
- )
- plt.xlabel('Percentile')
- plt.ylabel('Accuracy')
- plt.title('Coverage Stdev')
- plt.legend()
- plt.gca().invert_xaxis()
- plt.savefig(f'{V2_PATH}/coverage_stdev.png')
- plt.close()
- # Plot coverage vs F1 for standard deviation
- plt.plot(accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], label='Ensemble')
- plt.plot(
- accuracies_stdev_df['percentile'],
- [f1_indv] * len(accuracies_stdev_df['percentile']),
- label='Individual',
- linestyle='--',
- )
- plt.xlabel('Percentile')
- plt.ylabel('F1')
- plt.title('Coverage F1 Stdev')
- plt.legend()
- plt.gca().invert_xaxis()
- plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
- plt.close()
- # Print overall accuracy
- overall_accuracy = (
- confs_df[confs_df['predicted_class'] == confs_df['true_label']].shape[0]
- / confs_df.shape[0]
- )
- overall_f1 = met.F1(
- confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
- )
- print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1}')
|