Skip to content

Commit

Permalink
Sandbox run src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 875c5c1 commit a499f74
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,47 @@
from PIL import Image
import logging

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import logging
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

logging.basicConfig(filename='training_errors.log', level=logging.ERROR, format='%(asctime)s %(levelname)s %(message)s')
logging.basicConfig(
filename="training_errors.log",
level=logging.ERROR,
format="%(asctime)s %(levelname)s %(message)s",
)

# Training loop
epochs = 3
Expand All @@ -49,6 +54,8 @@ def forward(self, x):
loss.backward()
optimizer.step()
except Exception as e:
logging.error('Error at epoch %s, batch %s: %s', epoch, i, str(e), exc_info=True)
logging.error(
"Error at epoch %s, batch %s: %s", epoch, i, str(e), exc_info=True
)

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit a499f74

Please sign in to comment.