|
|
@@ -0,0 +1,214 @@
|
|
|
+# Torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch
|
|
|
+import torch.optim as optim
|
|
|
+
|
|
|
+# Config
|
|
|
+from utils.config import config
|
|
|
+import pathlib as pl
|
|
|
+import pandas as pd
|
|
|
+import json
|
|
|
+import sqlite3 as sql
|
|
|
+from typing import Callable, cast
|
|
|
+import os
|
|
|
+
|
|
|
+# Bayesian torch
|
|
|
+from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss # type: ignore[import-untyped]
|
|
|
+
|
|
|
+# Custom modules
|
|
|
+from model.cnn import CNN3D
|
|
|
+from utils.training import train_model_bayesian, test_model_bayesian
|
|
|
+from data.dataset import (
|
|
|
+ load_adni_data_from_file,
|
|
|
+ divide_dataset_by_patient_id,
|
|
|
+ initalize_dataloaders,
|
|
|
+)
|
|
|
+
|
|
|
+# Load data
|
|
|
+mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
|
|
|
+xls_file = pl.Path(config["data"]["xls_file_path"])
|
|
|
+
|
|
|
+# If current directory is MedPhys_Research, change to alnn_rewrite for relative imports to work
|
|
|
+try:
|
|
|
+ os.chdir("alnn_rewrite")
|
|
|
+except FileNotFoundError:
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
+ """
|
|
|
+ Preprocess the Excel DataFrame.
|
|
|
+ This function can be customized to filter or modify the DataFrame as needed.
|
|
|
+ """
|
|
|
+
|
|
|
+ data = df[["Image Data ID", "Sex", "Age (current)"]]
|
|
|
+ data["Sex"] = data["Sex"].str.strip() # type: ignore
|
|
|
+ data = data.replace({"M": 0, "F": 1}) # type: ignore
|
|
|
+ data.set_index("Image Data ID") # type: ignore
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
+dataset = load_adni_data_from_file(
|
|
|
+ mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
|
|
|
+)
|
|
|
+
|
|
|
+if config["data"]["seed"] is None:
|
|
|
+ print("Warning: No seed provided for dataset division, using default seed 0")
|
|
|
+ config["data"]["seed"] = 0
|
|
|
+
|
|
|
+ptid_df = pd.read_csv(xls_file)
|
|
|
+ptid_df.columns = ptid_df.columns.str.strip()
|
|
|
+
|
|
|
+ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna( # type: ignore
|
|
|
+ subset=["Image Data ID", "PTID"]
|
|
|
+)
|
|
|
+ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
|
|
|
+ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
|
|
|
+ptid_df = ptid_df[ptid_df["PTID"] != ""]
|
|
|
+
|
|
|
+ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
|
|
|
+
|
|
|
+# Split is grouped by PTID to prevent patient-level leakage across partitions.
|
|
|
+datasets = divide_dataset_by_patient_id(
|
|
|
+ dataset,
|
|
|
+ ptids,
|
|
|
+ config["data"]["data_splits"],
|
|
|
+ seed=config["data"]["seed"],
|
|
|
+)
|
|
|
+
|
|
|
+# Initialize the dataloaders
|
|
|
+train_loader, val_loader, test_loader = initalize_dataloaders(
|
|
|
+ datasets, batch_size=config["training"]["batch_size"]
|
|
|
+)
|
|
|
+
|
|
|
+bayesian_output_path = pl.Path(config["output"]["bayesian_path"])
|
|
|
+
|
|
|
+# Save seed to output config file
|
|
|
+output_config_path = bayesian_output_path / "config.json"
|
|
|
+if not output_config_path.parent.exists():
|
|
|
+ output_config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+with open(output_config_path, "w") as f:
|
|
|
+ json.dump(config, f, indent=4)
|
|
|
+print(f"Configuration saved to {output_config_path}")
|
|
|
+
|
|
|
+# Initialize the model
|
|
|
+model = (
|
|
|
+ CNN3D(
|
|
|
+ image_channels=config["data"]["image_channels"],
|
|
|
+ clin_data_channels=config["data"]["clin_data_channels"],
|
|
|
+ num_classes=config["data"]["num_classes"],
|
|
|
+ droprate=config["training"]["droprate"],
|
|
|
+ )
|
|
|
+ .float()
|
|
|
+ .to(config["training"]["device"])
|
|
|
+)
|
|
|
+
|
|
|
+const_bnn_prior_parameters: dict[str, float | bool | str] = {
|
|
|
+ "prior_mu": 0.0,
|
|
|
+ "prior_sigma": 1.0,
|
|
|
+ "posterior_mu_init": 0.0,
|
|
|
+ "posterior_rho_init": -3.0,
|
|
|
+ "type": "Reparameterization",
|
|
|
+ "moped_enable": False,
|
|
|
+ "moped_delta": 0.5,
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+dnn_to_bnn(model, const_bnn_prior_parameters)
|
|
|
+
|
|
|
+model.to(config["training"]["device"])
|
|
|
+
|
|
|
+# Set up intermediate model directory
|
|
|
+intermediate_model_dir = bayesian_output_path / "intermediate_models"
|
|
|
+if not intermediate_model_dir.exists():
|
|
|
+ intermediate_model_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+print(f"Intermediate models will be saved to {intermediate_model_dir}")
|
|
|
+
|
|
|
+# Set up the optimizer and loss function
|
|
|
+optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
|
|
|
+criterion = nn.BCELoss()
|
|
|
+
|
|
|
+# Train model
|
|
|
+model, history = train_model_bayesian(
|
|
|
+ model=model,
|
|
|
+ train_loader=train_loader,
|
|
|
+ val_loader=val_loader,
|
|
|
+ optimizer=optimizer,
|
|
|
+ criterion=criterion,
|
|
|
+ num_epochs=config["training"]["num_epochs"],
|
|
|
+ output_path=bayesian_output_path,
|
|
|
+ get_kl_loss=get_kl_loss,
|
|
|
+)
|
|
|
+
|
|
|
+# Test model
|
|
|
+test_loss, test_acc = test_model_bayesian(
|
|
|
+ model=model,
|
|
|
+ test_loader=test_loader,
|
|
|
+ criterion=criterion,
|
|
|
+ get_kl_loss=get_kl_loss,
|
|
|
+)
|
|
|
+
|
|
|
+print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
|
|
|
+
|
|
|
+# Save the model
|
|
|
+model_save_path = bayesian_output_path / "model_bayesian.pt"
|
|
|
+torch.save(model.state_dict(), model_save_path)
|
|
|
+print(f"Model saved to {model_save_path}")
|
|
|
+
|
|
|
+# Save test results and history by appending to the sql database
|
|
|
+results_save_path = bayesian_output_path / "results.sqlite"
|
|
|
+with sql.connect(results_save_path) as conn:
|
|
|
+ conn.execute(
|
|
|
+ """
|
|
|
+ CREATE TABLE IF NOT EXISTS results (
|
|
|
+ run INTEGER PRIMARY KEY,
|
|
|
+ test_loss REAL,
|
|
|
+ test_accuracy REAL
|
|
|
+ )
|
|
|
+ """
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.execute(
|
|
|
+ """
|
|
|
+ INSERT OR REPLACE INTO results (run, test_loss, test_accuracy)
|
|
|
+ VALUES (?, ?, ?)
|
|
|
+ """,
|
|
|
+ (1, test_loss, test_acc),
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.execute(
|
|
|
+ """
|
|
|
+ CREATE TABLE IF NOT EXISTS history_run_1 (
|
|
|
+ epoch INTEGER PRIMARY KEY,
|
|
|
+ train_loss REAL,
|
|
|
+ val_loss REAL,
|
|
|
+ train_acc REAL,
|
|
|
+ val_acc REAL
|
|
|
+ )
|
|
|
+ """
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.execute("DELETE FROM history_run_1")
|
|
|
+ for epoch, row in history.iterrows():
|
|
|
+ values = (
|
|
|
+ epoch,
|
|
|
+ float(row["train_loss"]),
|
|
|
+ float(row["val_loss"]),
|
|
|
+ float(row["train_acc"]),
|
|
|
+ float(row["val_acc"]),
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.execute(
|
|
|
+ """
|
|
|
+ INSERT INTO history_run_1 (epoch, train_loss, val_loss, train_acc, val_acc)
|
|
|
+ VALUES (?, ?, ?, ?, ?)
|
|
|
+ """,
|
|
|
+ values,
|
|
|
+ )
|
|
|
+
|
|
|
+ conn.commit()
|
|
|
+
|
|
|
+print(f"Results and history saved to {results_save_path}")
|
|
|
+print(f"Bayesian training completed. Model and results saved to {bayesian_output_path}")
|