Skip to content

Commit

Permalink
feat: Updated src/api.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent c7c2a7a commit 8c2e731
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""
This module defines a FastAPI application that serves a PyTorch model trained on the MNIST dataset.
It includes the necessary imports, model loading, image preprocessing, and prediction endpoint.
"""
from fastapi import FastAPI, UploadFile, File
from PIL import Image
import torch
from torchvision import transforms
from main import Net # Importing Net class from main.py

# Load the model
# Load the trained PyTorch model from the saved state dictionary
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
# Define the transformations to be applied to the images for preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
Expand All @@ -18,11 +22,25 @@
app = FastAPI()

@app.post("/predict/")
"""
This function takes an uploaded image file, preprocesses it, and makes a prediction using the loaded PyTorch model.

Parameters:
file: The uploaded image file.

Returns:
A dictionary with the key 'prediction' and the predicted digit as the value.
"""
async def predict(file: UploadFile = File(...)):
# Open the image file and convert it to grayscale
image = Image.open(file.file).convert("L")
# Apply the defined transformations to the image
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
# Add a batch dimension to the image tensor
image = image.unsqueeze(0)
# Make a prediction with the model without computing gradients
with torch.no_grad():
output = model(image)
# Get the digit with the highest prediction score
_, predicted = torch.max(output.data, 1)
return {"prediction": int(predicted[0])}

0 comments on commit 8c2e731

Please sign in to comment.