diff --git a/src/api.py b/src/api.py index c386220..112aae2 100644 --- a/src/api.py +++ b/src/api.py @@ -22,16 +22,16 @@ app = FastAPI() @app.post("/predict/") -""" -This function takes an uploaded image file, preprocesses it, and makes a prediction using the loaded PyTorch model. +async def predict(file: UploadFile = File(...)): + """ + This function takes an uploaded image file, preprocesses it, and makes a prediction using the loaded PyTorch model. -Parameters: -file: The uploaded image file. + Parameters: + file: The uploaded image file. -Returns: -A dictionary with the key 'prediction' and the predicted digit as the value. -""" -async def predict(file: UploadFile = File(...)): + Returns: + A dictionary with the key 'prediction' and the predicted digit as the value. + """ # Open the image file and convert it to grayscale image = Image.open(file.file).convert("L") # Apply the defined transformations to the image