Skip to content

Commit

Permalink
feat: add unit tests for main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 2, 2023
1 parent b73d090 commit 8fa5ea3
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest
from unittest.mock import patch
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from main import Net, transform

class TestMain(unittest.TestCase):
def setUp(self):
self.model = Net()
self.transform = transform

def test_Net(self):
# Test the forward pass
input_tensor = torch.randn(1, 1, 28, 28)
output = self.model(input_tensor)
self.assertEqual(output.size(), (1, 10))

@patch('torchvision.datasets.MNIST')
@patch('torch.utils.data.DataLoader')
def test_data_loading(self, mock_dataloader, mock_dataset):
# Mock the MNIST dataset
mock_dataset.return_value = datasets.MNIST('.', download=True, train=True, transform=self.transform)
# Mock the DataLoader
mock_dataloader.return_value = DataLoader(mock_dataset, batch_size=64, shuffle=True)
# Assert that the DataLoader is called with the correct arguments
mock_dataloader.assert_called_with(mock_dataset, batch_size=64, shuffle=True)

if __name__ == '__main__':
unittest.main()

0 comments on commit 8fa5ea3

Please sign in to comment.