Skip to content

Commit

Permalink
feat: Updated src/api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 29, 2023
1 parent 00a2da7 commit 4030bd0
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from PIL import Image
import torch
from torchvision import transforms
from main import Net # Importing Net class from main.py
from main import Net, TrainModel # Importing Net and TrainModel classes from main.py

# Load the model
model = Net()
Expand All @@ -14,6 +14,9 @@
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Create an instance of TrainModel
train_model = TrainModel(model, criterion, optimizer, trainloader)
train_model.train(3)

app = FastAPI()

Expand All @@ -23,6 +26,6 @@ async def predict(file: UploadFile = File(...)):
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
output = train_model.model(image)
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}

0 comments on commit 4030bd0

Please sign in to comment.