Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and docstrings to main.py and api.py #108

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
from fastapi import FastAPI, UploadFile, File
from PIL import Image
"""
This script creates a FastAPI application for making predictions using the PyTorch model defined in main.py.
"""
# torch is the main PyTorch library
import torch

# FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints.
from fastapi import FastAPI, File, UploadFile

# PIL is used for opening, manipulating, and saving many different image file formats
from PIL import Image

# torchvision.transforms provides classes for transforming images
from torchvision import transforms
from main import Net # Importing Net class from main.py

# Load the model
# Importing Net class from main.py
from main import Net

# Load the trained PyTorch model from a file
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()

# Transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define a sequence of preprocessing steps to be applied to the input images
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

# Create an instance of the FastAPI application
app = FastAPI()


# Define a route handler for making predictions using the model
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
"""
This function is a route handler for making predictions using the model.
It takes an image file as input and returns a prediction.

Parameters:
file (UploadFile): The image file to predict.

Returns:
dict: A dictionary with the prediction.
"""
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
Expand Down
29 changes: 27 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,54 @@
from PIL import Image
"""
This script loads and preprocesses the MNIST dataset, defines a PyTorch model, and trains the model.
"""
# PIL is used for opening, manipulating, and saving many different image file formats
import PIL.Image
# torch is the main PyTorch library
import torch
# torch.nn provides classes for building neural networks
import torch.nn as nn
# torch.optim provides classes for implementing various optimization algorithms
import torch.optim as optim
from torchvision import datasets, transforms
# torchvision.datasets provides classes for loading and using various popular datasets
from torchvision import datasets
# torchvision.transforms provides classes for transforming images
from torchvision import transforms
# torch.utils.data provides classes for loading data in parallel
from torch.utils.data import DataLoader
# numpy is used for numerical operations
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# This is a sequence of preprocessing steps to be applied to the images in the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# This represents the MNIST training dataset
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
# This is a data loader for batching and shuffling the training data
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
# This class defines the architecture of the PyTorch model
class Net(nn.Module):
"""
This class defines the architecture of the PyTorch model.
"""
def __init__(self):
"""
This method initializes the model.
"""
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
"""
This method defines the forward pass of the model.
"""
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
Expand Down
Loading