diff --git a/src/api.py b/src/api.py index db42381..88105e8 100644 --- a/src/api.py +++ b/src/api.py @@ -1,22 +1,23 @@ import torch from fastapi import FastAPI, File, UploadFile -from main import Net # Importing Net class from main.py -from main import Trainer # Importing Trainer class from main.py from PIL import Image from torchvision import transforms +from main import Net # Importing Net class from main.py +from main import Trainer # Importing Trainer class from main.py + # Load the model trainer = Trainer() -trainer.load_model("mnist_model.pth") # Assuming load_model method exists +trainer.load_model("mnist_model.pth") # Assuming load_model method exists # Transform used for preprocessing the image -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) app = FastAPI() + @app.post("/predict/") async def predict(file: UploadFile = File(...)): image = Image.open(file.file).convert("L")