|
@@ -0,0 +1,156 @@
|
|
|
+from model import ModelCT # Import your custom model architecture
|
|
|
+import os # For file and directory operations
|
|
|
+import numpy as np # For numerical operations and data shuffling
|
|
|
+import torch # For PyTorch operations
|
|
|
+from torch.utils.data import DataLoader # To load data in batches
|
|
|
+from datareader import DataReader # Custom data reader for loading processed data
|
|
|
+from sklearn.metrics import roc_auc_score # To compute ROC AUC metric
|
|
|
+import time # For time tracking during training
|
|
|
+
|
|
|
+############################################################################
|
|
|
+# DATA PREPARATION #
|
|
|
+############################################################################
|
|
|
+# Define the main folder containing all processed data
|
|
|
+main_path_to_data = r"C:\Users\Klanecek\Desktop\processed" # Change this path as needed
|
|
|
+
|
|
|
+# Get file lists from the subdirectories 'mild' and 'severe'
|
|
|
+files_mild = os.listdir(os.path.join(main_path_to_data, "mild"))
|
|
|
+files_severe = os.listdir(os.path.join(main_path_to_data, "severe"))
|
|
|
+
|
|
|
+# Extract unique patient IDs from filenames (assumes format: <patientID>_slice_<number>.npy)
|
|
|
+patients_mild = {file.split("_slice_")[0] for file in files_mild}
|
|
|
+patients_severe = {file.split("_slice_")[0] for file in files_severe}
|
|
|
+
|
|
|
+# Build a list of tuples (file_path, label) for central slice (assumed slice_4)
|
|
|
+# Label 0 for 'mild' and 1 for 'severe'
|
|
|
+all_central_slices = (
|
|
|
+ [("mild/" + patient_ID + "_slice_4.npy", 0) for patient_ID in patients_mild] +
|
|
|
+ [("severe/" + patient_ID + "_slice_4.npy", 1) for patient_ID in patients_severe]
|
|
|
+)
|
|
|
+
|
|
|
+# Set a fixed random seed for reproducibility and shuffle the list of samples
|
|
|
+np.random.seed(42)
|
|
|
+np.random.shuffle(all_central_slices)
|
|
|
+
|
|
|
+# Split data into train (60%), validation (20%), and test (20%) sets
|
|
|
+total_samples = len(all_central_slices)
|
|
|
+train_info = all_central_slices[: int(0.6 * total_samples)]
|
|
|
+valid_info = all_central_slices[int(0.6 * total_samples) : int(0.8 * total_samples)]
|
|
|
+test_info = all_central_slices[int(0.8 * total_samples) :]
|
|
|
+
|
|
|
+############################################################################
|
|
|
+# MODEL & HYPERPARAMETERS #
|
|
|
+############################################################################
|
|
|
+# Instantiate the model
|
|
|
+model = ModelCT()
|
|
|
+
|
|
|
+# Set the device for training (GPU if available, otherwise CPU)
|
|
|
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+
|
|
|
+# Extract hyperparameters for convenience
|
|
|
+learning_rate = 0.2e-3 # Base learning rate for the optimizer
|
|
|
+weight_decay = 0.0001 # Regularization parameter
|
|
|
+total_epoch = 10 # Total number of training epochs
|
|
|
+
|
|
|
+############################################################################
|
|
|
+# DATA LOADER CREATION #
|
|
|
+############################################################################
|
|
|
+# Create dataset objects for training and validation using the DataReader
|
|
|
+train_dataset = DataReader(main_path_to_data, train_info)
|
|
|
+valid_dataset = DataReader(main_path_to_data, valid_info)
|
|
|
+
|
|
|
+# Create DataLoader objects for training and validation
|
|
|
+# Batch size of 16 for training; 10 for validation (assuming one patient per batch if needed)
|
|
|
+train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,)
|
|
|
+valid_loader = DataLoader(valid_dataset, batch_size=10, shuffle=False)
|
|
|
+
|
|
|
+############################################################################
|
|
|
+# MODEL PREPARATION #
|
|
|
+############################################################################
|
|
|
+# Move the model to the selected device (GPU/CPU)
|
|
|
+model.to(device)
|
|
|
+
|
|
|
+# Define the loss function (BCEWithLogitsLoss combines sigmoid activation with binary cross-entropy loss)
|
|
|
+criterion = torch.nn.BCEWithLogitsLoss()
|
|
|
+
|
|
|
+# Define the optimizer (Adam) with the specified learning rate and weight decay
|
|
|
+optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
|
|
+
|
|
|
+############################################################################
|
|
|
+# TRAINING LOOP #
|
|
|
+############################################################################
|
|
|
+if __name__ == '__main__': # use this to avoid Broken Pipe Error in torch 1.9
|
|
|
+
|
|
|
+ # Initialize the variable to track the best validation AUC
|
|
|
+ best_auc = -np.inf
|
|
|
+ # Loop over epochs
|
|
|
+ for epoch in range(total_epoch):
|
|
|
+ start_time = time.time() # Start time for the current epoch
|
|
|
+ print(f"Epoch: {epoch + 1}/{total_epoch}")
|
|
|
+
|
|
|
+ running_loss = 0.0 # Initialize the running loss for this epoch
|
|
|
+
|
|
|
+ # Set the model to training mode
|
|
|
+ model.train()
|
|
|
+
|
|
|
+ # Iterate over the training data batches
|
|
|
+ for x, y in train_loader:
|
|
|
+ # Move data to the correct device
|
|
|
+ x, y = x.to(device), y.to(device)
|
|
|
+
|
|
|
+ # Zero the gradients from the previous iteration
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ # Forward pass: Compute predictions using the model
|
|
|
+ y_hat = model(x) # Using model(x) calls the __call__ method
|
|
|
+ loss = criterion(y_hat, y) # Compute the loss
|
|
|
+
|
|
|
+ # Backward pass: Compute gradient of the loss with respect to model parameters
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ # Update model weights
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ # Accumulate loss for tracking
|
|
|
+ running_loss += loss.item()
|
|
|
+
|
|
|
+ # Compute average loss for the epoch
|
|
|
+ avg_loss = running_loss / len(train_loader)
|
|
|
+
|
|
|
+ ############################################################################
|
|
|
+ # VALIDATION LOOP #
|
|
|
+ ############################################################################
|
|
|
+ # Prepare to store predictions and ground truth labels
|
|
|
+ predictions = []
|
|
|
+ trues = []
|
|
|
+
|
|
|
+ # Set the model to evaluation mode (disables dropout, batch norm, etc.)
|
|
|
+ model.eval()
|
|
|
+ with torch.no_grad():
|
|
|
+ # Iterate over the validation data batches
|
|
|
+ for x, y in valid_loader:
|
|
|
+ # Move data to the appropriate device
|
|
|
+ x, y = x.to(device), y.to(device)
|
|
|
+
|
|
|
+ # Forward pass to compute predictions
|
|
|
+ y_hat = model(x)
|
|
|
+
|
|
|
+ # Apply sigmoid to convert logits to probabilities
|
|
|
+ probs = torch.sigmoid(y_hat)
|
|
|
+
|
|
|
+ # Flatten and collect predictions and true labels
|
|
|
+ predictions.extend(probs.cpu().numpy().flatten().tolist())
|
|
|
+ trues.extend(y.cpu().numpy().flatten().tolist())
|
|
|
+
|
|
|
+ # Compute ROC AUC for validation set and handle potential errors
|
|
|
+ auc = roc_auc_score(trues, predictions)
|
|
|
+
|
|
|
+ # Print training metrics for the epoch
|
|
|
+ elapsed_time = time.time() - start_time
|
|
|
+ print(f"AUC: {auc:.4f}, Avg Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f} sec")
|
|
|
+
|
|
|
+ # Save the best model based on validation AUC
|
|
|
+ if auc > best_auc:
|
|
|
+ best_model_path = os.path.join("trained_models/", 'trained_model_weights.pth')
|
|
|
+ torch.save(model.state_dict(), best_model_path)
|
|
|
+ best_auc = auc
|