Skip to content

Commit

Permalink
Sandbox run src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 23, 2023
1 parent b0bd256 commit 26476dc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


class Net(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -31,7 +30,7 @@ def train(self, epochs=3):
optimizer = optim.SGD(self.parameters(), lr=0.01)
criterion = nn.NLLLoss()

for epoch in range(epochs):
for _epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = self(images)
Expand All @@ -45,6 +44,7 @@ def save(self, path="mnist_model.pth"):
def load(self, path="mnist_model.pth"):
self.load_state_dict(torch.load(path))


# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -60,4 +60,4 @@ def load(self, path="mnist_model.pth"):
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 26476dc

Please sign in to comment.