123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- # Torch
- import torch.nn as nn
- import torch
- import torch.optim as optim
- # Config
- from utils.config import config
- import pathlib as pl
- import pandas as pd
- import json
- # Custom modules
- from model.cnn import CNN3D
- from utils.training import train_model, test_model
- from data.dataset import (
- load_adni_data_from_file,
- divide_dataset,
- initalize_dataloaders,
- )
- # Load data
- mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
- xls_file = pl.Path(config["data"]["xls_file_path"])
- # Load the data
- def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
- """
- Preprocess the Excel DataFrame.
- This function can be customized to filter or modify the DataFrame as needed.
- """
- data = df[["Image Data ID", "Sex", "Age (current)"]]
- data["Sex"] = data["Sex"].str.strip() # type: ignore
- data = data.replace({"M": 0, "F": 1}) # type: ignore
- data.set_index("Image Data ID") # type: ignore
- return data
- dataset = load_adni_data_from_file(
- mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
- )
- # Divide the dataset into training and validation sets
- if config["data"]["seed"] is None:
- print("Warning: No seed provided for dataset division, using default seed 0")
- config["data"]["seed"] = 0
- datasets = divide_dataset(
- dataset, config["data"]["data_splits"], seed=config["data"]["seed"]
- )
- # Initialize the dataloaders
- train_loader, val_loader, test_loader = initalize_dataloaders(
- datasets, batch_size=config["training"]["batch_size"]
- )
- # Save seed to output config file
- output_config_path = pl.Path(config["output"]["path"]) / "config.json"
- if not output_config_path.parent.exists():
- output_config_path.parent.mkdir(parents=True, exist_ok=True)
- with open(output_config_path, "w") as f:
- # Save as JSON
- json.dump(config, f, indent=4)
- print(f"Configuration saved to {output_config_path}")
- # Set up the ensemble training loop
- for run_num in range(config["training"]["ensemble_size"]):
- print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}")
- # Initialize the model
- model = (
- CNN3D(
- image_channels=config["data"]["image_channels"],
- clin_data_channels=config["data"]["clin_data_channels"],
- num_classes=config["data"]["num_classes"],
- droprate=config["training"]["droprate"],
- )
- .float()
- .to(config["training"]["device"])
- )
- # Set up the optimizer and loss function
- optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
- criterion = nn.BCELoss()
- # Train model
- model, history = train_model(
- model=model,
- train_loader=train_loader,
- val_loader=val_loader,
- optimizer=optimizer,
- criterion=criterion,
- num_epochs=config["training"]["num_epochs"],
- learning_rate=config["training"]["learning_rate"],
- )
- # Test model
- test_loss, test_acc = test_model(
- model=model,
- test_loader=test_loader,
- criterion=criterion,
- )
- print(
- f"Run {run_num + 1}/{config['training']['ensemble_size']} - "
- f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
- )
- # Save the model
- model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt"
- torch.save(model.state_dict(), model_save_path)
- print(f"Model saved to {model_save_path}")
- # Save the training history
- history_save_path = (
- pl.Path(config["output"]["path"]) / f"history_run_{run_num + 1}.nc"
- )
- history.to_netcdf(history_save_path, mode="w") # type: ignore
- print(f"Training history saved to {history_save_path}")
- # Save test results by appending to the results file
- test_results_save_path = pl.Path(config["output"]["path"]) / f"results.json"
- with open(test_results_save_path, "wr+") as f:
- try:
- results = json.load(f)
- except json.JSONDecodeError:
- # If the file is empty or not a valid JSON, initialize an empty list
- print("No previous results found, initializing results list.")
- results = []
- results.append( # type: ignore
- {
- "run": run_num + 1,
- "test_loss": test_loss,
- "test_accuracy": test_acc,
- }
- )
- f.seek(0)
- json.dump(results, f, indent=4)
- print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
- # Completion message
- print(f"All runs completed. Models and results saved to {config['output']['path']}")
|