train_model.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Torch
  2. import torch.nn as nn
  3. import torch
  4. import torch.optim as optim
  5. # Config
  6. from utils.config import config
  7. import pathlib as pl
  8. import pandas as pd
  9. import json
  10. import sqlite3 as sql
  11. # Custom modules
  12. from model.cnn import CNN3D
  13. from utils.training import train_model, test_model
  14. from data.dataset import (
  15. load_adni_data_from_file,
  16. divide_dataset,
  17. initalize_dataloaders,
  18. )
  19. # Load data
  20. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  21. xls_file = pl.Path(config["data"]["xls_file_path"])
  22. # Load the data
  23. def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  24. """
  25. Preprocess the Excel DataFrame.
  26. This function can be customized to filter or modify the DataFrame as needed.
  27. """
  28. data = df[["Image Data ID", "Sex", "Age (current)"]]
  29. data["Sex"] = data["Sex"].str.strip() # type: ignore
  30. data = data.replace({"M": 0, "F": 1}) # type: ignore
  31. data.set_index("Image Data ID") # type: ignore
  32. return data
  33. dataset = load_adni_data_from_file(
  34. mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
  35. )
  36. # Divide the dataset into training and validation sets
  37. if config["data"]["seed"] is None:
  38. print("Warning: No seed provided for dataset division, using default seed 0")
  39. config["data"]["seed"] = 0
  40. datasets = divide_dataset(
  41. dataset, config["data"]["data_splits"], seed=config["data"]["seed"]
  42. )
  43. # Initialize the dataloaders
  44. train_loader, val_loader, test_loader = initalize_dataloaders(
  45. datasets, batch_size=config["training"]["batch_size"]
  46. )
  47. # Save seed to output config file
  48. output_config_path = pl.Path(config["output"]["path"]) / "config.json"
  49. if not output_config_path.parent.exists():
  50. output_config_path.parent.mkdir(parents=True, exist_ok=True)
  51. with open(output_config_path, "w") as f:
  52. # Save as JSON
  53. json.dump(config, f, indent=4)
  54. print(f"Configuration saved to {output_config_path}")
  55. # Set up the ensemble training loop
  56. for run_num in range(config["training"]["ensemble_size"]):
  57. print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}")
  58. # Initialize the model
  59. model = (
  60. CNN3D(
  61. image_channels=config["data"]["image_channels"],
  62. clin_data_channels=config["data"]["clin_data_channels"],
  63. num_classes=config["data"]["num_classes"],
  64. droprate=config["training"]["droprate"],
  65. )
  66. .float()
  67. .to(config["training"]["device"])
  68. )
  69. # Set up intermediate model directory
  70. intermediate_model_dir = pl.Path(config["output"]["path"]) / "intermediate_models"
  71. if not intermediate_model_dir.exists():
  72. intermediate_model_dir.mkdir(parents=True, exist_ok=True)
  73. print(f"Intermediate models will be saved to {intermediate_model_dir}")
  74. # Set up the optimizer and loss function
  75. optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
  76. criterion = nn.BCELoss()
  77. # Train model
  78. model, history = train_model(
  79. model=model,
  80. train_loader=train_loader,
  81. val_loader=val_loader,
  82. optimizer=optimizer,
  83. criterion=criterion,
  84. num_epochs=config["training"]["num_epochs"],
  85. output_path=pl.Path(config["output"]["path"]),
  86. )
  87. # Test model
  88. test_loss, test_acc = test_model(
  89. model=model,
  90. test_loader=test_loader,
  91. criterion=criterion,
  92. )
  93. print(
  94. f"Run {run_num + 1}/{config['training']['ensemble_size']} - "
  95. f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
  96. )
  97. # Save the model
  98. model_save_path = pl.Path(config["output"]["path"]) / f"model_run_{run_num + 1}.pt"
  99. torch.save(model.state_dict(), model_save_path)
  100. print(f"Model saved to {model_save_path}")
  101. # Save test results and history by appending to the sql database
  102. results_save_path = pl.Path(config["output"]["path"]) / f"results.sqlite"
  103. with sql.connect(results_save_path) as conn:
  104. # Create results table if it doesn't exist
  105. conn.execute(
  106. """
  107. CREATE TABLE IF NOT EXISTS results (
  108. run INTEGER PRIMARY KEY,
  109. test_loss REAL,
  110. test_accuracy REAL
  111. )
  112. """
  113. )
  114. # Insert the results
  115. conn.execute(
  116. """
  117. INSERT INTO results (run, test_loss, test_accuracy)
  118. VALUES (?, ?, ?)
  119. """,
  120. (run_num + 1, test_loss, test_acc),
  121. )
  122. # Create a new table for the run history
  123. conn.execute(
  124. f"""
  125. CREATE TABLE IF NOT EXISTS history_run_{run_num + 1} (
  126. epoch INTEGER PRIMARY KEY,
  127. train_loss REAL,
  128. val_loss REAL,
  129. train_acc REAL,
  130. val_acc REAL
  131. )
  132. """
  133. )
  134. # Insert the history
  135. for epoch, row in history.iterrows():
  136. values = (
  137. epoch,
  138. float(row["train_loss"]),
  139. float(row["val_loss"]),
  140. float(row["train_acc"]),
  141. float(row["val_acc"]),
  142. )
  143. conn.execute(
  144. f"""
  145. INSERT INTO history_run_{run_num + 1} (epoch, train_loss, val_loss, train_acc, val_acc)
  146. VALUES (?, ?, ?, ?, ?)
  147. """,
  148. values, # type: ignore
  149. )
  150. conn.commit()
  151. print(f"Results and history saved to {results_save_path}")
  152. print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
  153. # Completion message
  154. print(f"All runs completed. Models and results saved to {config['output']['path']}")