Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 4d62685 commit 76be9ad
Showing 1 changed file with 49 additions and 38 deletions.
87 changes: 49 additions & 38 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,54 @@
from torch.utils.data import DataLoader
import numpy as np

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
class MNISTTrainer:
def __init__(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.optimizer = None
self.criterion = nn.NLLLoss()
self.epochs = 3

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
def load_data(self):
trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
return trainloader

# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)

# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# Training loop
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)

def define_model(self):
model = self.Net()
self.optimizer = optim.SGD(model.parameters(), lr=0.01)
return model

def train_model(self, trainloader, model):
for epoch in range(self.epochs):
for images, labels in trainloader:
self.optimizer.zero_grad()
output = model(images)
loss = self.criterion(output, labels)
loss.backward()
self.optimizer.step()

def save_model(self, model):
torch.save(model.state_dict(), "mnist_model.pth")

trainer = MNISTTrainer()
trainloader = trainer.load_data()
model = trainer.define_model()
trainer.train_model(trainloader, model)
trainer.save_model(model)

0 comments on commit 76be9ad

Please sign in to comment.