diff --git a/haiku/_src/initializers.py b/haiku/_src/initializers.py index ab6bb49d8..1e3a32ac9 100644 --- a/haiku/_src/initializers.py +++ b/haiku/_src/initializers.py @@ -41,7 +41,7 @@ def _compute_fans(shape, fan_in_axes=None): fan_in = fan_out = 1 elif len(shape) == 1: fan_in = fan_out = shape[0] - elif len(shape) == 2: + elif len(shape) == 2 and fan_in_axes is None: fan_in, fan_out = shape else: if fan_in_axes is not None: diff --git a/haiku/_src/initializers_test.py b/haiku/_src/initializers_test.py index 486b7b06f..7ebbce8d2 100644 --- a/haiku/_src/initializers_test.py +++ b/haiku/_src/initializers_test.py @@ -103,12 +103,14 @@ def test_compute_fans(self): self.assertEqual(fan_in_out2, (2, 2)) fan_in_out3 = initializers._compute_fans([3, 4]) self.assertEqual(fan_in_out3, (3, 4)) - fan_in_out4 = initializers._compute_fans([1, 2, 3, 4]) - self.assertEqual(fan_in_out4, (6, 8)) - fan_in_out5 = initializers._compute_fans([3, 5, 9], fan_in_axes=[0]) - self.assertEqual(fan_in_out5, (3, 45)) - fan_in_out6 = initializers._compute_fans([3, 5, 7, 4], fan_in_axes=[0, 1]) - self.assertEqual(fan_in_out6, (15, 28)) + fan_in_out4 = initializers._compute_fans([3, 4], fan_in_axes=[1]) + self.assertEqual(fan_in_out4, (4, 3)) + fan_in_out5 = initializers._compute_fans([1, 2, 3, 4]) + self.assertEqual(fan_in_out5, (6, 8)) + fan_in_out6 = initializers._compute_fans([3, 5, 9], fan_in_axes=[0]) + self.assertEqual(fan_in_out6, (3, 45)) + fan_in_out7 = initializers._compute_fans([3, 5, 7, 4], fan_in_axes=[0, 1]) + self.assertEqual(fan_in_out7, (15, 28)) @test_utils.transform_and_run def test_orthogonal_invalid_shape(self):