Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 2, 2023
1 parent b73d090 commit 65a388d
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
"""
This class defines the architecture of a neural network for digit recognition.
It consists of three fully connected layers.
"""
class Net(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -35,7 +39,7 @@ def forward(self, x):
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# Training loop
# Training loop: This loop trains the neural network on the MNIST dataset.
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
Expand All @@ -45,4 +49,20 @@ def forward(self, x):
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")
# Zero the gradients before a new iteration
# Forward propagation: Pass the images through the model to get the output
# Compute the loss between the output and the actual labels
# Backpropagation: Compute the gradients of the loss with respect to the model's parameters
# Optimizer step: Update the model's parameters
"""
Initialize the neural network with three fully connected layers.
The first layer (fc1) has 128 neurons and takes as input the flattened 28x28 pixel MNIST images.
The second layer (fc2) has 64 neurons.
The third layer (fc3) has 10 neurons, corresponding to the 10 possible digits, and will output the network's log-probabilities.
"""
"""
Defines the forward pass of the neural network.
The input images are first flattened and then passed through the three layers with ReLU activation functions applied after the first and second layers.
The output of the third layer is passed through a log softmax function to obtain the network's log-probabilities.
"""

0 comments on commit 65a388d

Please sign in to comment.