Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep: refactor the training loop from a script to a class in main.py #6

Open
4 tasks done
wwzeng1 opened this issue Oct 12, 2023 · 1 comment · May be fixed by #163
Open
4 tasks done

Sweep: refactor the training loop from a script to a class in main.py #6

wwzeng1 opened this issue Oct 12, 2023 · 1 comment · May be fixed by #163
Labels

Comments

@wwzeng1
Copy link
Contributor

wwzeng1 commented Oct 12, 2023

Details

No response

Checklist
  • Modify src/main.py35f37cd
  • Ran sandbox for src/main.py. ✓ 52cb808
  • Modify src/api.pyc1333ec
  • Ran sandbox for src/api.py. ✓ b31fc72

Flowchart

@sweep-nightly
Copy link

sweep-nightly bot commented Oct 13, 2023

Here's the PR! #163.

💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID: c863ad18ba)

Actions (click)

  • ↻ Restart Sweep

Sandbox Execution ✓

Here are the sandbox execution logs prior to making any changes:

Sandbox logs for 7284908
trunk fmt src/main.py || exit 0 1/2 ✓
  NOTICES  
 src/main.py  ignored empty file [black]
 Hint: use --force to check ignored files
Checked 1 file
✔ No issues
trunk check --fix --filter=-ruff --print-failures src/main.py 2/2 ✓
  NOTICES  
 src/main.py  ignored empty file [black]
 Hint: use --force to check ignored files
Checked 1 file
✔ No issues
Run trunk upgrade to upgrade 1 linter

Sandbox passed on the latest main, so sandbox checks will be enabled for this issue.


Step 1: 🔎 Searching

I found the following snippets in your repository. I will now analyze these snippets and come up with a plan.

Some code snippets I looked at (click to expand). If some file is missing from here, you can mention the path in the ticket description.

evals/src/main.py

Lines 1 to 46 in 7284908

from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)
# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
# Training loop
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()

evals/src/api.py

Lines 1 to 27 in 7284908

from fastapi import FastAPI, UploadFile, File
from PIL import Image
import torch
from torchvision import transforms
from main import Net # Importing Net class from main.py
# Load the model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()
# Transform used for preprocessing the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
app = FastAPI()
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = Image.open(file.file).convert("L")
image = transform(image)
image = image.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)

evals/requirements.txt

Lines 1 to 15 in 7284908

annotated-types==0.6.0
anyio==3.7.1
astroid==3.0.1
certifi==2022.12.7
charset-normalizer==2.1.1
click==8.1.7
dill==0.3.7
exceptiongroup==1.1.3
fastapi==0.104.0
filelock==3.9.0
fsspec==2023.4.0
h11==0.14.0
idna==3.4
iniconfig==2.0.0
isort==5.12.0


Step 2: ⌨️ Coding

Modify src/main.py with contents:
• Create a new class named `Trainer` at the top of the file.
• Move the model, optimizer, and loss function initialization from the global scope into the `__init__` method of the `Trainer` class. The `__init__` method should take the learning rate and the path to save the model as parameters, and use them to initialize the optimizer and save the model respectively.
• Move the training loop from the global scope into a new method of the `Trainer` class named `train`. This method should take the number of epochs as a parameter.
• Move the model saving code from the global scope into a new method of the `Trainer` class named `save_model`.
• Remove the global code for model initialization, training, and saving, as it is now encapsulated in the `Trainer` class.
• At the end of the file, create an instance of the `Trainer` class, and call its `train` and `save_model` methods to train and save the model.
  • Ran sandbox for src/main.py. ✓ 52cb808
Sandbox logs for https://github.com/sweepai/evals/commit/52cb808b1b8bc91787d586dd565b782319ec01c4
trunk fmt src/main.py || exit 0 1/2 ✓
 ✔ Formatted src/main.py
Re-checking autofixed files...


Checked 1 file
✔ No issues
Run trunk upgrade to upgrade 1 linter
trunk check --fix --filter=-ruff --print-failures src/main.py 2/2 ✓
Checked 1 file
✔ No issues
Modify src/api.py with contents:
• Import the `Trainer` class from main.py at the top of the file.
• Replace the global model loading code with the creation of a `Trainer` instance and a call to its `load_model` method.
• Replace the global model usage in the `predict` function with a call to a method of the `Trainer` instance that returns the model.
  • Ran sandbox for src/api.py. ✓ b31fc72
Sandbox logs for https://github.com/sweepai/evals/commit/b31fc724747dc0cd6fdae265be435aa36b06ab4a
trunk fmt src/api.py || exit 0 1/2 ✓
 ✔ Formatted src/api.py
Re-checking autofixed files...

 ✔ Formatted src/api.py
Re-checking autofixed files...


Checked 1 file
✔ No issues
Run trunk upgrade to upgrade 1 linter
trunk check --fix --filter=-ruff --print-failures src/api.py 2/2 ✓
Checked 1 file
✔ No issues

Step 3: 🔁 Code Review

I have finished reviewing the code for completeness. I did not find errors for sweep/refactor-training-loop_27.


🎉 Latest improvements to Sweep:


💡 To recreate the pull request edit the issue title or description. To tweak the pull request, leave a comment on the pull request.
Join Our Discord

@kevinlu1248 kevinlu1248 changed the title Sweep: refactor the training loop from a script to a class Sweep: refactor the training loop from a script to a class in main.py Nov 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment