From a499f7430af66b77d81a0009fe8eeb1913862f63 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 20:47:26 +0000 Subject: [PATCH] Sandbox run src/main.py --- src/main.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/main.py b/src/main.py index 3957829..d0fd614 100644 --- a/src/main.py +++ b/src/main.py @@ -1,21 +1,21 @@ -from PIL import Image +import logging + +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader -import numpy as np -import logging +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): def __init__(self): @@ -23,7 +23,7 @@ def __init__(self): self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) - + def forward(self, x): x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) @@ -31,12 +31,17 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + # Step 3: Train the Model model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() -logging.basicConfig(filename='training_errors.log', level=logging.ERROR, format='%(asctime)s %(levelname)s %(message)s') +logging.basicConfig( + filename="training_errors.log", + level=logging.ERROR, + format="%(asctime)s %(levelname)s %(message)s", +) # Training loop epochs = 3 @@ -49,6 +54,8 @@ def forward(self, x): loss.backward() optimizer.step() except Exception as e: - logging.error('Error at epoch %s, batch %s: %s', epoch, i, str(e), exc_info=True) + logging.error( + "Error at epoch %s, batch %s: %s", epoch, i, str(e), exc_info=True + ) -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +torch.save(model.state_dict(), "mnist_model.pth")