# Torch import torch.nn as nn import torch import torch.optim as optim # Config from utils.config import config import pathlib as pl import pandas as pd import json import sqlite3 as sql # Custom modules from model.cnn import CNN3D from utils.training import train_model, test_model from data.dataset import ( load_adni_data_from_file, divide_dataset, initalize_dataloaders, ) # Load data mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii") xls_file = pl.Path(config["data"]["xls_file_path"]) # Load the data def xls_pre(df: pd.DataFrame) -> pd.DataFrame: """ Preprocess the Excel DataFrame. This function can be customized to filter or modify the DataFrame as needed. """ data = df[["Image Data ID", "Sex", "Age (current)"]] data["Sex"] = data["Sex"].str.strip() # type: ignore data = data.replace({"M": 0, "F": 1}) # type: ignore data.set_index("Image Data ID") # type: ignore return data dataset = load_adni_data_from_file( mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre ) # Divide the dataset into training and validation sets if config["data"]["seed"] is None: print("Warning: No seed provided for dataset division, using default seed 0") config["data"]["seed"] = 0 datasets = divide_dataset( dataset, config["data"]["data_splits"], seed=config["data"]["seed"] ) # Initialize the dataloaders train_loader, val_loader, test_loader = initalize_dataloaders( datasets, batch_size=config["training"]["batch_size"] ) # Save seed to output config file output_config_path = pl.Path(config["output"]["path"]) / "config.json" if not output_config_path.parent.exists(): output_config_path.parent.mkdir(parents=True, exist_ok=True) with open(output_config_path, "w") as f: # Save as JSON json.dump(config, f, indent=4) print(f"Configuration saved to {output_config_path}") # Set up the ensemble training loop for run_num in range(config["training"]["ensemble_size"]): print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}") # Initialize the model model = ( CNN3D( image_channels=config["data"]["image_channels"], clin_data_channels=config["data"]["clin_data_channels"], num_classes=config["data"]["num_classes"], droprate=config["training"]["droprate"], ) .float() .to(config["training"]["device"]) ) # Set up intermediate model directory intermediate_model_dir = pl.Path(config["output"]["path"]) / "intermediate_models" if not intermediate_model_dir.exists(): intermediate_model_dir.mkdir(parents=True, exist_ok=True) print(f"Intermediate models will be saved to {intermediate_model_dir}") # Set up the optimizer and loss function optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"]) criterion = nn.BCELoss() # Train model model, history = train_model( model=model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, criterion=criterion, num_epochs=config["training"]["num_epochs"], output_path=pl.Path(config["output"]["path"]), ) # Test model test_loss, test_acc = test_model( model=model, test_loader=test_loader, criterion=criterion, ) print( f"Run {run_num + 1}/{config['training']['ensemble_size']} - " f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}" ) # Save the model model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt" torch.save(model.state_dict(), model_save_path) print(f"Model saved to {model_save_path}") # Save test results and history by appending to the sql database results_save_path = pl.Path(config["output"]["path"]) / f"results.sqlite" with sql.connect(results_save_path) as conn: # Create results table if it doesn't exist conn.execute( """ CREATE TABLE IF NOT EXISTS results ( run INTEGER PRIMARY KEY, test_loss REAL, test_accuracy REAL ) """ ) # Insert the results conn.execute( """ INSERT INTO results (run, test_loss, test_accuracy) VALUES (?, ?, ?) """, (run_num + 1, test_loss, test_acc), ) # Create a new table for the run history conn.execute( f""" CREATE TABLE IF NOT EXISTS history_run_{run_num + 1} ( epoch INTEGER PRIMARY KEY, train_loss REAL, val_loss REAL, train_acc REAL, val_acc REAL ) """ ) # Insert the history for epoch, row in history.iterrows(): values = ( epoch, float(row["train_loss"]), float(row["val_loss"]), float(row["train_acc"]), float(row["val_acc"]), ) conn.execute( f""" INSERT INTO history_run_{run_num + 1} (epoch, train_loss, val_loss, train_acc, val_acc) VALUES (?, ?, ?, ?, ?) """, values, # type: ignore ) conn.commit() print(f"Results and history saved to {results_save_path}") print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n") # Completion message print(f"All runs completed. Models and results saved to {config['output']['path']}")