"""
    Knowledgedump.org - Image Classification - define_model
    Defining a simple Convolutional Neural Network (CNN) model class for image classification.
    We use three convolutional layers, each followed by max-pooling and lastly flattening to the fully connected layer.

    Required packages: torch
"""

import torch


class SimpleCNN(torch.nn.Module):
    def __init__(self):

        # Call constructor of PyTorch CNN model class nn.Module .
        super(SimpleCNN, self).__init__()

        # Convolutional layer 1:
        # - 3 input channels (RGB images), 32 output channels (feature maps)
        # - Kernel size of 3x3 with padding=1 to maintain image dimensions (32x32)
        # - used to catch low-level features
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)

        # Convolutional layer 2:
        # - 32 input channels, 64 output channels
        # - used to capture mid-level features
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # Convolutional layer 3:
        # - 64 input channels, 128 output channels
        # - for high-level features
        self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)

        # Max pooling layer:
        # - Reduces the dimensions of the feature maps for lower computational complexity.
        # - We use a 2x2 kernel with a stride of 2, which reduces the height and width of the image by half.
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        # Fully connected layer 1:
        # - After the convolutional layers, we flatten the feature maps into a 1D tensor for the fully connected layers.
        # - The size of the flattened tensor after passing through three convolutional layers and pooling is 128 * 4 * 4,
        #   which is compressed from 2048 to 512 neurons for reduced model complexity.
        self.fc1 = torch.nn.Linear(128 * 4 * 4, 512)
        
        # Fully connected layer 2 (output layer):
        # - This layer maps the 512 features from the previous layer to 10 output values.
        # - Each of the 10 output values corresponds to the probability of one of the CIFAR-10 classes.
        # - A softmax function will be used at the end of the forward pass to convert the raw outputs (logits) to probabilities.
        self.fc2 = torch.nn.Linear(512, 10)

        # ReLU (Rectified Linear Unit) activation function:
        # - A non-linear activation function applied after each convolutional and fully connected layer.
        # - Helps introduce non-linearity into the model and allows it to learn more complex patterns (faster training).
        # - Function is given by max(0,x).
        self.relu = torch.nn.ReLU()
        
        # Softmax function (at the end of the forward pass):
        # - The softmax function is applied after the final fully connected layer to convert logits into
        #   class probabilities for each input image.
        self.softmax = torch.nn.Softmax(dim=1)



    """
        Define the forward pass function of the model, where input tensor "inp" sequentially passes through all the layers.
        - input tensor has dimension of batch_size x 3 x 32 x 32 here.
        - output tensor has dimension batch_size x 10, containing the probabilities for each class.
    """

    def forward(self, inp):
        # Pass through the first convolutional layer and apply ReLU activation function.
        inp = self.relu(self.conv1(inp))
        # Apply max-pooling to reduce the height x width dimensions after conv1.
        inp = self.pool(inp)
        
        # Pass through the second convolutional layer.
        inp = self.relu(self.conv2(inp))
        inp = self.pool(inp)
        
        # Pass through the third convolutional layer.
        inp = self.relu(self.conv3(inp))
        inp = self.pool(inp)

        # Flatten the output from the convolutional layers into a 1D tensor
        # The output size after pooling is (B x 128 x 4 x 4), and is flattened to (B x 128*4*4).
        inp = inp.view(-1, 128 * 4 * 4)

        # Pass through the first fully connected layer and apply ReLU activation function.
        inp = self.relu(self.fc1(inp))

        # Pass through the second fully connected layer (output layer).
        inp = self.fc2(inp)
        # Apply softmax to the output to get the class probabilities.
        inp = self.softmax(inp)

        return inp
