diff --git a/src/api.py b/src/api.py index 36c257a..258bdff 100644 --- a/src/api.py +++ b/src/api.py @@ -2,10 +2,11 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import Trainer # Importing Trainer class from main.py -# Load the model -model = Net() +# Create a Trainer instance and load the model +trainer = Trainer(None, None, None, None) +model = trainer.model model.load_state_dict(torch.load("mnist_model.pth")) model.eval() diff --git a/src/main.py b/src/main.py index 243a31e..315a685 100644 --- a/src/main.py +++ b/src/main.py @@ -30,19 +30,49 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) +# Define the Trainer class +class Trainer: + def __init__(self, model, optimizer, criterion, dataloader): + self.model = model + self.optimizer = optimizer + self.criterion = criterion + self.dataloader = dataloader + + def train_epoch(self): + for images, labels in self.dataloader: + self.optimizer.zero_grad() + output = self.model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() + + def evaluate(self): + total_loss = 0 + total_correct = 0 + with torch.no_grad(): + for images, labels in self.dataloader: + output = self.model(images) + loss = self.criterion(output, labels) + total_loss += loss.item() + _, predicted = torch.max(output.data, 1) + total_correct += (predicted == labels).sum().item() + average_loss = total_loss / len(self.dataloader) + accuracy = total_correct / len(self.dataloader.dataset) + return average_loss, accuracy + + def train(self, epochs): + for epoch in range(epochs): + self.train_epoch() + average_loss, accuracy = self.evaluate() + print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss}, Accuracy: {accuracy}') + # Step 3: Train the Model model = Net() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() +# Create a Trainer instance and train the model +trainer = Trainer(model, optimizer, criterion, trainloader) +trainer.train(3) torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file