training.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from model import ModelCT # Import your custom model architecture
  2. import os # For file and directory operations
  3. import numpy as np # For numerical operations and data shuffling
  4. import torch # For PyTorch operations
  5. from torch.utils.data import DataLoader # To load data in batches
  6. from datareader import DataReader # Custom data reader for loading processed data
  7. from sklearn.metrics import roc_auc_score # To compute ROC AUC metric
  8. import time # For time tracking during training
  9. ############################################################################
  10. # DATA PREPARATION #
  11. ############################################################################
  12. # Define the main folder containing all processed data
  13. main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/" # Change this path as needed
  14. # Get file lists from the subdirectories 'mild' and 'severe'
  15. files_mild = os.listdir(os.path.join(main_path_to_data, "mild"))
  16. files_severe = os.listdir(os.path.join(main_path_to_data, "severe"))
  17. # Extract unique patient IDs from filenames (assumes format: <patientID>_slice_<number>.npy)
  18. patients_mild = {file.split("_slice_")[0] for file in files_mild}
  19. patients_severe = {file.split("_slice_")[0] for file in files_severe}
  20. # Build a list of tuples (file_path, label) for central slice (assumed slice_4)
  21. # Label 0 for 'mild' and 1 for 'severe'
  22. all_central_slices = (
  23. [("mild/" + patient_ID + "_slice_4.npy", 0) for patient_ID in patients_mild] +
  24. [("severe/" + patient_ID + "_slice_4.npy", 1) for patient_ID in patients_severe]
  25. )
  26. # Set a fixed random seed for reproducibility and shuffle the list of samples
  27. np.random.seed(42)
  28. np.random.shuffle(all_central_slices)
  29. # Split data into train (60%), validation (20%), and test (20%) sets
  30. total_samples = len(all_central_slices)
  31. train_info = all_central_slices[: int(0.6 * total_samples)]
  32. valid_info = all_central_slices[int(0.6 * total_samples) : int(0.8 * total_samples)]
  33. test_info = all_central_slices[int(0.8 * total_samples) :]
  34. ############################################################################
  35. # MODEL & HYPERPARAMETERS #
  36. ############################################################################
  37. # Instantiate the model
  38. model = ModelCT()
  39. # Set the device for training (GPU if available, otherwise CPU)
  40. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  41. # Extract hyperparameters for convenience
  42. learning_rate = 0.2e-3 # Base learning rate for the optimizer
  43. weight_decay = 0.0001 # Regularization parameter
  44. total_epoch = 10 # Total number of training epochs
  45. ############################################################################
  46. # DATA LOADER CREATION #
  47. ############################################################################
  48. # Create dataset objects for training and validation using the DataReader
  49. train_dataset = DataReader(main_path_to_data, train_info)
  50. valid_dataset = DataReader(main_path_to_data, valid_info)
  51. # Create DataLoader objects for training and validation
  52. # Batch size of 16 for training; 10 for validation (assuming one patient per batch if needed)
  53. train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,)
  54. valid_loader = DataLoader(valid_dataset, batch_size=10, shuffle=False)
  55. ############################################################################
  56. # MODEL PREPARATION #
  57. ############################################################################
  58. # Move the model to the selected device (GPU/CPU)
  59. model.to(device)
  60. # Define the loss function (BCEWithLogitsLoss combines sigmoid activation with binary cross-entropy loss)
  61. criterion = torch.nn.BCEWithLogitsLoss()
  62. # Define the optimizer (Adam) with the specified learning rate and weight decay
  63. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
  64. ############################################################################
  65. # TRAINING LOOP #
  66. ############################################################################
  67. if __name__ == '__main__': # use this to avoid Broken Pipe Error in torch 1.9
  68. # Initialize the variable to track the best validation AUC
  69. best_auc = -np.inf
  70. # Loop over epochs
  71. for epoch in range(total_epoch):
  72. start_time = time.time() # Start time for the current epoch
  73. print(f"Epoch: {epoch + 1}/{total_epoch}")
  74. running_loss = 0.0 # Initialize the running loss for this epoch
  75. # Set the model to training mode
  76. model.train()
  77. # Iterate over the training data batches
  78. for x, y in train_loader:
  79. # Move data to the correct device
  80. x, y = x.to(device), y.to(device)
  81. # Zero the gradients from the previous iteration
  82. optimizer.zero_grad()
  83. # Forward pass: Compute predictions using the model
  84. y_hat = model(x) # Using model(x) calls the __call__ method
  85. loss = criterion(y_hat, y) # Compute the loss
  86. # Backward pass: Compute gradient of the loss with respect to model parameters
  87. loss.backward()
  88. # Update model weights
  89. optimizer.step()
  90. # Accumulate loss for tracking
  91. running_loss += loss.item()
  92. # Compute average loss for the epoch
  93. avg_loss = running_loss / len(train_loader)
  94. ############################################################################
  95. # VALIDATION LOOP #
  96. ############################################################################
  97. # Prepare to store predictions and ground truth labels
  98. predictions = []
  99. trues = []
  100. # Set the model to evaluation mode (disables dropout, batch norm, etc.)
  101. model.eval()
  102. with torch.no_grad():
  103. # Iterate over the validation data batches
  104. for x, y in valid_loader:
  105. # Move data to the appropriate device
  106. x, y = x.to(device), y.to(device)
  107. # Forward pass to compute predictions
  108. y_hat = model(x)
  109. # Apply sigmoid to convert logits to probabilities
  110. probs = torch.sigmoid(y_hat)
  111. # Flatten and collect predictions and true labels
  112. predictions.extend(probs.cpu().numpy().flatten().tolist())
  113. trues.extend(y.cpu().numpy().flatten().tolist())
  114. # Compute ROC AUC for validation set and handle potential errors
  115. auc = roc_auc_score(trues, predictions)
  116. # Print training metrics for the epoch
  117. elapsed_time = time.time() - start_time
  118. print(f"AUC: {auc:.4f}, Avg Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f} sec")
  119. # Save the best model based on validation AUC
  120. if auc > best_auc:
  121. best_model_path = os.path.join("trained_models/", 'trained_model_weights.pth')
  122. torch.save(model.state_dict(), best_model_path)
  123. best_auc = auc