Sfoglia il codice sorgente

Begin full threshold refactor

Nicholas Schense 4 mesi fa
parent
commit
8a70abc2e8
1 ha cambiato i file con 99 aggiunte e 0 eliminazioni
  1. 99 0
      threshold_refac.py

+ 99 - 0
threshold_refac.py

@@ -0,0 +1,99 @@
+import pandas as pd
+import numpy as np
+import os
+import tomli as toml
+from utils.data.datasets import prepare_datasets
+import utils.ensemble as ens
+import torch
+import matplotlib.pyplot as plt
+import sklearn.metrics as metrics
+from tqdm import tqdm
+import utils.metrics as met
+import itertools as it
+import matplotlib.ticker as ticker
+import glob
+
+# CONFIGURATION
+if os.getenv('ADL_CONFIG_PATH') is None:
+    with open('config.toml', 'rb') as f:
+        config = toml.load(f)
+else:
+    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+        config = toml.load(f)
+
+ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+
+V2_PATH = ENSEMBLE_PATH + '/v2'
+
+
+# Models is a dictionary with the model ids as keys and the model data as values
+def get_model_predictions(models, data):
+    predictions = {}
+    for model_id, model in models.items():
+        model.eval()
+        with torch.no_grad:
+            # Get the predictions
+            output = model(data)
+            predictions[model_id] = output.detach().cpu().numpy()
+
+    return predictions
+
+
+def load_models_v2(folder, device):
+    glob_path = os.path.join(folder, '*.pt')
+    model_files = glob(glob_path)
+    model_dict = {}
+
+    for model_file in model_files:
+        model = torch.load(model_file, map_location=device)
+        model_id = os.path.basename(model_file).split('_')[0]
+        model_dict[model_id] = model
+
+    return model_dict
+
+
+# Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
+def preprocess_data(data, device):
+    mri, xls = data
+    mri = mri.unsqueeze(0).to(device)
+    xls = xls.unsqueeze(0).to(device)
+    return (mri, xls)
+
+
+def ensemble_dataset_predictions(models, dataset, device):
+    # For each datapoint, get the predictions of each model
+    predictions = {}
+    for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)):
+        # Preprocess data
+        data = preprocess_data(data, device)
+        # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second
+        # The key is the id of the image
+        predictions[i] = (
+            target.detach().cpu().numpy(),
+            get_model_predictions(models, data),
+        )
+
+    return predictions
+
+
+# Given a dictionary of predictions, select one model and eliminate the rest
+def select_individual_model(predictions, model_id):
+    selected_model_predictions = {}
+    for key, value in predictions.items():
+        selected_model_predictions[key] = (value[0], {model_id: value[1][model_id]})
+    return selected_model_predictions
+
+
+# Given a dictionary of predictions, select a subset of models and eliminate the rest
+def select_subset_models(predictions, model_ids):
+    selected_model_predictions = {}
+    for key, value in predictions.items():
+        selected_model_predictions[key] = (
+            value[0],
+            {model_id: value[1][model_id] for model_id in model_ids},
+        )
+    return selected_model_predictions
+
+
+# Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, accuracy, f1) for each result
+def calculate_statistics(predictions):