Browse Source

Rename files to fix issues with dnn_to_bnn()

Nicholas Schense 6 days ago
parent
commit
bd7ecee3c1
9 changed files with 737 additions and 46 deletions
  1. 2 1
      config.toml
  2. 102 3
      data/dataset.py
  3. 6 6
      model/cnn.py
  4. 23 23
      model/layers.py
  5. 1 0
      requirements.txt
  6. 195 0
      run_overnight_training.py
  7. 214 0
      train_bayesian.py
  8. 30 10
      train_ensemble.py
  9. 164 3
      utils/training.py

+ 2 - 1
config.toml

@@ -17,4 +17,5 @@ learning_rate = 0.0001
 num_epochs = 30
 
 [output]
-path = "/home/nschense/Medphys_Research/models/Full_Ensemble(50x30)_PTSPLIT/"
+ensemble_path = "/home/nschense/Medphys_Research/models/Full_Ensemble(50x30)_PTSPLIT/"
+bayesian_path = "/home/nschense/Medphys_Research/models/Bayesian_PTSPLIT/"

+ 102 - 3
data/dataset.py

@@ -5,10 +5,11 @@ import pathlib as pl
 import pandas as pd
 from torch.utils.data import Subset, DataLoader
 import re
+import random
 
 
 from jaxtyping import Float
-from typing import Tuple, Iterator, Callable, List
+from typing import Tuple, Iterator, Callable, List, Dict
 
 
 class ADNIDataset(data.Dataset):  # type: ignore
@@ -194,7 +195,7 @@ def initalize_dataloaders(
 
 def divide_dataset_by_patient_id(
     dataset: ADNIDataset,
-    ptids: List[Tuple[int, int]],
+    ptids: List[Tuple[int, str]],
     ratios: Tuple[float, float, float],
     seed: int,
 ) -> List[data.Subset[ADNIDataset]]:
@@ -204,10 +205,108 @@ def divide_dataset_by_patient_id(
 
     Args:
         dataset (ADNIDataset): The dataset to divide.
-        ptids (List[Tuple[int, int]]): A list of tuples containing image file ids and their corresponding patient IDs.
+        ptids (List[Tuple[int, str]]): A list of tuples containing image file IDs and their corresponding patient IDs.
         ratios (Tuple[float, float, float]): The ratios for training, validation, and test sets.
         seed (int): The random seed for reproducibility.
+
     Returns:
         List[data.Subset[ADNIDataset]]: A list of subsets for training, validation, and test sets.
 
+    Notes:
+        This split is grouped by PTID, so all images from the same patient are assigned
+        to exactly one partition to avoid patient-level leakage across train/val/test.
     """
+    if sum(ratios) != 1.0:
+        raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
+
+    if not ptids:
+        raise ValueError("ptids list cannot be empty.")
+
+    image_to_patient: Dict[int, str] = {}
+    for image_id, patient_id in ptids:
+        image_id_int = int(image_id)
+        patient_id_str = str(patient_id).strip()
+        if not patient_id_str or patient_id_str.lower() == "nan":
+            raise ValueError(f"Invalid PTID for Image Data ID {image_id_int}.")
+
+        if (
+            image_id_int in image_to_patient
+            and image_to_patient[image_id_int] != patient_id_str
+        ):
+            raise ValueError(
+                f"Conflicting PTIDs for Image Data ID {image_id_int}: "
+                f"{image_to_patient[image_id_int]} vs {patient_id_str}."
+            )
+
+        image_to_patient[image_id_int] = patient_id_str
+
+    patient_to_indices: Dict[str, List[int]] = {}
+    for idx, image_id in enumerate(dataset.filename_ids):
+        if image_id not in image_to_patient:
+            raise ValueError(
+                f"Missing PTID mapping for dataset Image Data ID {image_id}."
+            )
+
+        patient_id = image_to_patient[image_id]
+        if patient_id not in patient_to_indices:
+            patient_to_indices[patient_id] = []
+        patient_to_indices[patient_id].append(idx)
+
+    shuffled_patients = list(patient_to_indices.keys())
+    random.Random(seed).shuffle(shuffled_patients)
+
+    train_cutoff = int(len(shuffled_patients) * ratios[0])
+    val_cutoff = train_cutoff + int(len(shuffled_patients) * ratios[1])
+
+    train_patients = shuffled_patients[:train_cutoff]
+    val_patients = shuffled_patients[train_cutoff:val_cutoff]
+    test_patients = shuffled_patients[val_cutoff:]
+
+    train_patient_set = set(train_patients)
+    val_patient_set = set(val_patients)
+    test_patient_set = set(test_patients)
+
+    if (
+        train_patient_set & val_patient_set
+        or train_patient_set & test_patient_set
+        or val_patient_set & test_patient_set
+    ):
+        raise ValueError("Patient separation violated across train/val/test splits.")
+
+    all_patients = set(patient_to_indices.keys())
+    if train_patient_set | val_patient_set | test_patient_set != all_patients:
+        raise ValueError("Not all patients were assigned to a split.")
+
+    train_indices = [
+        idx for patient_id in train_patients for idx in patient_to_indices[patient_id]
+    ]
+    val_indices = [
+        idx for patient_id in val_patients for idx in patient_to_indices[patient_id]
+    ]
+    test_indices = [
+        idx for patient_id in test_patients for idx in patient_to_indices[patient_id]
+    ]
+
+    train_index_set = set(train_indices)
+    val_index_set = set(val_indices)
+    test_index_set = set(test_indices)
+
+    if (
+        train_index_set & val_index_set
+        or train_index_set & test_index_set
+        or val_index_set & test_index_set
+    ):
+        raise ValueError("Sample index overlap detected across train/val/test splits.")
+
+    all_split_indices = train_index_set | val_index_set | test_index_set
+    expected_indices = set(range(len(dataset)))
+    if all_split_indices != expected_indices:
+        raise ValueError(
+            "Split coverage check failed: not all dataset samples are assigned exactly once."
+        )
+
+    return [
+        Subset(dataset, train_indices),
+        Subset(dataset, val_indices),
+        Subset(dataset, test_indices),
+    ]

+ 6 - 6
model/cnn.py

@@ -9,7 +9,7 @@ class CNN_Image_Section(nn.Module):
     def __init__(self, image_channels: int, droprate: float = 0.0):
         super().__init__()
         # Initial Convolutional Blocks
-        self.conv1 = ly.ConvBlock(
+        self.cnv1 = ly.CNVBlock(
             image_channels,
             192,
             (11, 13, 11),
@@ -17,22 +17,22 @@ class CNN_Image_Section(nn.Module):
             droprate=droprate,
             pool=False,
         )
-        self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
+        self.cnv2 = ly.CNVBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
 
         # Midflow Block
         self.midflow = ly.MidFlowBlock(384, droprate)
 
         # Split Convolutional Block
-        self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
+        self.splitcnv = ly.SplitCNVBlock(384, 192, 96, 1, droprate)
 
         # Fully Connected Block
         self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
 
     def forward(self, x: Float[torch.Tensor, "N C D H W"]):
-        x = self.conv1(x)
-        x = self.conv2(x)
+        x = self.cnv1(x)
+        x = self.cnv2(x)
         x = self.midflow(x)
-        x = self.splitconv(x)
+        x = self.splitcnv(x)
         x = torch.flatten(x, 1)
         x = self.fc_image(x)
 

+ 23 - 23
model/layers.py

@@ -4,7 +4,7 @@ import torch
 from typing import Tuple
 
 
-class SepConv3d(nn.Module):
+class SepCNV3d(nn.Module):
     def __init__(
         self,
         in_channels: int,
@@ -14,7 +14,7 @@ class SepConv3d(nn.Module):
         padding: int | str = 0,
         bias: bool = False,
     ):
-        super(SepConv3d, self).__init__()
+        super(SepCNV3d, self).__init__()
         self.depthwise = nn.Conv3d(
             in_channels,
             out_channels,
@@ -30,7 +30,7 @@ class SepConv3d(nn.Module):
         return x
 
 
-class SplitConvBlock(nn.Module):
+class SplitCNVBlock(nn.Module):
     def __init__(
         self,
         in_channels: int,
@@ -39,29 +39,29 @@ class SplitConvBlock(nn.Module):
         split_dim: int,
         drop_rate: float,
     ):
-        super(SplitConvBlock, self).__init__()
+        super(SplitCNVBlock, self).__init__()
 
         self.split_dim = split_dim
 
-        self.leftconv_1 = SepConvBlock(
+        self.leftcnv_1 = SepCNVBlock(
             in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
         )
-        self.rightconv_1 = SepConvBlock(
+        self.rightcnv_1 = SepCNVBlock(
             in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
         )
 
-        self.leftconv_2 = SepConvBlock(
+        self.leftcnv_2 = SepCNVBlock(
             mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
         )
-        self.rightconv_2 = SepConvBlock(
+        self.rightcnv_2 = SepCNVBlock(
             mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
         )
 
     def forward(self, x: Float[torch.Tensor, "N C D H W"]):
         (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
 
-        self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
-        self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
+        self.leftblock = nn.Sequential(self.leftcnv_1, self.leftcnv_2)
+        self.rightblock = nn.Sequential(self.rightcnv_1, self.rightcnv_2)
 
         left = self.leftblock(left)
         right = self.rightblock(right)
@@ -73,25 +73,25 @@ class MidFlowBlock(nn.Module):
     def __init__(self, channels: int, drop_rate: float):
         super(MidFlowBlock, self).__init__()
 
-        self.conv1 = ConvBlock(
+        self.cnv1 = CNVBlock(
             channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
         )
-        self.conv2 = ConvBlock(
+        self.cnv2 = CNVBlock(
             channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
         )
-        self.conv3 = ConvBlock(
+        self.cnv3 = CNVBlock(
             channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
         )
 
-        # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
-        self.block = self.conv1
+        # self.block = nn.Sequential(self.cnv1, self.cnv2, self.cnv3)
+        self.block = self.cnv1
 
     def forward(self, x: Float[torch.Tensor, "N C D H W"]):
         a = nn.ELU()(self.block(x) + x)
         return a
 
 
-class ConvBlock(nn.Module):
+class CNVBlock(nn.Module):
     def __init__(
         self,
         in_channels: int,
@@ -102,8 +102,8 @@ class ConvBlock(nn.Module):
         droprate: float = 0.0,
         pool: bool = False,
     ):
-        super(ConvBlock, self).__init__()
-        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
+        super(CNVBlock, self).__init__()
+        self.cnv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
         self.norm = nn.BatchNorm3d(out_channels)
         self.elu = nn.ELU()
         self.dropout = nn.Dropout(droprate)
@@ -114,7 +114,7 @@ class ConvBlock(nn.Module):
             self.maxpool = None
 
     def forward(self, x: Float[torch.Tensor, "N C D H W"]):
-        a = self.conv(x)
+        a = self.cnv(x)
         a = self.norm(a)
         a = self.elu(a)
 
@@ -142,7 +142,7 @@ class FullConnBlock(nn.Module):
         return x
 
 
-class SepConvBlock(nn.Module):
+class SepCNVBlock(nn.Module):
     def __init__(
         self,
         in_channels: int,
@@ -153,8 +153,8 @@ class SepConvBlock(nn.Module):
         droprate: float = 0.0,
         pool: bool = False,
     ):
-        super(SepConvBlock, self).__init__()
-        self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
+        super(SepCNVBlock, self).__init__()
+        self.cnv = SepCNV3d(in_channels, out_channels, kernel_size, stride, padding)
         self.norm = nn.BatchNorm3d(out_channels)
         self.elu = nn.ELU()
         self.dropout = nn.Dropout(droprate)
@@ -165,7 +165,7 @@ class SepConvBlock(nn.Module):
             self.maxpool = None
 
     def forward(self, x: Float[torch.Tensor, "N C D H W"]):
-        x = self.conv(x)
+        x = self.cnv(x)
         x = self.norm(x)
         x = self.elu(x)
 

+ 1 - 0
requirements.txt

@@ -46,3 +46,4 @@ typing_extensions==4.12.2
 tzdata==2025.2
 wadler_lindig==0.1.7
 xarray==2025.9.0
+bayesian-torch

+ 195 - 0
run_overnight_training.py

@@ -0,0 +1,195 @@
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import pathlib as pl
+import subprocess
+import sys
+import time
+from datetime import datetime, timezone
+
+
+def utc_now_iso() -> str:
+    return datetime.now(timezone.utc).isoformat()
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description=(
+            "Run ensemble training followed by bayesian training using existing scripts."
+        )
+    )
+    parser.add_argument(
+        "--python",
+        default=sys.executable,
+        help="Python executable to use for child scripts (default: current interpreter).",
+    )
+    parser.add_argument(
+        "--workdir",
+        default=str(pl.Path(__file__).resolve().parent),
+        help="Working directory for training scripts (default: this script directory).",
+    )
+    parser.add_argument(
+        "--log-dir",
+        default=None,
+        help="Directory for log and summary files (default: <workdir>/logs).",
+    )
+    parser.add_argument(
+        "--continue-on-error",
+        action="store_true",
+        help="Continue to the next stage even if a stage fails.",
+    )
+    parser.add_argument(
+        "--dry-run",
+        action="store_true",
+        help="Print resolved commands and paths without running training.",
+    )
+    return parser.parse_args()
+
+
+def run_stage(
+    stage_name: str,
+    command: list[str],
+    workdir: pl.Path,
+    log_file: pl.Path,
+) -> dict[str, object]:
+    started = utc_now_iso()
+    start_time = time.monotonic()
+
+    log_file.parent.mkdir(parents=True, exist_ok=True)
+    with open(log_file, "w", encoding="utf-8") as log:
+        log.write(f"[{started}] Starting stage: {stage_name}\n")
+        log.write(f"Command: {' '.join(command)}\n")
+        log.write(f"Working directory: {workdir}\n\n")
+
+        process = subprocess.Popen(
+            command,
+            cwd=str(workdir),
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+            text=True,
+            bufsize=1,
+            universal_newlines=True,
+            env=os.environ.copy(),
+        )
+
+        if process.stdout is not None:
+            for line in process.stdout:
+                print(line, end="")
+                log.write(line)
+
+        return_code = process.wait()
+
+        finished = utc_now_iso()
+        duration_seconds = time.monotonic() - start_time
+        log.write(
+            (
+                "\n"
+                f"[{finished}] Finished stage: {stage_name}\n"
+                f"Exit code: {return_code}\n"
+                f"Duration seconds: {duration_seconds:.2f}\n"
+            )
+        )
+
+    return {
+        "stage": stage_name,
+        "command": command,
+        "started_at_utc": started,
+        "finished_at_utc": finished,
+        "duration_seconds": duration_seconds,
+        "exit_code": return_code,
+        "status": "success" if return_code == 0 else "failed",
+        "log_file": str(log_file),
+    }
+
+
+def main() -> int:
+    args = parse_args()
+
+    workdir = pl.Path(args.workdir).resolve()
+    log_dir = (
+        pl.Path(args.log_dir).resolve()
+        if args.log_dir is not None
+        else workdir / "logs"
+    )
+    run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+    run_log_dir = log_dir / f"overnight_{run_stamp}"
+
+    ensemble_script = workdir / "train_ensemble.py"
+    bayesian_script = workdir / "train_bayesian.py"
+
+    for script in (ensemble_script, bayesian_script):
+        if not script.exists():
+            print(f"Error: required script not found: {script}")
+            return 2
+
+    stages = [
+        ("ensemble", [args.python, str(ensemble_script)]),
+        ("bayesian", [args.python, str(bayesian_script)]),
+    ]
+
+    if args.dry_run:
+        print("Dry run: no training scripts will be executed.")
+        print(f"Working directory: {workdir}")
+        print(f"Run log directory: {run_log_dir}")
+        for stage_name, command in stages:
+            print(f"Stage {stage_name}: {' '.join(command)}")
+        return 0
+
+    run_log_dir.mkdir(parents=True, exist_ok=True)
+    summary_path = run_log_dir / "run_summary.json"
+
+    run_started = utc_now_iso()
+    run_start_time = time.monotonic()
+    stage_results: list[dict[str, object]] = []
+
+    final_exit_code = 0
+    for stage_name, command in stages:
+        print(f"\n=== Starting {stage_name} training ===")
+        log_file = run_log_dir / f"{stage_name}.log"
+        result = run_stage(stage_name, command, workdir, log_file)
+        stage_results.append(result)
+
+        if int(result["exit_code"]) != 0 and not args.continue_on_error:
+            final_exit_code = int(result["exit_code"])
+            print(
+                (
+                    f"Stage '{stage_name}' failed with exit code {result['exit_code']}. "
+                    "Stopping because --continue-on-error was not set."
+                )
+            )
+            break
+
+    if final_exit_code == 0:
+        failed = [r for r in stage_results if int(r["exit_code"]) != 0]
+        if failed:
+            final_exit_code = int(failed[-1]["exit_code"])
+
+    run_finished = utc_now_iso()
+    total_duration = time.monotonic() - run_start_time
+
+    summary = {
+        "run_started_at_utc": run_started,
+        "run_finished_at_utc": run_finished,
+        "total_duration_seconds": total_duration,
+        "workdir": str(workdir),
+        "python_executable": args.python,
+        "continue_on_error": args.continue_on_error,
+        "final_exit_code": final_exit_code,
+        "overall_status": "success" if final_exit_code == 0 else "failed",
+        "stages": stage_results,
+    }
+
+    with open(summary_path, "w", encoding="utf-8") as f:
+        json.dump(summary, f, indent=2)
+
+    print("\n=== Overnight run complete ===")
+    print(f"Summary: {summary_path}")
+    print(f"Logs directory: {run_log_dir}")
+
+    return final_exit_code
+
+
+if __name__ == "__main__":
+    raise SystemExit(main())

+ 214 - 0
train_bayesian.py

@@ -0,0 +1,214 @@
+# Torch
+import torch.nn as nn
+import torch
+import torch.optim as optim
+
+# Config
+from utils.config import config
+import pathlib as pl
+import pandas as pd
+import json
+import sqlite3 as sql
+from typing import Callable, cast
+import os
+
+# Bayesian torch
+from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss  # type: ignore[import-untyped]
+
+# Custom modules
+from model.cnn import CNN3D
+from utils.training import train_model_bayesian, test_model_bayesian
+from data.dataset import (
+    load_adni_data_from_file,
+    divide_dataset_by_patient_id,
+    initalize_dataloaders,
+)
+
+# Load data
+mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
+xls_file = pl.Path(config["data"]["xls_file_path"])
+
+# If current directory is MedPhys_Research, change to alnn_rewrite for relative imports to work
+try:
+    os.chdir("alnn_rewrite")
+except FileNotFoundError:
+    pass
+
+
+def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
+    """
+    Preprocess the Excel DataFrame.
+    This function can be customized to filter or modify the DataFrame as needed.
+    """
+
+    data = df[["Image Data ID", "Sex", "Age (current)"]]
+    data["Sex"] = data["Sex"].str.strip()  # type: ignore
+    data = data.replace({"M": 0, "F": 1})  # type: ignore
+    data.set_index("Image Data ID")  # type: ignore
+
+    return data
+
+
+dataset = load_adni_data_from_file(
+    mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
+)
+
+if config["data"]["seed"] is None:
+    print("Warning: No seed provided for dataset division, using default seed 0")
+    config["data"]["seed"] = 0
+
+ptid_df = pd.read_csv(xls_file)
+ptid_df.columns = ptid_df.columns.str.strip()
+
+ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(  # type: ignore
+    subset=["Image Data ID", "PTID"]
+)
+ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
+ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
+ptid_df = ptid_df[ptid_df["PTID"] != ""]
+
+ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
+
+# Split is grouped by PTID to prevent patient-level leakage across partitions.
+datasets = divide_dataset_by_patient_id(
+    dataset,
+    ptids,
+    config["data"]["data_splits"],
+    seed=config["data"]["seed"],
+)
+
+# Initialize the dataloaders
+train_loader, val_loader, test_loader = initalize_dataloaders(
+    datasets, batch_size=config["training"]["batch_size"]
+)
+
+bayesian_output_path = pl.Path(config["output"]["bayesian_path"])
+
+# Save seed to output config file
+output_config_path = bayesian_output_path / "config.json"
+if not output_config_path.parent.exists():
+    output_config_path.parent.mkdir(parents=True, exist_ok=True)
+
+with open(output_config_path, "w") as f:
+    json.dump(config, f, indent=4)
+print(f"Configuration saved to {output_config_path}")
+
+# Initialize the model
+model = (
+    CNN3D(
+        image_channels=config["data"]["image_channels"],
+        clin_data_channels=config["data"]["clin_data_channels"],
+        num_classes=config["data"]["num_classes"],
+        droprate=config["training"]["droprate"],
+    )
+    .float()
+    .to(config["training"]["device"])
+)
+
+const_bnn_prior_parameters: dict[str, float | bool | str] = {
+    "prior_mu": 0.0,
+    "prior_sigma": 1.0,
+    "posterior_mu_init": 0.0,
+    "posterior_rho_init": -3.0,
+    "type": "Reparameterization",
+    "moped_enable": False,
+    "moped_delta": 0.5,
+}
+
+
+dnn_to_bnn(model, const_bnn_prior_parameters)
+
+model.to(config["training"]["device"])
+
+# Set up intermediate model directory
+intermediate_model_dir = bayesian_output_path / "intermediate_models"
+if not intermediate_model_dir.exists():
+    intermediate_model_dir.mkdir(parents=True, exist_ok=True)
+print(f"Intermediate models will be saved to {intermediate_model_dir}")
+
+# Set up the optimizer and loss function
+optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
+criterion = nn.BCELoss()
+
+# Train model
+model, history = train_model_bayesian(
+    model=model,
+    train_loader=train_loader,
+    val_loader=val_loader,
+    optimizer=optimizer,
+    criterion=criterion,
+    num_epochs=config["training"]["num_epochs"],
+    output_path=bayesian_output_path,
+    get_kl_loss=get_kl_loss,
+)
+
+# Test model
+test_loss, test_acc = test_model_bayesian(
+    model=model,
+    test_loader=test_loader,
+    criterion=criterion,
+    get_kl_loss=get_kl_loss,
+)
+
+print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
+
+# Save the model
+model_save_path = bayesian_output_path / "model_bayesian.pt"
+torch.save(model.state_dict(), model_save_path)
+print(f"Model saved to {model_save_path}")
+
+# Save test results and history by appending to the sql database
+results_save_path = bayesian_output_path / "results.sqlite"
+with sql.connect(results_save_path) as conn:
+    conn.execute(
+        """
+        CREATE TABLE IF NOT EXISTS results (
+            run INTEGER PRIMARY KEY,
+            test_loss REAL,
+            test_accuracy REAL
+        )
+        """
+    )
+
+    conn.execute(
+        """
+        INSERT OR REPLACE INTO results (run, test_loss, test_accuracy)
+        VALUES (?, ?, ?)
+        """,
+        (1, test_loss, test_acc),
+    )
+
+    conn.execute(
+        """
+        CREATE TABLE IF NOT EXISTS history_run_1 (
+            epoch INTEGER PRIMARY KEY,
+            train_loss REAL,
+            val_loss REAL,
+            train_acc REAL,
+            val_acc REAL
+        )
+        """
+    )
+
+    conn.execute("DELETE FROM history_run_1")
+    for epoch, row in history.iterrows():
+        values = (
+            epoch,
+            float(row["train_loss"]),
+            float(row["val_loss"]),
+            float(row["train_acc"]),
+            float(row["val_acc"]),
+        )
+
+        conn.execute(
+            """
+            INSERT INTO history_run_1 (epoch, train_loss, val_loss, train_acc, val_acc)
+            VALUES (?, ?, ?, ?, ?)
+            """,
+            values,
+        )
+
+    conn.commit()
+
+print(f"Results and history saved to {results_save_path}")
+print(f"Bayesian training completed. Model and results saved to {bayesian_output_path}")

+ 30 - 10
train_model.py → train_ensemble.py

@@ -17,7 +17,7 @@ from model.cnn import CNN3D
 from utils.training import train_model, test_model
 from data.dataset import (
     load_adni_data_from_file,
-    divide_dataset,
+    divide_dataset_by_patient_id,
     initalize_dataloaders,
 )
 
@@ -47,13 +47,29 @@ dataset = load_adni_data_from_file(
     mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
 )
 
-# Divide the dataset into training and validation sets
+# Divide the dataset into training/validation/test sets
 if config["data"]["seed"] is None:
     print("Warning: No seed provided for dataset division, using default seed 0")
     config["data"]["seed"] = 0
 
-datasets = divide_dataset(
-    dataset, config["data"]["data_splits"], seed=config["data"]["seed"]
+ptid_df = pd.read_csv(xls_file)
+ptid_df.columns = ptid_df.columns.str.strip()
+
+ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(  # type: ignore
+    subset=["Image Data ID", "PTID"]
+)
+ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
+ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
+ptid_df = ptid_df[ptid_df["PTID"] != ""]
+
+ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
+
+# Split is grouped by PTID to prevent patient-level leakage across partitions.
+datasets = divide_dataset_by_patient_id(
+    dataset,
+    ptids,
+    config["data"]["data_splits"],
+    seed=config["data"]["seed"],
 )
 
 
@@ -62,8 +78,10 @@ train_loader, val_loader, test_loader = initalize_dataloaders(
     datasets, batch_size=config["training"]["batch_size"]
 )
 
+ensemble_output_path = pl.Path(config["output"]["ensemble_path"])
+
 # Save seed to output config file
-output_config_path = pl.Path(config["output"]["path"]) / "config.json"
+output_config_path = ensemble_output_path / "config.json"
 if not output_config_path.parent.exists():
     output_config_path.parent.mkdir(parents=True, exist_ok=True)
 
@@ -91,7 +109,7 @@ for run_num in range(config["training"]["ensemble_size"]):
     )
 
     # Set up intermediate model directory
-    intermediate_model_dir = pl.Path(config["output"]["path"]) / "intermediate_models"
+    intermediate_model_dir = ensemble_output_path / "intermediate_models"
     if not intermediate_model_dir.exists():
         intermediate_model_dir.mkdir(parents=True, exist_ok=True)
     print(f"Intermediate models will be saved to {intermediate_model_dir}")
@@ -108,7 +126,7 @@ for run_num in range(config["training"]["ensemble_size"]):
         optimizer=optimizer,
         criterion=criterion,
         num_epochs=config["training"]["num_epochs"],
-        output_path=pl.Path(config["output"]["path"]),
+        output_path=ensemble_output_path,
     )
 
     # Test model
@@ -124,12 +142,12 @@ for run_num in range(config["training"]["ensemble_size"]):
     )
 
     # Save the model
-    model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt"
+    model_save_path = ensemble_output_path / f"model_run_{run_num + 1}.pt"
     torch.save(model.state_dict(), model_save_path)
     print(f"Model saved to {model_save_path}")
 
     # Save test results and history by appending to the sql database
-    results_save_path = pl.Path(config["output"]["path"]) / f"results.sqlite"
+    results_save_path = ensemble_output_path / f"results.sqlite"
     with sql.connect(results_save_path) as conn:
         # Create results table if it doesn't exist
         conn.execute(
@@ -185,4 +203,6 @@ for run_num in range(config["training"]["ensemble_size"]):
     print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
 
 # Completion message
-print(f"All runs completed. Models and results saved to {config['output']['path']}")
+print(
+    f"All runs completed. Models and results saved to {config['output']['ensemble_path']}"
+)

+ 164 - 3
utils/training.py

@@ -1,9 +1,8 @@
 import torch
 import torch.nn as nn
 from torch.utils.data import DataLoader
-import xarray as xr
 from data.dataset import ADNIDataset
-from typing import Tuple
+from typing import Callable, Tuple, cast
 from tqdm import tqdm
 import numpy as np
 import pathlib as pl
@@ -14,6 +13,7 @@ type TrainMetrics = Tuple[
 ]  # (train_loss, val_loss, train_acc, val_acc)
 
 type TestMetrics = Tuple[float, float]  # (test_loss, test_acc)
+type KLLossFn = Callable[[nn.Module], torch.Tensor | None]
 
 
 def test_model(
@@ -85,7 +85,7 @@ def train_epoch(
         optimizer.zero_grad()
         outputs = model((mri, xls))
         loss = criterion(outputs, targets)
-        loss.backward()
+        loss.backward()  # type: ignore[reportUnknownMemberType]
         optimizer.step()
         train_loss += loss.item() * (mri.size(0) + xls.size(0))
     train_loss /= len(train_loader)
@@ -185,3 +185,164 @@ def train_model(
     )
 
     return model, history
+
+
+def test_model_bayesian(
+    model: nn.Module,
+    test_loader: DataLoader[ADNIDataset],
+    criterion: nn.Module,
+    get_kl_loss: KLLossFn,
+) -> TestMetrics:
+    """
+    Tests a Bayesian model on the test dataset with KL-augmented loss.
+    """
+    model.eval()
+    test_loss = 0.0
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for _, (mri, xls, targets, _) in tqdm(
+            enumerate(test_loader), desc="Testing", total=len(test_loader), unit="batch"
+        ):
+            outputs = model((mri, xls))
+            data_loss = cast(torch.Tensor, criterion(outputs, targets))
+            batch_size = mri.size(0)
+            kl_term = get_kl_loss(model)
+            kl_loss = (
+                kl_term / batch_size
+                if kl_term is not None
+                else torch.tensor(0.0, device=outputs.device)
+            )
+            loss: torch.Tensor = data_loss + kl_loss
+
+            test_loss += loss.item() * (mri.size(0) + xls.size(0))
+
+            predicted = (outputs > 0.5).float()
+            correct += (predicted == targets).sum().item()
+            total += targets.numel()
+
+    test_loss /= len(test_loader)
+    test_acc = correct / total if total > 0 else 0.0
+    return test_loss, test_acc
+
+
+def train_epoch_bayesian(
+    model: nn.Module,
+    train_loader: DataLoader[ADNIDataset],
+    val_loader: DataLoader[ADNIDataset],
+    optimizer: torch.optim.Optimizer,
+    criterion: nn.Module,
+    get_kl_loss: KLLossFn,
+) -> TrainMetrics:
+    """
+    Trains a Bayesian model for one epoch and evaluates on validation data.
+    """
+    model.train()
+    train_loss = 0.0
+
+    for _, (mri, xls, targets, _) in tqdm(
+        enumerate(train_loader), desc="Training", total=len(train_loader), unit="batch"
+    ):
+        optimizer.zero_grad()
+        outputs = model((mri, xls))
+        data_loss = cast(torch.Tensor, criterion(outputs, targets))
+        batch_size = mri.size(0)
+        kl_term = get_kl_loss(model)
+        kl_loss = (
+            kl_term / batch_size
+            if kl_term is not None
+            else torch.tensor(0.0, device=outputs.device)
+        )
+        loss: torch.Tensor = data_loss + kl_loss
+        loss.backward()  # type: ignore[reportUnknownMemberType]
+        optimizer.step()
+        train_loss += loss.item() * (mri.size(0) + xls.size(0))
+    train_loss /= len(train_loader)
+
+    model.eval()
+    val_loss = 0.0
+    correct = 0
+    total = 0
+
+    with torch.no_grad():
+        for _, (mri, xls, targets, _) in tqdm(
+            enumerate(val_loader),
+            desc="Validation",
+            total=len(val_loader),
+            unit="batch",
+        ):
+            outputs = model((mri, xls))
+            data_loss = cast(torch.Tensor, criterion(outputs, targets))
+            batch_size = mri.size(0)
+            kl_term = get_kl_loss(model)
+            kl_loss = (
+                kl_term / batch_size
+                if kl_term is not None
+                else torch.tensor(0.0, device=outputs.device)
+            )
+            loss: torch.Tensor = data_loss + kl_loss
+            val_loss += loss.item() * (mri.size(0) + xls.size(0))
+
+            predicted = (outputs > 0.5).float()
+            correct += (predicted == targets).sum().item()
+            total += targets.numel()
+
+    val_loss /= len(val_loader)
+    val_acc = correct / total if total > 0 else 0.0
+    train_acc = correct / total if total > 0 else 0.0
+
+    return train_loss, val_loss, train_acc, val_acc
+
+
+def train_model_bayesian(
+    model: nn.Module,
+    train_loader: DataLoader[ADNIDataset],
+    val_loader: DataLoader[ADNIDataset],
+    optimizer: torch.optim.Optimizer,
+    criterion: nn.Module,
+    num_epochs: int,
+    output_path: pl.Path,
+    get_kl_loss: KLLossFn,
+) -> Tuple[nn.Module, pd.DataFrame]:
+    """
+    Trains a Bayesian model with KL-augmented objective.
+    """
+    nhist = np.zeros((num_epochs, 4), dtype=np.float32)
+
+    for epoch in range(num_epochs):
+        train_loss, val_loss, train_acc, val_acc = train_epoch_bayesian(
+            model,
+            train_loader,
+            val_loader,
+            optimizer,
+            criterion,
+            get_kl_loss,
+        )
+
+        nhist[epoch, 0] = train_loss
+        nhist[epoch, 1] = val_loss
+        nhist[epoch, 2] = train_acc
+        nhist[epoch, 3] = val_acc
+
+        print(
+            f"Epoch [{epoch + 1}/{num_epochs}], "
+            f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
+            f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}"
+        )
+
+        if num_epochs > 4:
+            if (epoch + 1) % (num_epochs // 4) == 0:
+                model_save_path = (
+                    output_path / "intermediate_models" / f"model_epoch_{epoch + 1}.pt"
+                )
+                torch.save(model.state_dict(), model_save_path)
+                print(f"Model saved at epoch {epoch + 1}")
+
+    history = pd.DataFrame(
+        data=nhist.astype(np.float32),
+        columns=["train_loss", "val_loss", "train_acc", "val_acc"],
+        index=np.arange(1, num_epochs + 1),
+    )
+
+    return model, history