From 2384722bc6a9c42da868ce85cc5e2dcbc21f257d Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 16 Aug 2023 14:24:30 -0700 Subject: [PATCH] Support batched indices in PositionEmbed. This is useful to support prefilling caches for prompted decoding with batches containing prompts of different lengths. PiperOrigin-RevId: 557602024 --- t5x/decoding_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index 685b44fe6..ad563648d 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -774,10 +774,10 @@ def test_cache_map_with_index(self): 'cached_key': jnp.ones([10, 12, 2]), 'cached_values': jnp.ones([4, 7, 2]), 'cache_index': jnp.ones([4, 5, 6]), - } + }, }, 'position_embedder': { - 'position_embedder_index': jnp.array(-1), + 'position_embedder_index': jnp.array([-1]), }, } @@ -787,9 +787,11 @@ def test_cache_map_with_index(self): 'layers_0': { 'cached_key': fn(jnp.ones([3, 6])), 'cached_values': fn(jnp.ones([3, 6])), - 'cache_index': fn(jnp.ones([ - 3, - ])), + 'cache_index': fn( + jnp.ones([ + 3, + ]) + ), }, 'layers_1': { 'relpos_bias': { @@ -804,10 +806,10 @@ def test_cache_map_with_index(self): 'cached_key': fn(jnp.ones([10, 12, 2])), 'cached_values': fn(jnp.ones([4, 7, 2])), 'cache_index': fn(jnp.ones([4, 5, 6])), - } + }, }, 'position_embedder': { - 'position_embedder_index': jnp.array(-1), + 'position_embedder_index': jnp.array([-1]), }, }