From ba925fbc0468242f4cb70722de8620db059077c8 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 19:31:19 +0000 Subject: [PATCH] feat: Updated src/api.py --- src/api.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..376373b 100644 --- a/src/api.py +++ b/src/api.py @@ -2,12 +2,13 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py - -# Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +from main import MNISTTrainer # Importing MNISTTrainer class from main.py +trainer = MNISTTrainer() +trainloader = trainer.load_data() +model = trainer.define_model() +trainer.train_model(trainloader, model) +trainer.save_model(model) +model.eval() # Set the model to evaluation mode # Transform used for preprocessing the image transform = transforms.Compose([