# Torch import torch.nn as nn import torch import torch.optim as optim # Config from utils.config import config import pathlib as pl import pandas as pd import json # Custom modules from model.cnn import CNN3D from utils.training import train_model, test_model from data.dataset import ( load_adni_data_from_file, divide_dataset, initalize_dataloaders, ) # Load data mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii") xls_file = pl.Path(config["data"]["xls_file_path"]) # Load the data def xls_pre(df: pd.DataFrame) -> pd.DataFrame: """ Preprocess the Excel DataFrame. This function can be customized to filter or modify the DataFrame as needed. """ data = df[["Image Data ID", "Sex", "Age (current)"]] data["Sex"] = data["Sex"].str.strip() # type: ignore data = data.replace({"M": 0, "F": 1}) # type: ignore data.set_index("Image Data ID") # type: ignore return data dataset = load_adni_data_from_file( mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre ) # Divide the dataset into training and validation sets if config["data"]["seed"] is None: print("Warning: No seed provided for dataset division, using default seed 0") config["data"]["seed"] = 0 datasets = divide_dataset( dataset, config["data"]["data_splits"], seed=config["data"]["seed"] ) # Initialize the dataloaders train_loader, val_loader, test_loader = initalize_dataloaders( datasets, batch_size=config["training"]["batch_size"] ) # Save seed to output config file output_config_path = pl.Path(config["output"]["path"]) / "config.json" if not output_config_path.parent.exists(): output_config_path.parent.mkdir(parents=True, exist_ok=True) with open(output_config_path, "w") as f: # Save as JSON json.dump(config, f, indent=4) print(f"Configuration saved to {output_config_path}") # Set up the ensemble training loop for run_num in range(config["training"]["ensemble_size"]): print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}") # Initialize the model model = ( CNN3D( image_channels=config["data"]["image_channels"], clin_data_channels=config["data"]["clin_data_channels"], num_classes=config["data"]["num_classes"], droprate=config["training"]["droprate"], ) .float() .to(config["training"]["device"]) ) # Set up the optimizer and loss function optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"]) criterion = nn.BCELoss() # Train model model, history = train_model( model=model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, criterion=criterion, num_epochs=config["training"]["num_epochs"], learning_rate=config["training"]["learning_rate"], ) # Test model test_loss, test_acc = test_model( model=model, test_loader=test_loader, criterion=criterion, ) print( f"Run {run_num + 1}/{config['training']['ensemble_size']} - " f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}" ) # Save the model model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt" torch.save(model.state_dict(), model_save_path) print(f"Model saved to {model_save_path}") # Save the training history history_save_path = ( pl.Path(config["output"]["path"]) / f"history_run_{run_num + 1}.nc" ) history.to_netcdf(history_save_path, mode="w") # type: ignore print(f"Training history saved to {history_save_path}") # Save test results by appending to the results file test_results_save_path = pl.Path(config["output"]["path"]) / f"results.json" with open(test_results_save_path, "wr+") as f: try: results = json.load(f) except json.JSONDecodeError: # If the file is empty or not a valid JSON, initialize an empty list print("No previous results found, initializing results list.") results = [] results.append( # type: ignore { "run": run_num + 1, "test_loss": test_loss, "test_accuracy": test_acc, } ) f.seek(0) json.dump(results, f, indent=4) print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n") # Completion message print(f"All runs completed. Models and results saved to {config['output']['path']}")