Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 4d62685 commit 7540c06
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
This module is used to train a PyTorch model on the MNIST dataset. It includes the necessary steps to preprocess the data,
define the model, and train the model.
"""

from PIL import Image
import torch
import torch.nn as nn
Expand All @@ -7,6 +12,7 @@
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# The transform variable is used to preprocess the MNIST data by converting the images to tensors and normalizing them.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
Expand All @@ -17,13 +23,34 @@

# Step 2: Define the PyTorch Model
class Net(nn.Module):
"""
This class defines a simple feed-forward neural network for the MNIST dataset. It includes three fully connected layers.
Attributes:
fc1: The first fully connected layer.
fc2: The second fully connected layer.
fc3: The third fully connected layer.
"""

def __init__(self):
"""
Initializes the Net class by defining the three fully connected layers.
"""
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
"""
Defines the forward pass of the network.
Parameters:
x: The input tensor.
Returns:
The output tensor after passing through the network.
"""
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
Expand All @@ -36,13 +63,26 @@ def forward(self, x):
criterion = nn.NLLLoss()

# Training loop
# Define the number of training epochs
epochs = 3

# Start the training loop
for epoch in range(epochs):
# For each batch of images and labels in the trainloader
for images, labels in trainloader:
# Zero the gradients
optimizer.zero_grad()

# Forward pass: compute the output of the model on the images
output = model(images)

# Compute the loss between the output and the labels
loss = criterion(output, labels)

# Backward pass: compute the gradients of the loss with respect to the model parameters
loss.backward()

# Update the model parameters
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 7540c06

Please sign in to comment.