|
- import pandas as pd
- import numpy as np
- import os
- import tomli as toml
- from utils.data.datasets import prepare_datasets
- import utils.ensemble as ens
- import torch
- import matplotlib.pyplot as plt
- import sklearn.metrics as metrics
- from tqdm import tqdm
- import utils.metrics as met
- RUN = False
- # CONFIGURATION
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open('config.toml', 'rb') as f:
- config = toml.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = toml.load(f)
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V2_PATH = ENSEMBLE_PATH + '/v2'
- # Result is a 1x2 tensor, with the softmax of the 2 predicted classes
- # Want to convert to a predicted class and a confidence
- def output_to_confidence(result):
- predicted_class = torch.argmax(result).item()
- confidence = (torch.max(result).item() - 0.5) * 2
- return torch.Tensor([predicted_class, confidence])
- # This function conducts tests on the models and returns the results, as well as saving the predictions and metrics
- def get_predictions(config):
- models, model_descs = ens.load_models(
- f'{ENSEMBLE_PATH}/models/',
- config['training']['device'],
- )
- models = [model.to(config['training']['device']) for model in models]
- test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
- f'{ENSEMBLE_PATH}/val_dataset.pt'
- )
- # [([model results], labels)]
- results = []
- # [(class_1, class_2, true_label)]
- indv_results = []
- for i, (data, target) in tqdm(
- enumerate(test_set),
- total=len(test_set),
- desc='Getting predictions',
- unit='sample',
- ):
- mri, xls = data
- mri = mri.unsqueeze(0).to(config['training']['device'])
- xls = xls.unsqueeze(0).to(config['training']['device'])
- data = (mri, xls)
- res = []
- for j, model in enumerate(models):
- model.eval()
- with torch.no_grad():
- output = model(data)
- output = output.tolist()
- if j == 0:
- indv_results.append((output[0][0], output[0][1], target[1].item()))
- res.append(output)
- results.append((res, target.tolist()))
- # The results are a list of tuples, where each tuple contains a list of model outputs and the true label
- # We want to convert this to 2 list of tuples, one with the ensemble predicted class, ensemble confidence and true label
- # And one with the ensemble predicted class, ensemble standard deviation and true label
- # [(ensemble predicted class, ensemble confidence, true label)]
- confidences = []
- # [(ensemble predicted class, ensemble standard deviation, true label)]
- stdevs = []
- for result in results:
- model_results, true_label = result
- # Get the ensemble mean and variance with numpy, as these are lists
- mean = np.mean(model_results, axis=0)
- variance = np.var(model_results, axis=0)
- # Calculate confidence and standard deviation
- confidence = (np.max(mean) - 0.5) * 2
- stdev = np.sqrt(variance)
- # Get the predicted class
- predicted_class = np.argmax(mean)
- # Get the confidence and standard deviation of the predicted class
- print(stdev)
- pc_stdev = np.squeeze(stdev)[predicted_class]
- # Get the true label
- true_label = true_label[1]
- confidences.append((predicted_class, confidence, true_label))
- stdevs.append((predicted_class, pc_stdev, true_label))
- return results, confidences, stdevs, indv_results
- if RUN:
- results, confs, stdevs, indv_results = get_predictions(config)
- # Convert to pandas dataframes
- confs_df = pd.DataFrame(
- confs, columns=['predicted_class', 'confidence', 'true_label']
- )
- stdevs_df = pd.DataFrame(stdevs, columns=['predicted_class', 'stdev', 'true_label'])
- indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
- if not os.path.exists(V2_PATH):
- os.makedirs(V2_PATH)
- confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
- stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
- indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
- else:
- confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
- stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
- indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
- # Plot confidence vs standard deviation
- plt.scatter(confs_df['confidence'], stdevs_df['stdev'])
- plt.xlabel('Confidence')
- plt.ylabel('Standard Deviation')
- plt.title('Confidence vs Standard Deviation')
- plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
- plt.close()
- # Calculate Binning for Coverage
- # Sort Dataframes
- confs_df = confs_df.sort_values(by='confidence')
- stdevs_df = stdevs_df.sort_values(by='stdev')
- confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
- stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
- # Calculate individual model accuracy
- # Determine predicted class
- indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
- indv_df['predicted_class'] = indv_df['predicted_class'].apply(
- lambda x: 0 if x == 'class_1' else 1
- )
- indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
- accuracy_indv = indv_df['correct'].mean()
- # Calculate percentiles for confidence and standard deviation
- quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11))['confidence']
- quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11))['stdev']
- accuracies_conf = []
- # Use the quantiles to calculate the coverage
- for quantile in quantiles_conf.items():
- percentile = quantile[0]
- filt = confs_df[confs_df['confidence'] >= quantile[1]]
- accuracy = (
- filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
- )
- accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy})
- accuracies_df = pd.DataFrame(accuracies_conf)
- # Plot the coverage
- plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
- plt.plot(
- accuracies_df['percentile'],
- [accuracy_indv] * len(accuracies_df['percentile']),
- label='Individual',
- linestyle='--',
- )
- plt.xlabel('Percentile')
- plt.ylabel('Accuracy')
- plt.title('Coverage conf')
- plt.legend()
- plt.savefig(f'{V2_PATH}/coverage.png')
- plt.close()
- # Repeat for standard deviation
- accuracies_stdev = []
- for quantile in quantiles_stdev.items():
- percentile = quantile[0]
- filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
- accuracy = (
- filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
- )
- accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy})
- accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
- # Plot the coverage
- plt.plot(
- accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
- )
- plt.plot(
- accuracies_stdev_df['percentile'],
- [accuracy_indv] * len(accuracies_stdev_df['percentile']),
- label='Individual',
- linestyle='--',
- )
- plt.xlabel('Percentile')
- plt.ylabel('Accuracy')
- plt.title('Coverage Stdev')
- plt.legend()
- plt.gca().invert_xaxis()
- plt.savefig(f'{V2_PATH}/coverage_stdev.png')
- plt.close()
|