Skip to content

Commit

Permalink
Fix _compute_fans when fan_in_axes is specified and len(shape) is 2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600094371
  • Loading branch information
Haiku Contributor authored and copybara-github committed Jan 23, 2024
1 parent 0898b7b commit cb1f796
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion haiku/_src/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _compute_fans(shape, fan_in_axes=None):
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
elif len(shape) == 2 and fan_in_axes is None:
fan_in, fan_out = shape
else:
if fan_in_axes is not None:
Expand Down
14 changes: 8 additions & 6 deletions haiku/_src/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ def test_compute_fans(self):
self.assertEqual(fan_in_out2, (2, 2))
fan_in_out3 = initializers._compute_fans([3, 4])
self.assertEqual(fan_in_out3, (3, 4))
fan_in_out4 = initializers._compute_fans([1, 2, 3, 4])
self.assertEqual(fan_in_out4, (6, 8))
fan_in_out5 = initializers._compute_fans([3, 5, 9], fan_in_axes=[0])
self.assertEqual(fan_in_out5, (3, 45))
fan_in_out6 = initializers._compute_fans([3, 5, 7, 4], fan_in_axes=[0, 1])
self.assertEqual(fan_in_out6, (15, 28))
fan_in_out4 = initializers._compute_fans([3, 4], fan_in_axes=[1])
self.assertEqual(fan_in_out4, (4, 3))
fan_in_out5 = initializers._compute_fans([1, 2, 3, 4])
self.assertEqual(fan_in_out5, (6, 8))
fan_in_out6 = initializers._compute_fans([3, 5, 9], fan_in_axes=[0])
self.assertEqual(fan_in_out6, (3, 45))
fan_in_out7 = initializers._compute_fans([3, 5, 7, 4], fan_in_axes=[0, 1])
self.assertEqual(fan_in_out7, (15, 28))

@test_utils.transform_and_run
def test_orthogonal_invalid_shape(self):
Expand Down

0 comments on commit cb1f796

Please sign in to comment.