From 64ca4c517b6ae3ec173cff6c1206896f0d47b853 Mon Sep 17 00:00:00 2001 From: Haiku Contributor Date: Sat, 20 Jan 2024 08:22:23 -0800 Subject: [PATCH] Fix _compute_fans when fan_in_axes is specified and len(shape) is 2. PiperOrigin-RevId: 600094371 --- haiku/_src/initializers.py | 2 +- haiku/_src/initializers_test.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/haiku/_src/initializers.py b/haiku/_src/initializers.py index ab6bb49d8..1e3a32ac9 100644 --- a/haiku/_src/initializers.py +++ b/haiku/_src/initializers.py @@ -41,7 +41,7 @@ def _compute_fans(shape, fan_in_axes=None): fan_in = fan_out = 1 elif len(shape) == 1: fan_in = fan_out = shape[0] - elif len(shape) == 2: + elif len(shape) == 2 and fan_in_axes is None: fan_in, fan_out = shape else: if fan_in_axes is not None: diff --git a/haiku/_src/initializers_test.py b/haiku/_src/initializers_test.py index 486b7b06f..7ebbce8d2 100644 --- a/haiku/_src/initializers_test.py +++ b/haiku/_src/initializers_test.py @@ -103,12 +103,14 @@ def test_compute_fans(self): self.assertEqual(fan_in_out2, (2, 2)) fan_in_out3 = initializers._compute_fans([3, 4]) self.assertEqual(fan_in_out3, (3, 4)) - fan_in_out4 = initializers._compute_fans([1, 2, 3, 4]) - self.assertEqual(fan_in_out4, (6, 8)) - fan_in_out5 = initializers._compute_fans([3, 5, 9], fan_in_axes=[0]) - self.assertEqual(fan_in_out5, (3, 45)) - fan_in_out6 = initializers._compute_fans([3, 5, 7, 4], fan_in_axes=[0, 1]) - self.assertEqual(fan_in_out6, (15, 28)) + fan_in_out4 = initializers._compute_fans([3, 4], fan_in_axes=[1]) + self.assertEqual(fan_in_out4, (4, 3)) + fan_in_out5 = initializers._compute_fans([1, 2, 3, 4]) + self.assertEqual(fan_in_out5, (6, 8)) + fan_in_out6 = initializers._compute_fans([3, 5, 9], fan_in_axes=[0]) + self.assertEqual(fan_in_out6, (3, 45)) + fan_in_out7 = initializers._compute_fans([3, 5, 7, 4], fan_in_axes=[0, 1]) + self.assertEqual(fan_in_out7, (15, 28)) @test_utils.transform_and_run def test_orthogonal_invalid_shape(self):