Skip to content

Commit

Permalink
fix test bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 14, 2024
1 parent 6967505 commit 06c835b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions brainstate/nn/_poolings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_flatten1(self):
(10, 20, 30),
]:
arr = bst.random.rand(*size)
f = nn.Flatten(start_dim=0)
f = nn.Flatten(start_axis=0)
out = f(arr)
self.assertTrue(out.shape == (np.prod(size),))

Expand All @@ -29,21 +29,21 @@ def test_flatten2(self):
(10, 20, 30),
]:
arr = bst.random.rand(*size)
f = nn.Flatten(start_dim=1)
f = nn.Flatten(start_axis=1)
out = f(arr)
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))

def test_flatten3(self):
size = (16, 32, 32, 8)
arr = bst.random.rand(*size)
f = nn.Flatten(start_dim=0, in_size=(32, 8))
f = nn.Flatten(start_axis=0, in_size=(32, 8))
out = f(arr)
self.assertTrue(out.shape == (16, 32, 32 * 8))

def test_flatten4(self):
size = (16, 32, 32, 8)
arr = bst.random.rand(*size)
f = nn.Flatten(start_dim=1, in_size=(32, 32, 8))
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
out = f(arr)
self.assertTrue(out.shape == (16, 32, 32 * 8))

Expand Down

0 comments on commit 06c835b

Please sign in to comment.