Skip to content

Commit

Permalink
Support batched indices in PositionEmbed. This is useful to support p…
Browse files Browse the repository at this point in the history
…refilling caches for prompted decoding with batches containing prompts of different lengths.

PiperOrigin-RevId: 557602024
  • Loading branch information
adarob authored and t5-copybara committed Aug 17, 2023
1 parent 94a4f05 commit 2384722
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions t5x/decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,10 +774,10 @@ def test_cache_map_with_index(self):
'cached_key': jnp.ones([10, 12, 2]),
'cached_values': jnp.ones([4, 7, 2]),
'cache_index': jnp.ones([4, 5, 6]),
}
},
},
'position_embedder': {
'position_embedder_index': jnp.array(-1),
'position_embedder_index': jnp.array([-1]),
},
}

Expand All @@ -787,9 +787,11 @@ def test_cache_map_with_index(self):
'layers_0': {
'cached_key': fn(jnp.ones([3, 6])),
'cached_values': fn(jnp.ones([3, 6])),
'cache_index': fn(jnp.ones([
3,
])),
'cache_index': fn(
jnp.ones([
3,
])
),
},
'layers_1': {
'relpos_bias': {
Expand All @@ -804,10 +806,10 @@ def test_cache_map_with_index(self):
'cached_key': fn(jnp.ones([10, 12, 2])),
'cached_values': fn(jnp.ones([4, 7, 2])),
'cache_index': fn(jnp.ones([4, 5, 6])),
}
},
},
'position_embedder': {
'position_embedder_index': jnp.array(-1),
'position_embedder_index': jnp.array([-1]),
},
}

Expand Down

0 comments on commit 2384722

Please sign in to comment.