from model import ModelCT           # Import your custom model architecture
import os                           # For file and directory operations
import numpy as np                  # For numerical operations and data shuffling
import torch                        # For PyTorch operations
from torch.utils.data import DataLoader  # To load data in batches
from datareader import DataReader   # Custom data reader for loading processed data
from sklearn.metrics import roc_auc_score  # To compute ROC AUC metric
import time                         # For time tracking during training

############################################################################
#                           DATA PREPARATION                               #
############################################################################
# Define the main folder containing all processed data
main_path_to_data = "/data/PSUF_naloge/5-naloga/processed/"  # Change this path as needed

# Get file lists from the subdirectories 'mild' and 'severe'
files_mild = os.listdir(os.path.join(main_path_to_data, "mild"))
files_severe = os.listdir(os.path.join(main_path_to_data, "severe"))

# Extract unique patient IDs from filenames (assumes format: <patientID>_slice_<number>.npy)
patients_mild = {file.split("_slice_")[0] for file in files_mild}
patients_severe = {file.split("_slice_")[0] for file in files_severe}

# Build a list of tuples (file_path, label) for central slice (assumed slice_4)
# Label 0 for 'mild' and 1 for 'severe'
all_central_slices = (
    [("mild/" + patient_ID + "_slice_4.npy", 0) for patient_ID in patients_mild] +
    [("severe/" + patient_ID + "_slice_4.npy", 1) for patient_ID in patients_severe]
)

# Set a fixed random seed for reproducibility and shuffle the list of samples
np.random.seed(42)
np.random.shuffle(all_central_slices)

# Split data into train (60%), validation (20%), and test (20%) sets
total_samples = len(all_central_slices)
train_info = all_central_slices[: int(0.6 * total_samples)]
valid_info = all_central_slices[int(0.6 * total_samples) : int(0.8 * total_samples)]
test_info  = all_central_slices[int(0.8 * total_samples) :]

############################################################################
#                           MODEL & HYPERPARAMETERS                        #
############################################################################
# Instantiate the model
model = ModelCT()
    
# Set the device for training (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Extract hyperparameters for convenience
learning_rate =  0.2e-3 # Base learning rate for the optimizer
weight_decay = 0.0001 # Regularization parameter
total_epoch = 10 # Total number of training epochs

############################################################################
#                        DATA LOADER CREATION                              #
############################################################################
# Create dataset objects for training and validation using the DataReader
train_dataset = DataReader(main_path_to_data, train_info)
valid_dataset = DataReader(main_path_to_data, valid_info)

# Create DataLoader objects for training and validation
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,)
valid_loader = DataLoader(valid_dataset, batch_size=10, shuffle=False)

############################################################################
#                           MODEL PREPARATION                              #
############################################################################
# Move the model to the selected device (GPU/CPU)
model.to(device)

# Define the loss function (BCEWithLogitsLoss combines sigmoid activation with binary cross-entropy loss)
criterion = torch.nn.BCEWithLogitsLoss()

# Define the optimizer (Adam) with the specified learning rate and weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

############################################################################
#                           TRAINING LOOP                                  #
############################################################################
if __name__ == '__main__': # use this to avoid Broken Pipe Error in torch 1.9
    
    # Initialize the variable to track the best validation AUC
    best_auc = -np.inf
    # Loop over epochs
    for epoch in range(total_epoch):
        start_time = time.time()  # Start time for the current epoch
        print(f"Epoch: {epoch + 1}/{total_epoch}")
        
        running_loss = 0.0  # Initialize the running loss for this epoch
        
        # Set the model to training mode
        model.train()
        
        # Iterate over the training data batches
        for x, y in train_loader:
            # Move data to the correct device
            x, y = x.to(device), y.to(device)
            
            # Zero the gradients from the previous iteration
            optimizer.zero_grad()
            
            # Forward pass: Compute predictions using the model
            y_hat = model(x)  # Using model(x) calls the __call__ method
            loss = criterion(y_hat, y)  # Compute the loss
            
            # Backward pass: Compute gradient of the loss with respect to model parameters
            loss.backward()
            
            # Update model weights
            optimizer.step()
            
            # Accumulate loss for tracking
            running_loss += loss.item()
        
        # Compute average loss for the epoch
        avg_loss = running_loss / len(train_loader)
        
        ############################################################################
        #                           VALIDATION LOOP                                #
        ############################################################################
        # Prepare to store predictions and ground truth labels
        predictions = []
        trues = []
        
        # Set the model to evaluation mode (disables dropout, batch norm, etc.)
        model.eval()
        with torch.no_grad():
            # Iterate over the validation data batches
            for x, y in valid_loader:
                # Move data to the appropriate device
                x, y = x.to(device), y.to(device)
                
                # Forward pass to compute predictions
                y_hat = model(x)
                
                # Apply sigmoid to convert logits to probabilities
                probs = torch.sigmoid(y_hat)
                
                # Flatten and collect predictions and true labels
                predictions.extend(probs.cpu().numpy().flatten().tolist())
                trues.extend(y.cpu().numpy().flatten().tolist())
        
        # Compute ROC AUC for validation set and handle potential errors
        auc = roc_auc_score(trues, predictions)

        # Print training metrics for the epoch
        elapsed_time = time.time() - start_time
        print(f"AUC: {auc:.4f}, Avg Loss: {avg_loss:.4f}, Time: {elapsed_time:.2f} sec")
        
        # Save the best model based on validation AUC
        if auc > best_auc:
            best_model_path = os.path.join("trained_models/", 'trained_model_weights.pth')
            torch.save(model.state_dict(), best_model_path)
            best_auc = auc