From 029e0cdbeebcd5d81a7d44b32a75e14448b26366 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 09:03:30 +0000 Subject: [PATCH 1/3] feat: Updated src/main.py --- src/main.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..8c7fe4e 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,23 @@ trainset = datasets.MNIST('.', download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) +# Step 4: Define the TrainModel Class +class TrainModel: + def __init__(self, model, criterion, optimizer, dataloader): + self.model = model + self.criterion = criterion + self.optimizer = optimizer + self.dataloader = dataloader + + def train(self, epochs): + for epoch in range(epochs): + for images, labels in self.dataloader: + self.optimizer.zero_grad() + output = self.model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() + # Step 2: Define the PyTorch Model class Net(nn.Module): def __init__(self): @@ -35,14 +52,8 @@ def forward(self, x): optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() +# Create an instance of TrainModel and train +train_model = TrainModel(model, criterion, optimizer, trainloader) +train_model.train(3) torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file From 00a2da713be124ff5f33e65ceac7495a689a90de Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 09:04:11 +0000 Subject: [PATCH 2/3] feat: Updated requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9679557..2f2271f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ certifi==2022.12.7 charset-normalizer==2.1.1 click==8.1.7 dill==0.3. -distutils exceptiongroup==1.1.3 fastapi==0.104.0 filelock==3.9.0 From 4030bd00716a6d3cd6f32d21d4d5ab7d361e62c7 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 09:05:26 +0000 Subject: [PATCH 3/3] feat: Updated src/api.py --- src/api.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..74fee25 100644 --- a/src/api.py +++ b/src/api.py @@ -2,7 +2,7 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import Net, TrainModel # Importing Net and TrainModel classes from main.py # Load the model model = Net() @@ -14,6 +14,9 @@ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) +# Create an instance of TrainModel +train_model = TrainModel(model, criterion, optimizer, trainloader) +train_model.train(3) app = FastAPI() @@ -23,6 +26,6 @@ async def predict(file: UploadFile = File(...)): image = transform(image) image = image.unsqueeze(0) # Add batch dimension with torch.no_grad(): - output = model(image) + output = train_model.model(image) _, predicted = torch.max(output.data, 1) return {"prediction": int(predicted[0])}