Skip to content

Commit

Permalink
feat: add tests using mocker to main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 23, 2023
1 parent c2e0e9a commit cc9abe2
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from pytest_mock import MockerFixture
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from main import Net

def test_data_loading_and_preprocessing(mocker: MockerFixture):
mock_mnist = mocker.patch.object(datasets, 'MNIST')
mock_dataloader = mocker.patch.object(DataLoader, '__init__', return_value=None)

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

mock_mnist.assert_called_once_with('.', download=True, train=True, transform=transform)
mock_dataloader.assert_called_once_with(trainset, batch_size=64, shuffle=True)

assert isinstance(trainset, datasets.MNIST)
assert isinstance(trainloader, DataLoader)

def test_model_definition():
model = Net()

assert isinstance(model, Net)
assert isinstance(model.fc1, torch.nn.Linear)
assert isinstance(model.fc2, torch.nn.Linear)
assert isinstance(model.fc3, torch.nn.Linear)

input_data = torch.randn(64, 1, 28, 28)
output = model(input_data)

assert output.size() == (64, 10)
assert output.dtype == torch.float32

def test_forward_method(mocker: MockerFixture):
mock_relu = mocker.patch('torch.nn.functional.relu')
mock_log_softmax = mocker.patch('torch.nn.functional.log_softmax')

model = Net()
input_data = torch.randn(64, 1, 28, 28)
output = model(input_data)

mock_relu.assert_any_call(model.fc1(input_data.view(-1, 28 * 28)))
mock_relu.assert_any_call(model.fc2(mock_relu.return_value))
mock_log_softmax.assert_called_once_with(model.fc3(mock_relu.return_value), dim=1)

assert output.size() == (64, 10)
assert output.dtype == torch.float32

0 comments on commit cc9abe2

Please sign in to comment.