from model import ModelCT # Import your custom model architecture import os # For file and directory operations import numpy as np # For numerical operations and data shuffling import torch # For PyTorch operations from torch.utils.data import DataLoader # To load data in batches from datareader import DataReader # Custom data reader for loading processed data from sklearn.metrics import roc_auc_score # To compute ROC AUC metric import time # For time tracking during training ############################################################################ # DATA PREPARATION # ############################################################################ # Define the main folder containing all processed data main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/" # Change this path as needed # Get file lists from the subdirectories 'mild' and 'severe' files_mild = os.listdir(os.path.join(main_path_to_data, "mild")) files_severe = os.listdir(os.path.join(main_path_to_data, "severe")) # Extract unique patient IDs from filenames (assumes format: _slice_.npy) patients_mild = {file.split("_slice_")[0] for file in files_mild} patients_severe = {file.split("_slice_")[0] for file in files_severe} # Build a list of tuples (file_path, label) for central slice (assumed slice_4) # Label 0 for 'mild' and 1 for 'severe' all_central_slices = ( [("mild/" + patient_ID + "_slice_4.npy", 0) for patient_ID in patients_mild] + [("severe/" + patient_ID + "_slice_4.npy", 1) for patient_ID in patients_severe] ) # Set a fixed random seed for reproducibility and shuffle the list of samples np.random.seed(42) np.random.shuffle(all_central_slices) # Split data into train (60%), validation (20%), and test (20%) sets total_samples = len(all_central_slices) train_info = all_central_slices[: int(0.6 * total_samples)] valid_info = all_central_slices[int(0.6 * total_samples) : int(0.8 * total_samples)] test_info = all_central_slices[int(0.8 * total_samples) :] ############################################################################ # MODEL & HYPERPARAMETERS # ############################################################################ # Instantiate the model model = ModelCT() # Set the device for training (GPU if available, otherwise CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Extract hyperparameters for convenience learning_rate = 0.2e-3 # Base learning rate for the optimizer weight_decay = 0.0001 # Regularization parameter total_epoch = 10 # Total number of training epochs ############################################################################ # DATA LOADER CREATION # ############################################################################ # Create dataset objects for training and validation using the DataReader train_dataset = DataReader(main_path_to_data, train_info) valid_dataset = DataReader(main_path_to_data, valid_info) # Create DataLoader objects for training and validation # Batch size of 16 for training; 10 for validation (assuming one patient per batch if needed) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,) valid_loader = DataLoader(valid_dataset, batch_size=10, shuffle=False) ############################################################################ # MODEL PREPARATION # ############################################################################ # Move the model to the selected device (GPU/CPU) model.to(device) # Define the loss function (BCEWithLogitsLoss combines sigmoid activation with binary cross-entropy loss) criterion = torch.nn.BCEWithLogitsLoss() # Define the optimizer (Adam) with the specified learning rate and weight decay optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) ############################################################################ # TRAINING LOOP # ############################################################################ if __name__ == '__main__': # use this to avoid Broken Pipe Error in torch 1.9 # Initialize the variable to track the best validation AUC best_auc = -np.inf # Loop over epochs for epoch in range(total_epoch): start_time = time.time() # Start time for the current epoch print(f"Epoch: {epoch + 1}/{total_epoch}") running_loss = 0.0 # Initialize the running loss for this epoch # Set the model to training mode model.train() # Iterate over the training data batches for x, y in train_loader: # Move data to the correct device x, y = x.to(device), y.to(device) # Zero the gradients from the previous iteration optimizer.zero_grad() # Forward pass: Compute predictions using the model y_hat = model(x) # Using model(x) calls the __call__ method loss = criterion(y_hat, y) # Compute the loss # Backward pass: Compute gradient of the loss with respect to model parameters loss.backward() # Update model weights optimizer.step() # Accumulate loss for tracking running_loss += loss.item() # Compute average loss for the epoch avg_loss = running_loss / len(train_loader) ############################################################################ # VALIDATION LOOP # ############################################################################ # Prepare to store predictions and ground truth labels predictions = [] trues = [] # Set the model to evaluation mode (disables dropout, batch norm, etc.) model.eval() with torch.no_grad(): # Iterate over the validation data batches for x, y in valid_loader: # Move data to the appropriate device x, y = x.to(device), y.to(device) # Forward pass to compute predictions y_hat = model(x) # Apply sigmoid to convert logits to probabilities probs = torch.sigmoid(y_hat) # Flatten and collect predictions and true labels predictions.extend(probs.cpu().numpy().flatten().tolist()) trues.extend(y.cpu().numpy().flatten().tolist()) # Compute ROC AUC for validation set and handle potential errors auc = roc_auc_score(trues, predictions) # Print training metrics for the epoch elapsed_time = time.time() - start_time print(f"AUC: {auc:.4f}, Avg Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f} sec") # Save the best model based on validation AUC if auc > best_auc: best_model_path = os.path.join("trained_models/", 'trained_model_weights.pth') torch.save(model.state_dict(), best_model_path) best_auc = auc