train_model.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Torch
  2. import torch.nn as nn
  3. import torch
  4. import torch.optim as optim
  5. # Config
  6. from utils.config import config
  7. import pathlib as pl
  8. from result import Ok, Err
  9. import json
  10. # Custom modules
  11. from model.cnn import CNN3D
  12. from utils.training import train_model, test_model
  13. from data.dataset import (
  14. load_adni_data_from_file,
  15. divide_dataset,
  16. initalize_dataloaders,
  17. )
  18. # Load data
  19. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  20. xls_file = pl.Path(config["data"]["xls_file_path"])
  21. # Load the data
  22. match load_adni_data_from_file(
  23. mri_files, xls_file, device=config["training"]["device"]
  24. ):
  25. case Ok(d):
  26. dataset = d
  27. print("Data loaded successfully")
  28. case Err(e):
  29. print(f"Error loading data: {e}")
  30. exit(-1)
  31. # Divide the dataset into training and validation sets
  32. if config["data"]["seed"] is None:
  33. print("Warning: No seed provided for dataset division, using default seed 0")
  34. config["data"]["seed"] = 0
  35. match divide_dataset(
  36. dataset, config["data"]["train_val_split"], seed=config["data"]["seed"]
  37. ):
  38. case Ok(s):
  39. if len(s) != 3:
  40. print(f"Error: Expected 3 subsets (train, val, test), got {len(s)}")
  41. exit(-1)
  42. datasets = s
  43. print("Dataset divided successfully")
  44. case Err(e):
  45. print(f"Error dividing dataset: {e}")
  46. exit(-1)
  47. # Initialize the dataloaders
  48. train_loader, val_loader, test_loader = initalize_dataloaders(
  49. datasets, batch_size=config["training"]["batch_size"]
  50. )
  51. # Save seed to output config file
  52. output_config_path = pl.Path(config["output"]["path"] / "config.json")
  53. if not output_config_path.parent.exists():
  54. output_config_path.parent.mkdir(parents=True, exist_ok=True)
  55. with open(output_config_path, "w") as f:
  56. # Save as JSON
  57. json.dump(config, f, indent=4)
  58. print(f"Configuration saved to {output_config_path}")
  59. # Set up the ensemble training loop
  60. for run_num in range(config["training"]["ensemble_runs"]):
  61. print(f"Starting run {run_num + 1}/{config['training']['ensemble_runs']}")
  62. # Initialize the model
  63. model = (
  64. CNN3D(
  65. image_channels=config["data"]["image_channels"],
  66. clin_data_channels=config["data"]["clin_data_channels"],
  67. num_classes=config["data"]["num_classes"],
  68. droprate=config["training"]["drop_rate"],
  69. )
  70. .float()
  71. .to(config["training"]["device"])
  72. )
  73. # Set up the optimizer and loss function
  74. optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
  75. criterion = nn.BCELoss()
  76. # Train model
  77. model, history = train_model(
  78. model=model,
  79. train_loader=train_loader,
  80. val_loader=val_loader,
  81. optimizer=optimizer,
  82. criterion=criterion,
  83. num_epochs=config["training"]["num_epochs"],
  84. learning_rate=config["training"]["learning_rate"],
  85. )
  86. # Test model
  87. test_loss, test_acc = test_model(
  88. model=model,
  89. test_loader=test_loader,
  90. criterion=criterion,
  91. )
  92. print(
  93. f"Run {run_num + 1}/{config['training']['ensemble_runs']} - "
  94. f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
  95. )
  96. # Save the model
  97. model_save_path = pl.Path(config["output"]["path"] / f"model_run_{run_num + 1}.pt")
  98. torch.save(model.state_dict(), model_save_path)
  99. print(f"Model saved to {model_save_path}")
  100. # Save the training history
  101. history_save_path = pl.Path(
  102. config["output"]["path"] / f"history_run_{run_num + 1}.nc"
  103. )
  104. history.to_netcdf(history_save_path, mode="w") # type: ignore
  105. print(f"Training history saved to {history_save_path}")
  106. # Save test results
  107. test_results_save_path = pl.Path(
  108. config["output"]["path"] / f"test_results_run_{run_num + 1}.json"
  109. )
  110. with open(test_results_save_path, "w") as f:
  111. json.dump(
  112. {
  113. "test_loss": test_loss,
  114. "test_accuracy": test_acc,
  115. },
  116. f,
  117. indent=4,
  118. )
  119. print(f"Test results saved to {test_results_save_path}")
  120. print(f"Run {run_num + 1}/{config['training']['ensemble_runs']} completed\n")
  121. # Completion message
  122. print(f"All runs completed. Models and results saved to {config['output']['path']}")