train_model.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Torch
  2. import torch.nn as nn
  3. import torch
  4. import torch.optim as optim
  5. # Config
  6. from utils.config import config
  7. import pathlib as pl
  8. import pandas as pd
  9. import json
  10. # Custom modules
  11. from model.cnn import CNN3D
  12. from utils.training import train_model, test_model
  13. from data.dataset import (
  14. load_adni_data_from_file,
  15. divide_dataset,
  16. initalize_dataloaders,
  17. )
  18. # Load data
  19. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  20. xls_file = pl.Path(config["data"]["xls_file_path"])
  21. # Load the data
  22. def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  23. """
  24. Preprocess the Excel DataFrame.
  25. This function can be customized to filter or modify the DataFrame as needed.
  26. """
  27. data = df[["Image Data ID", "Sex", "Age (current)"]]
  28. data["Sex"] = data["Sex"].str.strip() # type: ignore
  29. data = data.replace({"M": 0, "F": 1}) # type: ignore
  30. data.set_index("Image Data ID") # type: ignore
  31. return data
  32. dataset = load_adni_data_from_file(
  33. mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
  34. )
  35. # Divide the dataset into training and validation sets
  36. if config["data"]["seed"] is None:
  37. print("Warning: No seed provided for dataset division, using default seed 0")
  38. config["data"]["seed"] = 0
  39. datasets = divide_dataset(
  40. dataset, config["data"]["data_splits"], seed=config["data"]["seed"]
  41. )
  42. # Initialize the dataloaders
  43. train_loader, val_loader, test_loader = initalize_dataloaders(
  44. datasets, batch_size=config["training"]["batch_size"]
  45. )
  46. # Save seed to output config file
  47. output_config_path = pl.Path(config["output"]["path"]) / "config.json"
  48. if not output_config_path.parent.exists():
  49. output_config_path.parent.mkdir(parents=True, exist_ok=True)
  50. with open(output_config_path, "w") as f:
  51. # Save as JSON
  52. json.dump(config, f, indent=4)
  53. print(f"Configuration saved to {output_config_path}")
  54. # Set up the ensemble training loop
  55. for run_num in range(config["training"]["ensemble_size"]):
  56. print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}")
  57. # Initialize the model
  58. model = (
  59. CNN3D(
  60. image_channels=config["data"]["image_channels"],
  61. clin_data_channels=config["data"]["clin_data_channels"],
  62. num_classes=config["data"]["num_classes"],
  63. droprate=config["training"]["droprate"],
  64. )
  65. .float()
  66. .to(config["training"]["device"])
  67. )
  68. # Set up the optimizer and loss function
  69. optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
  70. criterion = nn.BCELoss()
  71. # Train model
  72. model, history = train_model(
  73. model=model,
  74. train_loader=train_loader,
  75. val_loader=val_loader,
  76. optimizer=optimizer,
  77. criterion=criterion,
  78. num_epochs=config["training"]["num_epochs"],
  79. learning_rate=config["training"]["learning_rate"],
  80. )
  81. # Test model
  82. test_loss, test_acc = test_model(
  83. model=model,
  84. test_loader=test_loader,
  85. criterion=criterion,
  86. )
  87. print(
  88. f"Run {run_num + 1}/{config['training']['ensemble_size']} - "
  89. f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
  90. )
  91. # Save the model
  92. model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt"
  93. torch.save(model.state_dict(), model_save_path)
  94. print(f"Model saved to {model_save_path}")
  95. # Save the training history
  96. history_save_path = (
  97. pl.Path(config["output"]["path"]) / f"history_run_{run_num + 1}.nc"
  98. )
  99. history.to_netcdf(history_save_path, mode="w") # type: ignore
  100. print(f"Training history saved to {history_save_path}")
  101. # Save test results by appending to the results file
  102. test_results_save_path = pl.Path(config["output"]["path"]) / f"results.json"
  103. with open(test_results_save_path, "wr+") as f:
  104. try:
  105. results = json.load(f)
  106. except json.JSONDecodeError:
  107. # If the file is empty or not a valid JSON, initialize an empty list
  108. print("No previous results found, initializing results list.")
  109. results = []
  110. results.append( # type: ignore
  111. {
  112. "run": run_num + 1,
  113. "test_loss": test_loss,
  114. "test_accuracy": test_acc,
  115. }
  116. )
  117. f.seek(0)
  118. json.dump(results, f, indent=4)
  119. print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
  120. # Completion message
  121. print(f"All runs completed. Models and results saved to {config['output']['path']}")