train_bayesian.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Torch
  2. import torch.nn as nn
  3. import torch
  4. import torch.optim as optim
  5. # Config
  6. from utils.config import config
  7. import pathlib as pl
  8. import pandas as pd
  9. import json
  10. import sqlite3 as sql
  11. from typing import Callable, cast
  12. import os
  13. # Bayesian torch
  14. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss # type: ignore[import-untyped]
  15. # Custom modules
  16. from model.cnn import CNN3D
  17. from utils.training import train_model_bayesian, test_model_bayesian
  18. from data.dataset import (
  19. load_adni_data_from_file,
  20. divide_dataset_by_patient_id,
  21. initalize_dataloaders,
  22. )
  23. # Load data
  24. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  25. xls_file = pl.Path(config["data"]["xls_file_path"])
  26. # If current directory is MedPhys_Research, change to alnn_rewrite for relative imports to work
  27. try:
  28. os.chdir("alnn_rewrite")
  29. except FileNotFoundError:
  30. pass
  31. def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  32. """
  33. Preprocess the Excel DataFrame.
  34. This function can be customized to filter or modify the DataFrame as needed.
  35. """
  36. data = df[["Image Data ID", "Sex", "Age (current)"]]
  37. data["Sex"] = data["Sex"].str.strip() # type: ignore
  38. data = data.replace({"M": 0, "F": 1}) # type: ignore
  39. data.set_index("Image Data ID") # type: ignore
  40. return data
  41. dataset = load_adni_data_from_file(
  42. mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
  43. )
  44. if config["data"]["seed"] is None:
  45. print("Warning: No seed provided for dataset division, using default seed 0")
  46. config["data"]["seed"] = 0
  47. ptid_df = pd.read_csv(xls_file)
  48. ptid_df.columns = ptid_df.columns.str.strip()
  49. ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna( # type: ignore
  50. subset=["Image Data ID", "PTID"]
  51. )
  52. ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
  53. ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
  54. ptid_df = ptid_df[ptid_df["PTID"] != ""]
  55. ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
  56. # Split is grouped by PTID to prevent patient-level leakage across partitions.
  57. datasets = divide_dataset_by_patient_id(
  58. dataset,
  59. ptids,
  60. config["data"]["data_splits"],
  61. seed=config["data"]["seed"],
  62. )
  63. # Initialize the dataloaders
  64. train_loader, val_loader, test_loader = initalize_dataloaders(
  65. datasets, batch_size=config["training"]["batch_size"]
  66. )
  67. bayesian_output_path = pl.Path(config["output"]["bayesian_path"])
  68. # Save seed to output config file
  69. output_config_path = bayesian_output_path / "config.json"
  70. if not output_config_path.parent.exists():
  71. output_config_path.parent.mkdir(parents=True, exist_ok=True)
  72. with open(output_config_path, "w") as f:
  73. json.dump(config, f, indent=4)
  74. print(f"Configuration saved to {output_config_path}")
  75. # Initialize the model
  76. model = (
  77. CNN3D(
  78. image_channels=config["data"]["image_channels"],
  79. clin_data_channels=config["data"]["clin_data_channels"],
  80. num_classes=config["data"]["num_classes"],
  81. droprate=config["training"]["droprate"],
  82. )
  83. .float()
  84. .to(config["training"]["device"])
  85. )
  86. const_bnn_prior_parameters: dict[str, float | bool | str] = {
  87. "prior_mu": 0.0,
  88. "prior_sigma": 1.0,
  89. "posterior_mu_init": 0.0,
  90. "posterior_rho_init": -3.0,
  91. "type": "Reparameterization",
  92. "moped_enable": False,
  93. "moped_delta": 0.5,
  94. }
  95. dnn_to_bnn(model, const_bnn_prior_parameters)
  96. model.to(config["training"]["device"])
  97. # Set up intermediate model directory
  98. intermediate_model_dir = bayesian_output_path / "intermediate_models"
  99. if not intermediate_model_dir.exists():
  100. intermediate_model_dir.mkdir(parents=True, exist_ok=True)
  101. print(f"Intermediate models will be saved to {intermediate_model_dir}")
  102. # Set up the optimizer and loss function
  103. optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])
  104. criterion = nn.BCELoss()
  105. # Train model
  106. model, history = train_model_bayesian(
  107. model=model,
  108. train_loader=train_loader,
  109. val_loader=val_loader,
  110. optimizer=optimizer,
  111. criterion=criterion,
  112. num_epochs=config["training"]["num_epochs"],
  113. output_path=bayesian_output_path,
  114. get_kl_loss=get_kl_loss,
  115. )
  116. # Test model
  117. test_loss, test_acc = test_model_bayesian(
  118. model=model,
  119. test_loader=test_loader,
  120. criterion=criterion,
  121. get_kl_loss=get_kl_loss,
  122. )
  123. print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
  124. # Save the model
  125. model_save_path = bayesian_output_path / "model_bayesian.pt"
  126. torch.save(model.state_dict(), model_save_path)
  127. print(f"Model saved to {model_save_path}")
  128. # Save test results and history by appending to the sql database
  129. results_save_path = bayesian_output_path / "results.sqlite"
  130. with sql.connect(results_save_path) as conn:
  131. conn.execute(
  132. """
  133. CREATE TABLE IF NOT EXISTS results (
  134. run INTEGER PRIMARY KEY,
  135. test_loss REAL,
  136. test_accuracy REAL
  137. )
  138. """
  139. )
  140. conn.execute(
  141. """
  142. INSERT OR REPLACE INTO results (run, test_loss, test_accuracy)
  143. VALUES (?, ?, ?)
  144. """,
  145. (1, test_loss, test_acc),
  146. )
  147. conn.execute(
  148. """
  149. CREATE TABLE IF NOT EXISTS history_run_1 (
  150. epoch INTEGER PRIMARY KEY,
  151. train_loss REAL,
  152. val_loss REAL,
  153. train_acc REAL,
  154. val_acc REAL
  155. )
  156. """
  157. )
  158. conn.execute("DELETE FROM history_run_1")
  159. for epoch, row in history.iterrows():
  160. values = (
  161. epoch,
  162. float(row["train_loss"]),
  163. float(row["val_loss"]),
  164. float(row["train_acc"]),
  165. float(row["val_acc"]),
  166. )
  167. conn.execute(
  168. """
  169. INSERT INTO history_run_1 (epoch, train_loss, val_loss, train_acc, val_acc)
  170. VALUES (?, ?, ?, ?, ?)
  171. """,
  172. values,
  173. )
  174. conn.commit()
  175. print(f"Results and history saved to {results_save_path}")
  176. print(f"Bayesian training completed. Model and results saved to {bayesian_output_path}")