Skip to content

Commit

Permalink
feat: add tests for main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 4d62685 commit eaef018
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest
import torch
from main import Net

class TestNet(unittest.TestCase):
def test_net_initialization(self):
model = Net()
self.assertIsInstance(model.fc1, torch.nn.Linear)
self.assertIsInstance(model.fc2, torch.nn.Linear)
self.assertIsInstance(model.fc3, torch.nn.Linear)

def test_net_forward(self):
model = Net()
input_tensor = torch.randn(1, 1, 28, 28)
output = model(input_tensor)
self.assertIsInstance(output, torch.Tensor)
self.assertEqual(output.shape, (1, 10))

def test_training_loop(self):
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = torch.nn.NLLLoss()

# Small, fixed dataset
inputs = torch.randn(10, 1, 28, 28)
targets = torch.randint(0, 10, (10,))

# Training loop
initial_loss = float('inf')
for _ in range(10):
optimizer.zero_grad()
output = model(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()

self.assertLess(loss.item(), initial_loss)
initial_loss = loss.item()

if __name__ == '__main__':
unittest.main()

0 comments on commit eaef018

Please sign in to comment.