Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 4d62685 commit 84d3eb4
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from PIL import Image
import torch
import torch.nn as nn
Expand All @@ -6,6 +7,10 @@
from torch.utils.data import DataLoader
import numpy as np

# Configure logging
logging.basicConfig(filename='training_errors.log', level=logging.ERROR,
format='%(asctime)s %(levelname)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
Expand Down Expand Up @@ -38,11 +43,14 @@ def forward(self, x):
# Training loop
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
for batch_idx, (images, labels) in enumerate(trainloader):
try:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
except Exception as e:
logging.error(f'Error at epoch {epoch} batch {batch_idx}: {e}', exc_info=True)

torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 84d3eb4

Please sign in to comment.