| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- # Torch
- import torch.nn as nn
- import torch
- import torch.optim as optim
- # Config
- from utils.config import config
- import pathlib as pl
- from result import Ok, Err
- import json
- # Custom modules
- from model.cnn import CNN3D
- from utils.training import train_model, test_model
- from data.dataset import (
- load_adni_data_from_file,
- divide_dataset,
- initalize_dataloaders,
- )
- # Load data
- mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
- xls_file = pl.Path(config["data"]["xls_file_path"])
- # Load the data
- match load_adni_data_from_file(
- mri_files, xls_file, device=config["training"]["device"]
- ):
- case Ok(d):
- dataset = d
- print("Data loaded successfully")
- case Err(e):
- print(f"Error loading data: {e}")
- exit(-1)
- # Divide the dataset into training and validation sets
- if config["data"]["seed"] is None:
- print("Warning: No seed provided for dataset division, using default seed 0")
- config["data"]["seed"] = 0
- match divide_dataset(
- dataset, config["data"]["train_val_split"], seed=config["data"]["seed"]
- ):
- case Ok(s):
- if len(s) != 3:
- print(f"Error: Expected 3 subsets (train, val, test), got {len(s)}")
- exit(-1)
- datasets = s
- print("Dataset divided successfully")
- case Err(e):
- print(f"Error dividing dataset: {e}")
- exit(-1)
- # Initialize the dataloaders
- train_loader, val_loader, test_loader = initalize_dataloaders(
- datasets, batch_size=config["training"]["batch_size"]
- )
- # Save seed to output config file
- output_config_path = pl.Path(config["output"]["path"] / "config.json")
- if not output_config_path.parent.exists():
- output_config_path.parent.mkdir(parents=True, exist_ok=True)
- with open(output_config_path, "w") as f:
- # Save as JSON
- json.dump(config, f, indent=4)
- print(f"Configuration saved to {output_config_path}")
- # Set up the ensemble training loop
- for run_num in range(config["training"]["ensemble_runs"]):
- print(f"Starting run {run_num + 1}/{config['training']['ensemble_runs']}")
- # Initialize the model
- model = (
- CNN3D(
- image_channels=config["data"]["image_channels"],
- clin_data_channels=config["data"]["clin_data_channels"],
- num_classes=config["data"]["num_classes"],
- droprate=config["training"]["drop_rate"],
- )
- .float()
- .to(config["training"]["device"])
- )
- # Set up the optimizer and loss function
- optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
- criterion = nn.BCELoss()
- # Train model
- model, history = train_model(
- model=model,
- train_loader=train_loader,
- val_loader=val_loader,
- optimizer=optimizer,
- criterion=criterion,
- num_epochs=config["training"]["num_epochs"],
- learning_rate=config["training"]["learning_rate"],
- )
- # Test model
- test_loss, test_acc = test_model(
- model=model,
- test_loader=test_loader,
- criterion=criterion,
- )
- print(
- f"Run {run_num + 1}/{config['training']['ensemble_runs']} - "
- f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
- )
- # Save the model
- model_save_path = pl.Path(config["output"]["path"] / f"model_run_{run_num + 1}.pt")
- torch.save(model.state_dict(), model_save_path)
- print(f"Model saved to {model_save_path}")
- # Save the training history
- history_save_path = pl.Path(
- config["output"]["path"] / f"history_run_{run_num + 1}.nc"
- )
- history.to_netcdf(history_save_path, mode="w") # type: ignore
- print(f"Training history saved to {history_save_path}")
- # Save test results
- test_results_save_path = pl.Path(
- config["output"]["path"] / f"test_results_run_{run_num + 1}.json"
- )
- with open(test_results_save_path, "w") as f:
- json.dump(
- {
- "test_loss": test_loss,
- "test_accuracy": test_acc,
- },
- f,
- indent=4,
- )
- print(f"Test results saved to {test_results_save_path}")
- print(f"Run {run_num + 1}/{config['training']['ensemble_runs']} completed\n")
- # Completion message
- print(f"All runs completed. Models and results saved to {config['output']['path']}")
|