diff --git a/src/test_main.py b/src/test_main.py new file mode 100644 index 0000000..5363417 --- /dev/null +++ b/src/test_main.py @@ -0,0 +1,53 @@ +import pytest +from pytest_mock import MockerFixture +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from main import Net + +def test_data_loading_and_preprocessing(mocker: MockerFixture): + mock_mnist = mocker.patch.object(datasets, 'MNIST') + mock_dataloader = mocker.patch.object(DataLoader, '__init__', return_value=None) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + trainset = datasets.MNIST('.', download=True, train=True, transform=transform) + trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + + mock_mnist.assert_called_once_with('.', download=True, train=True, transform=transform) + mock_dataloader.assert_called_once_with(trainset, batch_size=64, shuffle=True) + + assert isinstance(trainset, datasets.MNIST) + assert isinstance(trainloader, DataLoader) + +def test_model_definition(): + model = Net() + + assert isinstance(model, Net) + assert isinstance(model.fc1, torch.nn.Linear) + assert isinstance(model.fc2, torch.nn.Linear) + assert isinstance(model.fc3, torch.nn.Linear) + + input_data = torch.randn(64, 1, 28, 28) + output = model(input_data) + + assert output.size() == (64, 10) + assert output.dtype == torch.float32 + +def test_forward_method(mocker: MockerFixture): + mock_relu = mocker.patch('torch.nn.functional.relu') + mock_log_softmax = mocker.patch('torch.nn.functional.log_softmax') + + model = Net() + input_data = torch.randn(64, 1, 28, 28) + output = model(input_data) + + mock_relu.assert_any_call(model.fc1(input_data.view(-1, 28 * 28))) + mock_relu.assert_any_call(model.fc2(mock_relu.return_value)) + mock_log_softmax.assert_called_once_with(model.fc3(mock_relu.return_value), dim=1) + + assert output.size() == (64, 10) + assert output.dtype == torch.float32