"""
    Knowledgedump.org - Image Classification - train_model
    Train a Convolutional Neural Network (CNN) model for image classification.
    We use a cross-entropy loss function, epoch number of 10 and Adam as optimizer here, with learning rate 0.001.
    (Might have to fine-tune these after testing)

    Required packages: torch
"""

import torch
import time


"""
    Function to train the model with CIFAR-10 data.
    Inputs:
        - CNN model (object of torch.nn.Module class), i.e. the model to be trained here.
        - DataLoaders for training and testing data.
        - Number of epochs to train the model (default of 10 here).
        - Learning rate for the optimizer (default is 0.001 here).
        - Device to train the model on - torch.device("cpu") for CPU or None for GPU (if available).
    Returns:
        - Epoch loss and accuracy % values for visualization afterwards.
"""

def train_model(model, trainloader, testloader, epochs=10, learning_rate=0.001, device=None):

    # Set device to CPU or GPU, depending on availability/setting.
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Print the parameters used and whether CPU (False) or GPU is used (True):
    print(f"Epoch number: {epochs}; learning rate (Adam): {learning_rate}; Cuda available: {torch.cuda.is_available()}")

    # Initialize lists to store loss and accuracy for visualization
    train_epoch_losses = []
    train_epoch_accuracies = []
    test_epoch_accuracies = []

    # Move the model to the selected device (GPU or CPU).
    model.to(device)
    
    # Set the optimizer as Adam with respective learning rate.
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Use cross-entropy loss function for the image classification problem.
    loss_fct = torch.nn.CrossEntropyLoss()

    # Loop through the epochs to train the model:
    for epoch in range(epochs):
        # Set model to training mode.
        model.train()

        # Variable to accumulate the loss for this epoch.
        current_loss = 0.0
        # Initialize Counter for correct predictions.
        correct_pred = 0
        # Total number of images processed.
        images_processed = 0
        # Record the start time for the epoch.
        start_time = time.time()
        
        # Loop through the batches in the training data:
        for inputs, labels in trainloader:
            # Move the inputs and labels to the selected device (GPU/CPU)
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Set gradients from the previous step to zero.
            optimizer.zero_grad()
            
            # Perform the forward pass, i.e. compute predicted outputs.
            outputs = model(inputs)
            
            # Calculate the loss.
            loss = loss_fct(outputs, labels)
            
            # Perform backpropagation (compute gradients).
            loss.backward()
            
            # Update the model's parameters, using the optimizer.
            optimizer.step()
            
            # Accumulate the loss for this batch.
            current_loss += loss.item()
            
            # Get the predicted class labels (highest predicted probability).
            _, predicted = torch.max(outputs.data, 1)
            
            # Update correct predictions count.
            images_processed += labels.size(0)
            correct_pred += (predicted == labels).sum().item()
        
        # Calculate average loss and accuracy in % for this epoch.
        epoch_loss = current_loss / len(trainloader)
        epoch_acc = 100 * correct_pred / images_processed
        epoch_time = time.time() - start_time
        
        # Append the performance values to the output lists:
        train_epoch_losses.append(epoch_loss)
        train_epoch_accuracies.append(epoch_acc)

        # Print out statistics for the current epoch.
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%, Time: {epoch_time:.2f}s")
        
        # Evaluate the model after each epoch, to track progress.
        test_epoch_accuracies.append(evaluate_model(model, testloader, device))

    return [train_epoch_losses, train_epoch_accuracies, test_epoch_accuracies]



"""
    Function to evaluate the trained model on the test set.
    Inputs:
        - model (torch.nn.Module object) to train.
        - DataLoader for test data.
        - Device to evaluate the model on (torch.device("cpu") or torch.device("cuda")).
"""

def evaluate_model(model, testloader, device):
    # Set the model to evaluation mode (normally used to disable dropouts and batch normalization, which aren't applied here).
    model.eval()
    
    # Initialize Counters for correct predictions and total number of processed images.
    correct_pred = 0
    images_processed = 0
    
    # Disable gradient computation for evaluation (faster and uses less memory)
    with torch.no_grad():
        # Loop through the test set:
        for inputs, labels in testloader:
            # Move inputs and labels to the selected device
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Perform the forward pass (compute predicted outputs).
            outputs = model(inputs)
            
            # Get the predicted class labels.
            _, predicted = torch.max(outputs.data, 1)
            
            # Update correct predictions count.
            images_processed += labels.size(0)
            correct_pred += (predicted == labels).sum().item()
    
    # Calculate the accuracy on the test set in %.
    accuracy = 100 * correct_pred / images_processed
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy
