Quellcode durchsuchen

Finished model training and evaluation

Nicholas Schense vor 4 Tagen
Ursprung
Commit
db8827e3fe
6 geänderte Dateien mit 216 neuen und 45 gelöschten Zeilen
  1. 8 8
      config.toml
  2. 5 1
      data/dataset.py
  3. 128 0
      evaluate_models.py
  4. 2 0
      generate_statistics.py
  5. 61 26
      train_model.py
  6. 12 10
      utils/training.py

+ 8 - 8
config.toml

@@ -1,7 +1,7 @@
 [data]
-mri_files_path = "/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/"
+mri_files_path = "../data/PET_volumes_customtemplate_float32/"
 xls_file_path = "LP_ADNIMERGE.csv"
-seed = 1
+seed = 42
 data_splits = [0.7, 0.2, 0.1] # train, validation, test
 image_channels = 1
 clin_data_channels = 2
@@ -9,12 +9,12 @@ num_classes = 2 # AD, NL
 
 
 [training]
-device = "cuda:1" # "cpu", "cuda", "mps"
+device = "cuda:0" # "cpu", "cuda", "mps"
 batch_size = 32
-ensemble_size = 5
-droprate = 0.1
-learning_rate = 0.001
-num_epochs = 10
+ensemble_size = 50
+droprate = 0.05
+learning_rate = 0.0001
+num_epochs = 30
 
 [output]
-path = "../models/5x10/"
+path = "../models/Full_Ensemble(50x30)/"

+ 5 - 1
data/dataset.py

@@ -104,8 +104,12 @@ def load_adni_data_from_file(
 
         if "AD" in filename:
             file_expected_class = torch.tensor([1.0, 0.0])
-        elif "CN" in filename:
+        elif "NL" in filename:
             file_expected_class = torch.tensor([0.0, 1.0])
+        else:
+            raise ValueError(
+                f"Filename {filename} does not contain a valid class identifier (AD or CN)."
+            )
 
         mri_data_unstacked.append(file_mri_data)
         expected_classes_unstacked.append(file_expected_class)

+ 128 - 0
evaluate_models.py

@@ -0,0 +1,128 @@
+# This program evaluates every model on the combined validation and test set, then saves the results to a netcdf file.
+
+import torch
+import xarray as xr
+from torch.utils.data import DataLoader
+import numpy as np
+
+
+# Config
+from model.cnn import CNN3D
+from utils.config import config
+import pathlib as pl
+import pandas as pd
+import json
+
+
+# Custom modules
+from data.dataset import (
+    load_adni_data_from_file,
+    divide_dataset,
+    initalize_dataloaders,
+    ADNIDataset,
+)
+
+mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
+xls_file = pl.Path(config["data"]["xls_file_path"])
+
+
+def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
+    """
+    Preprocess the Excel DataFrame.
+    This function can be customized to filter or modify the DataFrame as needed.
+    """
+
+    data = df[["Image Data ID", "Sex", "Age (current)"]]
+    data["Sex"] = data["Sex"].str.strip()  # type: ignore
+    data = data.replace({"M": 0, "F": 1})  # type: ignore
+    data.set_index("Image Data ID")  # type: ignore
+
+    return data
+
+
+dataset = load_adni_data_from_file(
+    mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
+)
+
+# Divide the dataset into training and validation sets, using the same seed as training
+with open(pl.Path(config["output"]["path"]) / "config.json") as f:
+    training_config = json.load(f)
+    try:
+        loaded_seed = int(training_config["data"]["seed"])
+    except (ValueError, KeyError) as e:
+        print(
+            f"Warning: No previous seed found for dataset division, using seed from config. Error: {e}"
+        )
+        loaded_seed = config["data"]["seed"]
+
+
+datasets = divide_dataset(dataset, config["data"]["data_splits"], seed=loaded_seed)
+
+
+# Initialize the dataloaders
+train_loader, val_loader, test_loader = initalize_dataloaders(
+    datasets, batch_size=config["training"]["batch_size"]
+)
+
+
+# Combine validation and test sets for final evaluation
+combined_loader: DataLoader[ADNIDataset] = torch.utils.data.DataLoader(
+    torch.utils.data.ConcatDataset([val_loader.dataset, test_loader.dataset]),
+    batch_size=1,
+    shuffle=False,
+)
+
+
+# 50 models are too large to load into memory at once, so we will load and evaluate them one at a time
+model_dir = pl.Path(config["output"]["path"])
+model_files = sorted(model_dir.glob("model_run_*.pt"))
+
+placeholder = np.zeros(
+    (len(model_files), len(combined_loader), config["data"]["num_classes"]),
+    dtype=np.float32,
+)  # Placeholder for results
+
+placeholder[:] = np.nan  # Fill with NaNs for easier identification of missing data
+dimensions = ["model", "batch", "img_class"]
+coords = {
+    "model": [int(mf.stem.split("_")[2]) for mf in model_files],
+    "batch": list(range(len(combined_loader))),
+    "img_class": list(range(config["data"]["num_classes"])),
+}
+
+results = xr.DataArray(placeholder, coords=coords, dims=dimensions)
+
+for model_file in model_files:
+    model_num = int(model_file.stem.split("_")[2])
+    print(f"Evaluating model {model_num}...")
+
+    # Load the model state
+    model = (
+        CNN3D(
+            image_channels=config["data"]["image_channels"],
+            clin_data_channels=config["data"]["clin_data_channels"],
+            num_classes=config["data"]["num_classes"],
+            droprate=config["training"]["droprate"],
+        )
+        .float()
+        .to(config["training"]["device"])
+    )
+
+    model.load_state_dict(
+        torch.load(model_file, map_location=config["training"]["device"]), strict=False
+    )
+    model.eval()
+
+    with torch.no_grad():
+        for batch_idx, (mri_batch, xls_batch, labels_batch) in enumerate(
+            combined_loader
+        ):
+            outputs = model((mri_batch.float(), xls_batch.float()))
+            probabilities = outputs.cpu().numpy()[0, :]  # type: ignore
+
+            results.loc[model_num, batch_idx, :] = probabilities  # type: ignore
+
+# Save results to netcdf file
+output_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
+results.to_netcdf(output_path, mode="w")  # type: ignore
+print(f"Results saved to {output_path}")

+ 2 - 0
generate_statistics.py

@@ -0,0 +1,2 @@
+import xarray as xr
+from utils.config import config

+ 61 - 26
train_model.py

@@ -3,11 +3,13 @@ import torch.nn as nn
 import torch
 import torch.optim as optim
 
+
 # Config
 from utils.config import config
 import pathlib as pl
 import pandas as pd
 import json
+import sqlite3 as sql
 
 
 # Custom modules
@@ -88,6 +90,12 @@ for run_num in range(config["training"]["ensemble_size"]):
         .to(config["training"]["device"])
     )
 
+    # Set up intermediate model directory
+    intermediate_model_dir = pl.Path(config["output"]["path"]) / "intermediate_models"
+    if not intermediate_model_dir.exists():
+        intermediate_model_dir.mkdir(parents=True, exist_ok=True)
+    print(f"Intermediate models will be saved to {intermediate_model_dir}")
+
     # Set up the optimizer and loss function
     optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
     criterion = nn.BCELoss()
@@ -100,7 +108,7 @@ for run_num in range(config["training"]["ensemble_size"]):
         optimizer=optimizer,
         criterion=criterion,
         num_epochs=config["training"]["num_epochs"],
-        learning_rate=config["training"]["learning_rate"],
+        output_path=pl.Path(config["output"]["path"]),
     )
 
     # Test model
@@ -120,33 +128,60 @@ for run_num in range(config["training"]["ensemble_size"]):
     torch.save(model.state_dict(), model_save_path)
     print(f"Model saved to {model_save_path}")
 
-    # Save the training history
-    history_save_path = (
-        pl.Path(config["output"]["path"]) / f"history_run_{run_num + 1}.nc"
-    )
+    # Save test results and history by appending to the sql database
+    results_save_path = pl.Path(config["output"]["path"]) / f"results.sqlite"
+    with sql.connect(results_save_path) as conn:
+        # Create results table if it doesn't exist
+        conn.execute(
+            """
+            CREATE TABLE IF NOT EXISTS results (
+                run INTEGER PRIMARY KEY,
+                test_loss REAL,
+                test_accuracy REAL
+            )
+            """
+        )
+        # Insert the results
+        conn.execute(
+            """
+            INSERT INTO results (run, test_loss, test_accuracy)
+            VALUES (?, ?, ?)
+            """,
+            (run_num + 1, test_loss, test_acc),
+        )
 
-    history.to_netcdf(history_save_path, mode="w")  # type: ignore
-    print(f"Training history saved to {history_save_path}")
-
-    # Save test results by appending to the results file
-    test_results_save_path = pl.Path(config["output"]["path"]) / f"results.json"
-    with open(test_results_save_path, "r+") as f:
-        try:
-            results = json.load(f)
-        except json.JSONDecodeError:
-            # If the file is empty or not a valid JSON, initialize an empty list
-            print("No previous results found, initializing results list.")
-            results = []
-
-        results.append(  # type: ignore
-            {
-                "run": run_num + 1,
-                "test_loss": test_loss,
-                "test_accuracy": test_acc,
-            }
+        # Create a new table for the run history
+        conn.execute(
+            f"""
+            CREATE TABLE IF NOT EXISTS history_run_{run_num + 1} (
+                epoch INTEGER PRIMARY KEY,
+                train_loss REAL,
+                val_loss REAL,
+                train_acc REAL,
+                val_acc REAL
+            )
+            """
         )
-        f.seek(0)
-        json.dump(results, f, indent=4)
+        # Insert the history
+        for epoch, row in history.iterrows():
+            values = (
+                epoch,
+                float(row["train_loss"]),
+                float(row["val_loss"]),
+                float(row["train_acc"]),
+                float(row["val_acc"]),
+            )
+
+            conn.execute(
+                f"""
+                INSERT INTO history_run_{run_num + 1} (epoch, train_loss, val_loss, train_acc, val_acc)
+                VALUES (?, ?, ?, ?, ?)
+                """,
+                values,  # type: ignore
+            )
+
+        conn.commit()
+    print(f"Results and history saved to {results_save_path}")
     print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
 
 # Completion message

+ 12 - 10
utils/training.py

@@ -6,6 +6,8 @@ from data.dataset import ADNIDataset
 from typing import Tuple
 from tqdm import tqdm
 import numpy as np
+import pathlib as pl
+import pandas as pd
 
 type TrainMetrics = Tuple[
     float, float, float, float
@@ -122,8 +124,8 @@ def train_model(
     optimizer: torch.optim.Optimizer,
     criterion: nn.Module,
     num_epochs: int,
-    learning_rate: float,
-) -> Tuple[nn.Module, xr.DataArray]:
+    output_path: pl.Path,
+) -> Tuple[nn.Module, pd.DataFrame]:
     """
     Trains the model using the provided training and validation data loaders.
 
@@ -168,18 +170,18 @@ def train_model(
         # If we are at 25, 50, or 75% of the epochs, save the model
         if num_epochs > 4:
             if (epoch + 1) % (num_epochs // 4) == 0:
-                torch.save(model.state_dict(), f"model_epoch_{epoch + 1}.pth")
+                model_save_path = (
+                    output_path / "intermediate_models" / f"model_epoch_{epoch + 1}.pt"
+                )
+                torch.save(model.state_dict(), model_save_path)
                 print(f"Model saved at epoch {epoch + 1}")
 
     # return the trained model and the traning history
 
-    history = xr.DataArray(
-        data=nhist,
-        dims=["epoch", "metric"],
-        coords={
-            "epoch": range(num_epochs),
-            "metric": ["train_loss", "val_loss", "train_acc", "val_acc"],
-        },
+    history = pd.DataFrame(
+        data=nhist.astype(np.float32),
+        columns=["train_loss", "val_loss", "train_acc", "val_acc"],
+        index=np.arange(1, num_epochs + 1),
     )
 
     return model, history