# Torch import torch.nn as nn import torch import torch.optim as optim # Config from utils.config import config import pathlib as pl from result import Ok, Err import json # Custom modules from model.cnn import CNN3D from utils.training import train_model, test_model from data.dataset import ( load_adni_data_from_file, divide_dataset, initalize_dataloaders, ) # Load data mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii") xls_file = pl.Path(config["data"]["xls_file_path"]) # Load the data match load_adni_data_from_file( mri_files, xls_file, device=config["training"]["device"] ): case Ok(d): dataset = d print("Data loaded successfully") case Err(e): print(f"Error loading data: {e}") exit(-1) # Divide the dataset into training and validation sets if config["data"]["seed"] is None: print("Warning: No seed provided for dataset division, using default seed 0") config["data"]["seed"] = 0 match divide_dataset( dataset, config["data"]["train_val_split"], seed=config["data"]["seed"] ): case Ok(s): if len(s) != 3: print(f"Error: Expected 3 subsets (train, val, test), got {len(s)}") exit(-1) datasets = s print("Dataset divided successfully") case Err(e): print(f"Error dividing dataset: {e}") exit(-1) # Initialize the dataloaders train_loader, val_loader, test_loader = initalize_dataloaders( datasets, batch_size=config["training"]["batch_size"] ) # Save seed to output config file output_config_path = pl.Path(config["output"]["path"] / "config.json") if not output_config_path.parent.exists(): output_config_path.parent.mkdir(parents=True, exist_ok=True) with open(output_config_path, "w") as f: # Save as JSON json.dump(config, f, indent=4) print(f"Configuration saved to {output_config_path}") # Set up the ensemble training loop for run_num in range(config["training"]["ensemble_runs"]): print(f"Starting run {run_num + 1}/{config['training']['ensemble_runs']}") # Initialize the model model = ( CNN3D( image_channels=config["data"]["image_channels"], clin_data_channels=config["data"]["clin_data_channels"], num_classes=config["data"]["num_classes"], droprate=config["training"]["drop_rate"], ) .float() .to(config["training"]["device"]) ) # Set up the optimizer and loss function optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"]) criterion = nn.BCELoss() # Train model model, history = train_model( model=model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, criterion=criterion, num_epochs=config["training"]["num_epochs"], learning_rate=config["training"]["learning_rate"], ) # Test model test_loss, test_acc = test_model( model=model, test_loader=test_loader, criterion=criterion, ) print( f"Run {run_num + 1}/{config['training']['ensemble_runs']} - " f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}" ) # Save the model model_save_path = pl.Path(config["output"]["path"] / f"model_run_{run_num + 1}.pt") torch.save(model.state_dict(), model_save_path) print(f"Model saved to {model_save_path}") # Save the training history history_save_path = pl.Path( config["output"]["path"] / f"history_run_{run_num + 1}.nc" ) history.to_netcdf(history_save_path, mode="w") # type: ignore print(f"Training history saved to {history_save_path}") # Save test results test_results_save_path = pl.Path( config["output"]["path"] / f"test_results_run_{run_num + 1}.json" ) with open(test_results_save_path, "w") as f: json.dump( { "test_loss": test_loss, "test_accuracy": test_acc, }, f, indent=4, ) print(f"Test results saved to {test_results_save_path}") print(f"Run {run_num + 1}/{config['training']['ensemble_runs']} completed\n") # Completion message print(f"All runs completed. Models and results saved to {config['output']['path']}")