123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- # Torch
- import torch.nn as nn
- import torch
- import torch.optim as optim
- # Config
- from utils.config import config
- import pathlib as pl
- import pandas as pd
- import json
- import sqlite3 as sql
- # Custom modules
- from model.cnn import CNN3D
- from utils.training import train_model, test_model
- from data.dataset import (
- load_adni_data_from_file,
- divide_dataset,
- initalize_dataloaders,
- )
- # Load data
- mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
- xls_file = pl.Path(config["data"]["xls_file_path"])
- # Load the data
- def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
- """
- Preprocess the Excel DataFrame.
- This function can be customized to filter or modify the DataFrame as needed.
- """
- data = df[["Image Data ID", "Sex", "Age (current)"]]
- data["Sex"] = data["Sex"].str.strip() # type: ignore
- data = data.replace({"M": 0, "F": 1}) # type: ignore
- data.set_index("Image Data ID") # type: ignore
- return data
- dataset = load_adni_data_from_file(
- mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
- )
- # Divide the dataset into training and validation sets
- if config["data"]["seed"] is None:
- print("Warning: No seed provided for dataset division, using default seed 0")
- config["data"]["seed"] = 0
- datasets = divide_dataset(
- dataset, config["data"]["data_splits"], seed=config["data"]["seed"]
- )
- # Initialize the dataloaders
- train_loader, val_loader, test_loader = initalize_dataloaders(
- datasets, batch_size=config["training"]["batch_size"]
- )
- # Save seed to output config file
- output_config_path = pl.Path(config["output"]["path"]) / "config.json"
- if not output_config_path.parent.exists():
- output_config_path.parent.mkdir(parents=True, exist_ok=True)
- with open(output_config_path, "w") as f:
- # Save as JSON
- json.dump(config, f, indent=4)
- print(f"Configuration saved to {output_config_path}")
- # Set up the ensemble training loop
- for run_num in range(config["training"]["ensemble_size"]):
- print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}")
- # Initialize the model
- model = (
- CNN3D(
- image_channels=config["data"]["image_channels"],
- clin_data_channels=config["data"]["clin_data_channels"],
- num_classes=config["data"]["num_classes"],
- droprate=config["training"]["droprate"],
- )
- .float()
- .to(config["training"]["device"])
- )
- # Set up intermediate model directory
- intermediate_model_dir = pl.Path(config["output"]["path"]) / "intermediate_models"
- if not intermediate_model_dir.exists():
- intermediate_model_dir.mkdir(parents=True, exist_ok=True)
- print(f"Intermediate models will be saved to {intermediate_model_dir}")
- # Set up the optimizer and loss function
- optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
- criterion = nn.BCELoss()
- # Train model
- model, history = train_model(
- model=model,
- train_loader=train_loader,
- val_loader=val_loader,
- optimizer=optimizer,
- criterion=criterion,
- num_epochs=config["training"]["num_epochs"],
- output_path=pl.Path(config["output"]["path"]),
- )
- # Test model
- test_loss, test_acc = test_model(
- model=model,
- test_loader=test_loader,
- criterion=criterion,
- )
- print(
- f"Run {run_num + 1}/{config['training']['ensemble_size']} - "
- f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
- )
- # Save the model
- model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt"
- torch.save(model.state_dict(), model_save_path)
- print(f"Model saved to {model_save_path}")
- # Save test results and history by appending to the sql database
- results_save_path = pl.Path(config["output"]["path"]) / f"results.sqlite"
- with sql.connect(results_save_path) as conn:
- # Create results table if it doesn't exist
- conn.execute(
- """
- CREATE TABLE IF NOT EXISTS results (
- run INTEGER PRIMARY KEY,
- test_loss REAL,
- test_accuracy REAL
- )
- """
- )
- # Insert the results
- conn.execute(
- """
- INSERT INTO results (run, test_loss, test_accuracy)
- VALUES (?, ?, ?)
- """,
- (run_num + 1, test_loss, test_acc),
- )
- # Create a new table for the run history
- conn.execute(
- f"""
- CREATE TABLE IF NOT EXISTS history_run_{run_num + 1} (
- epoch INTEGER PRIMARY KEY,
- train_loss REAL,
- val_loss REAL,
- train_acc REAL,
- val_acc REAL
- )
- """
- )
- # Insert the history
- for epoch, row in history.iterrows():
- values = (
- epoch,
- float(row["train_loss"]),
- float(row["val_loss"]),
- float(row["train_acc"]),
- float(row["val_acc"]),
- )
- conn.execute(
- f"""
- INSERT INTO history_run_{run_num + 1} (epoch, train_loss, val_loss, train_acc, val_acc)
- VALUES (?, ?, ?, ?, ?)
- """,
- values, # type: ignore
- )
- conn.commit()
- print(f"Results and history saved to {results_save_path}")
- print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
- # Completion message
- print(f"All runs completed. Models and results saved to {config['output']['path']}")
|