From 2b655af2e882e5d3b82bda4b71c19a13fc8b2d24 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:39:27 +0000 Subject: [PATCH] feat: Updated src/main.py --- src/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..1df5d6f 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,4 @@ +import logging from PIL import Image import torch import torch.nn as nn @@ -6,6 +7,8 @@ from torch.utils.data import DataLoader import numpy as np +logging.basicConfig(filename='training.log', level=logging.ERROR) + # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ transforms.ToTensor(), @@ -39,10 +42,13 @@ def forward(self, x): epochs = 3 for epoch in range(epochs): for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + try: + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + except Exception as e: + logging.exception("Exception occurred during training: %s", e) torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file