train_ensemble.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Torch
  2. import torch.nn as nn
  3. import torch
  4. import torch.optim as optim
  5. # Config
  6. from utils.config import config
  7. import pathlib as pl
  8. import pandas as pd
  9. import json
  10. import sqlite3 as sql
  11. # Custom modules
  12. from model.cnn import CNN3D
  13. from utils.training import train_model, test_model
  14. from data.dataset import (
  15. load_adni_data_from_file,
  16. divide_dataset_by_patient_id,
  17. initalize_dataloaders,
  18. )
  19. # Load data
  20. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  21. xls_file = pl.Path(config["data"]["xls_file_path"])
  22. # Load the data
  23. def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  24. """
  25. Preprocess the Excel DataFrame.
  26. This function can be customized to filter or modify the DataFrame as needed.
  27. """
  28. data = df[["Image Data ID", "Sex", "Age (current)"]]
  29. data["Sex"] = data["Sex"].str.strip() # type: ignore
  30. data = data.replace({"M": 0, "F": 1}) # type: ignore
  31. data.set_index("Image Data ID") # type: ignore
  32. return data
  33. dataset = load_adni_data_from_file(
  34. mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
  35. )
  36. # Divide the dataset into training/validation/test sets
  37. if config["data"]["seed"] is None:
  38. print("Warning: No seed provided for dataset division, using default seed 0")
  39. config["data"]["seed"] = 0
  40. ptid_df = pd.read_csv(xls_file)
  41. ptid_df.columns = ptid_df.columns.str.strip()
  42. ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna( # type: ignore
  43. subset=["Image Data ID", "PTID"]
  44. )
  45. ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
  46. ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
  47. ptid_df = ptid_df[ptid_df["PTID"] != ""]
  48. ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
  49. # Split is grouped by PTID to prevent patient-level leakage across partitions.
  50. datasets = divide_dataset_by_patient_id(
  51. dataset,
  52. ptids,
  53. config["data"]["data_splits"],
  54. seed=config["data"]["seed"],
  55. )
  56. # Initialize the dataloaders
  57. train_loader, val_loader, test_loader = initalize_dataloaders(
  58. datasets, batch_size=config["training"]["batch_size"]
  59. )
  60. ensemble_output_path = pl.Path(config["output"]["ensemble_path"])
  61. # Save seed to output config file
  62. output_config_path = ensemble_output_path / "config.json"
  63. if not output_config_path.parent.exists():
  64. output_config_path.parent.mkdir(parents=True, exist_ok=True)
  65. with open(output_config_path, "w") as f:
  66. # Save as JSON
  67. json.dump(config, f, indent=4)
  68. print(f"Configuration saved to {output_config_path}")
  69. # Set up the ensemble training loop
  70. for run_num in range(config["training"]["ensemble_size"]):
  71. print(f"Starting run {run_num + 1}/{config['training']['ensemble_size']}")
  72. # Initialize the model
  73. model = (
  74. CNN3D(
  75. image_channels=config["data"]["image_channels"],
  76. clin_data_channels=config["data"]["clin_data_channels"],
  77. num_classes=config["data"]["num_classes"],
  78. droprate=config["training"]["droprate"],
  79. )
  80. .float()
  81. .to(config["training"]["device"])
  82. )
  83. # Set up intermediate model directory
  84. intermediate_model_dir = ensemble_output_path / "intermediate_models"
  85. if not intermediate_model_dir.exists():
  86. intermediate_model_dir.mkdir(parents=True, exist_ok=True)
  87. print(f"Intermediate models will be saved to {intermediate_model_dir}")
  88. # Set up the optimizer and loss function
  89. optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
  90. criterion = nn.BCELoss()
  91. # Train model
  92. model, history = train_model(
  93. model=model,
  94. train_loader=train_loader,
  95. val_loader=val_loader,
  96. optimizer=optimizer,
  97. criterion=criterion,
  98. num_epochs=config["training"]["num_epochs"],
  99. output_path=ensemble_output_path,
  100. )
  101. # Test model
  102. test_loss, test_acc = test_model(
  103. model=model,
  104. test_loader=test_loader,
  105. criterion=criterion,
  106. )
  107. print(
  108. f"Run {run_num + 1}/{config['training']['ensemble_size']} - "
  109. f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}"
  110. )
  111. # Save the model
  112. model_save_path = ensemble_output_path / f"model_run_{run_num + 1}.pt"
  113. torch.save(model.state_dict(), model_save_path)
  114. print(f"Model saved to {model_save_path}")
  115. # Save test results and history by appending to the sql database
  116. results_save_path = ensemble_output_path / f"results.sqlite"
  117. with sql.connect(results_save_path) as conn:
  118. # Create results table if it doesn't exist
  119. conn.execute(
  120. """
  121. CREATE TABLE IF NOT EXISTS results (
  122. run INTEGER PRIMARY KEY,
  123. test_loss REAL,
  124. test_accuracy REAL
  125. )
  126. """
  127. )
  128. # Insert the results
  129. conn.execute(
  130. """
  131. INSERT INTO results (run, test_loss, test_accuracy)
  132. VALUES (?, ?, ?)
  133. """,
  134. (run_num + 1, test_loss, test_acc),
  135. )
  136. # Create a new table for the run history
  137. conn.execute(
  138. f"""
  139. CREATE TABLE IF NOT EXISTS history_run_{run_num + 1} (
  140. epoch INTEGER PRIMARY KEY,
  141. train_loss REAL,
  142. val_loss REAL,
  143. train_acc REAL,
  144. val_acc REAL
  145. )
  146. """
  147. )
  148. # Insert the history
  149. for epoch, row in history.iterrows():
  150. values = (
  151. epoch,
  152. float(row["train_loss"]),
  153. float(row["val_loss"]),
  154. float(row["train_acc"]),
  155. float(row["val_acc"]),
  156. )
  157. conn.execute(
  158. f"""
  159. INSERT INTO history_run_{run_num + 1} (epoch, train_loss, val_loss, train_acc, val_acc)
  160. VALUES (?, ?, ?, ?, ?)
  161. """,
  162. values, # type: ignore
  163. )
  164. conn.commit()
  165. print(f"Results and history saved to {results_save_path}")
  166. print(f"Run {run_num + 1}/{config['training']['ensemble_size']} completed\n")
  167. # Completion message
  168. print(
  169. f"All runs completed. Models and results saved to {config['output']['ensemble_path']}"
  170. )