Nicholas Schense 1 неделя назад
Родитель
Сommit
56a6a93e65
5 измененных файлов с 267 добавлено и 18 удалено
  1. 1 1
      config.toml
  2. 21 0
      data/dataset.py
  3. 73 0
      evaluation/confidence_percentile.py
  4. 134 0
      evaluation/sanity_check.py
  5. 38 17
      requirements.txt

+ 1 - 1
config.toml

@@ -17,4 +17,4 @@ learning_rate = 0.0001
 num_epochs = 30
 
 [output]
-path = "../models/Full_Ensemble(50x30)/"
+path = "/home/nschense/Medphys_Research/models/Full_Ensemble(50x30)_PTSPLIT/"

+ 21 - 0
data/dataset.py

@@ -190,3 +190,24 @@ def initalize_dataloaders(
     return [
         DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
     ]
+
+
+def divide_dataset_by_patient_id(
+    dataset: ADNIDataset,
+    ptids: List[Tuple[int, int]],
+    ratios: Tuple[float, float, float],
+    seed: int,
+) -> List[data.Subset[ADNIDataset]]:
+    """
+    Divides the dataset into training, validation, and test sets based on patient IDs.
+    Ensures that all samples from the same patient are in the same set.
+
+    Args:
+        dataset (ADNIDataset): The dataset to divide.
+        ptids (List[Tuple[int, int]]): A list of tuples containing image file ids and their corresponding patient IDs.
+        ratios (Tuple[float, float, float]): The ratios for training, validation, and test sets.
+        seed (int): The random seed for reproducibility.
+    Returns:
+        List[data.Subset[ADNIDataset]]: A list of subsets for training, validation, and test sets.
+
+    """

+ 73 - 0
evaluation/confidence_percentile.py

@@ -0,0 +1,73 @@
+import xarray as xr
+import numpy as np
+import sys
+import os
+
+import matplotlib.pyplot as plt
+
+
+sys.path.append(
+    os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+)  # to allow imports from parent directory
+from utils.config import config
+import pathlib as pl
+
+import colorama as clr
+
+model_dataset_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
+array = xr.open_dataset(model_dataset_path)  # type: ignore
+
+
+predictions: xr.DataArray = array["predictions"]
+labels: xr.DataArray = array["labels"]
+
+# Make plots directory if it doesn't exist
+plots_dir = pl.Path(config["output"]["path"]) / "plots"
+plots_dir.mkdir(parents=True, exist_ok=True)
+
+# This script calculates and plots accuracy vs minimum confidence percentile threshold
+
+# Average predictions across models
+avg_predictions = predictions.mean(dim="model")
+# Get confidence scores for the positive class
+confidence_scores = avg_predictions.sel(img_class=1).values
+true_labels = labels.sel(label=1).values
+
+
+# Calculate accuracy at different confidence percentiles
+percentiles = np.linspace(0, 100, num=21)
+accuracies = []
+for p in percentiles:
+    absolute_confidences = 2 * np.abs(confidence_scores - 0.5)
+    threshold = np.percentile(absolute_confidences, p)
+
+    # Filter the predictions such that only those with absolute confidence above the threshold are considered
+    selected_indices = np.where(absolute_confidences >= threshold)[0]
+    if len(selected_indices) == 0:
+        accuracies.append(0.0)
+        continue
+    selected_confidences = confidence_scores[selected_indices]
+    selected_true_labels = true_labels[selected_indices]
+
+    predicted_positive = selected_confidences >= 0.5
+    true_positive = selected_true_labels == 1
+
+    correct_predictions = (predicted_positive == true_positive).sum().item()
+    total_predictions = len(selected_confidences)
+    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
+    accuracies.append(accuracy)
+
+# Plot accuracy vs confidence percentile threshold
+plt.figure(figsize=(10, 6))
+plt.plot(percentiles, accuracies, marker="o")
+plt.title("Accuracy vs Confidence Percentile Threshold")
+plt.xlabel("Confidence Percentile Threshold")
+plt.ylabel("Accuracy")
+plt.grid()
+plt.xticks(percentiles)
+
+plt.savefig(
+    pl.Path(config["output"]["path"])
+    / "plots"
+    / "accuracy_vs_confidence_percentile.png"
+)

+ 134 - 0
evaluation/sanity_check.py

@@ -0,0 +1,134 @@
+import xarray as xr
+import numpy as np
+import sys
+import os
+
+
+sys.path.append(
+    os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+)  # to allow imports from parent directory
+from utils.config import config
+import pathlib as pl
+
+import colorama as clr
+
+model_dataset_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
+array = xr.open_dataset(model_dataset_path)  # type: ignore
+
+
+predictions: xr.DataArray = array["predictions"]
+labels: xr.DataArray = array["labels"]
+
+# Average predictions across models
+avg_predictions = predictions.mean(dim="model")
+
+
+# Sort from highest to lowest confidence for the positive class (img_class=1)
+sorted_indices = np.argsort(-avg_predictions.sel(img_class=1).values)
+sorted_avg_predictions = avg_predictions.isel(img_id=sorted_indices)
+sorted_labels = labels.isel(img_id=sorted_indices)
+
+# Print out all predictions with their labels
+top_n = sorted_avg_predictions.sizes[
+    "img_id"
+]  # Change this value to print more or fewer
+print(
+    clr.Fore.CYAN
+    + f"Top {top_n} Predictions (Confidence for Positive Class):"
+    + clr.Style.RESET_ALL
+)
+for i in range(top_n):
+    confidence = sorted_avg_predictions.sel(img_class=1).isel(img_id=i).item()
+    label = sorted_labels.isel(img_id=i, label=1).values
+
+    correctness = (
+        "CORRECT"
+        if (confidence >= 0.5 and label == 1) or (confidence < 0.5 and label == 0)
+        else "INCORRECT"
+    )
+    color = clr.Fore.GREEN if correctness == "CORRECT" else clr.Fore.RED
+    print(
+        f"Image ID: {sorted_avg_predictions.img_id.isel(img_id=i).item():<8}, "
+        f"Confidence: {confidence:.4f}, "
+        f"Label: {label:<3}, " + color + f"{correctness:<9}" + clr.Style.RESET_ALL
+    )
+
+
+# Calculate overall accuracy
+predicted_positive = avg_predictions.sel(img_class=1) >= 0.5
+true_positive = labels.sel(label=1) == 1
+correct_predictions = (predicted_positive == true_positive).sum().item()
+total_predictions = len(avg_predictions.img_id)
+overall_accuracy = (
+    correct_predictions / total_predictions if total_predictions > 0 else 0.0
+)
+print(
+    clr.Fore.MAGENTA
+    + f"\nOverall Accuracy (Threshold 0.5): {overall_accuracy:.4f}"
+    + clr.Style.RESET_ALL
+)
+
+
+# Then go through all individual models and print out their accuracies for comparison, sorted from highest to lowest
+model_accuracies = []
+for model_idx in predictions.coords["model"].values:
+    model_preds = predictions.sel(model=model_idx)
+    predicted_positive = model_preds.sel(img_class=1) >= 0.5
+    correct_predictions = (predicted_positive == true_positive).sum().item()
+    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
+    model_accuracies.append((model_idx, accuracy))
+
+# Sort by accuracy
+model_accuracies.sort(key=lambda x: x[1], reverse=True)
+print(
+    clr.Fore.CYAN
+    + f"\nIndividual Model Accuracies (Threshold 0.5):"
+    + clr.Style.RESET_ALL
+)
+for model_idx, accuracy in model_accuracies:
+    print(f"Model {int(model_idx):<3}: Accuracy: {accuracy:.4f}")
+
+
+# Then calculate the average accuracy if we were to ensemble the top K models, for K=1 to total number of models
+total_models = len(predictions.coords["model"].values)
+ensemble_accuracies = []
+for k in range(1, total_models + 1):
+    top_k_models = [ma[0] for ma in model_accuracies[:k]]
+    ensemble_preds = predictions.sel(model=top_k_models).mean(dim="model")
+    predicted_positive = ensemble_preds.sel(img_class=1) >= 0.5
+    correct_predictions = (predicted_positive == true_positive).sum().item()
+    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
+    ensemble_accuracies.append((k, accuracy))
+print(
+    clr.Fore.CYAN
+    + f"\nEnsemble Accuracies for Top K Models (Threshold 0.5):"
+    + clr.Style.RESET_ALL
+)
+for k, accuracy in ensemble_accuracies:
+    print(f"Top {k:<3} Models: Ensemble Accuracy: {accuracy:.4f}")
+
+
+# Finally, identify the top 5 most confidently incorrect predictions
+incorrect_predictions = []
+for i in range(len(avg_predictions.img_id)):
+    confidence = avg_predictions.sel(img_class=1).isel(img_id=i).item()
+    label = labels.isel(img_id=i, label=1).values
+    predicted_label = 1 if confidence >= 0.5 else 0
+    if predicted_label != label:
+        incorrect_predictions.append((i, confidence, label))
+# Sort by confidence
+incorrect_predictions.sort(key=lambda x: -abs(x[1] - 0.5))
+top_incorrect = incorrect_predictions[:5]
+print(
+    clr.Fore.YELLOW
+    + f"\nTop 5 Most Confident Incorrect Predictions:"
+    + clr.Style.RESET_ALL
+)
+for i, confidence, label in top_incorrect:
+    predicted_label = 1 if confidence >= 0.5 else 0
+    print(
+        f"Image ID: {avg_predictions.img_id.isel(img_id=i).item():<8}, "
+        f"Confidence: {confidence:.4f}, "
+        f"Predicted Label: {predicted_label:<3}, "
+        f"True Label: {label:<3}"
+    )

+ 38 - 17
requirements.txt

@@ -1,27 +1,48 @@
-filelock==3.18.0
-fsspec==2025.3.2
+colorama==0.4.6
+contourpy==1.3.3
+cycler==0.12.1
+filelock==3.13.1
+fonttools==4.60.1
+fsspec==2024.6.1
 jaxtyping==0.3.2
-Jinja2==3.1.6
-MarkupSafe==3.0.2
+Jinja2==3.1.4
+kiwisolver==1.4.9
+MarkupSafe==2.1.5
+matplotlib==3.10.7
 mpmath==1.3.0
-networkx==3.4.2
+networkx==3.3
 nibabel==5.3.2
-numpy==2.2.6
+numpy==2.1.2
+nvidia-cublas-cu12==12.9.1.4
+nvidia-cuda-cupti-cu12==12.9.79
+nvidia-cuda-nvrtc-cu12==12.9.86
+nvidia-cuda-runtime-cu12==12.9.79
+nvidia-cudnn-cu12==9.10.2.21
+nvidia-cufft-cu12==11.4.1.4
+nvidia-cufile-cu12==1.14.1.1
+nvidia-curand-cu12==10.3.10.19
+nvidia-cusolver-cu12==11.7.5.82
+nvidia-cusparse-cu12==12.5.10.65
+nvidia-cusparselt-cu12==0.7.1
+nvidia-nccl-cu12==2.27.3
+nvidia-nvjitlink-cu12==12.9.86
+nvidia-nvtx-cu12==12.9.79
 packaging==25.0
-pandas==2.2.3
-pillow==11.2.1
+pandas==2.3.2
+pillow==11.0.0
+pyparsing==3.2.5
 python-dateutil==2.9.0.post0
 pytz==2025.2
 result==0.17.0
-scipy==1.15.3
-setuptools==80.8.0
+scipy==1.16.2
+setuptools==70.2.0
 six==1.17.0
-sympy==1.13.1
-torch==2.6.0
-torchaudio==2.6.0
-torchvision==0.21.0
+sympy==1.13.3
+torch==2.8.0+cu129
+torchvision==0.23.0+cu129
 tqdm==4.67.1
-typing_extensions==4.13.2
+triton==3.4.0
+typing_extensions==4.12.2
 tzdata==2025.2
-wadler_lindig==0.1.6
-xarray==2025.4.0
+wadler_lindig==0.1.7
+xarray==2025.9.0