From 340b634a0b5aeccd94a1e8f246bbbc7a8be02ccb Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Tue, 22 Aug 2023 16:27:13 -0700 Subject: [PATCH] Make use of prefilled cache in beam_search when `predict_batch_with_aux.prompt_with_targets` is set to True Without this change, the targets that are actually prompts will still be decoded AR style in order to build the cache, the logits are thrown away. With this change, and setting `EncoderDecoder.predict_batch_with_aux.prompt_with_target` to `True` we instead skip over the given prompts all together, speeding up the decode. PiperOrigin-RevId: 559256044 --- t5x/decoding.py | 253 ++++++++++++++++++++++++++++++++++++++----- t5x/decoding_test.py | 160 +++++++++++++++++++++++++++ 2 files changed, 386 insertions(+), 27 deletions(-) diff --git a/t5x/decoding.py b/t5x/decoding.py index 63e2a075d..2c5e1f929 100644 --- a/t5x/decoding.py +++ b/t5x/decoding.py @@ -968,23 +968,43 @@ class BeamState: finished_flags: jax.Array # bool: [batch_size, beam_size] # The current state of the autoregressive decoding caches. cache: PyTree # Any pytree of arrays, e.g. flax attention Cache object + # Optional array of initial indices from which decoding starts, will be either + # 0s if there is no prompt or None. + initial_index: jax.Array | None -def beam_init(batch_size: int, - beam_size: int, - max_decode_len: int, - cache: Mapping[str, jnp.ndarray], - offset: int = 0) -> BeamState: +def beam_init( + batch_size: int, + beam_size: int, + max_decode_len: int, + cache: Mapping[str, jnp.ndarray], + offset: int = 0, + live_seqs: Optional[jnp.ndarray] = None, + initial_index: Optional[jnp.ndarray] = None, +) -> BeamState: """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF - live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + # If we prefill any part of the prompt, then the initial live sequences are + # provided. In reality these will be the last token of the prompt or BOS if + # the prompt (in the batch) is empty. + live_seqs0 = ( + live_seqs + if live_seqs is not None + else jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + ) finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache) + # We will have to expand the cache_index if we're given an initial prompt that + # we prefill. + beam_cache0 = cache_map( + lambda x: add_beam_dim(x, beam_size, offset), + cache, + apply_to_index=live_seqs is not None, + ) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -992,23 +1012,87 @@ def beam_init(batch_size: int, live_seqs=live_seqs0, finished_seqs=finished_seqs0, finished_flags=finished_flags0, - cache=beam_cache0) + cache=beam_cache0, + initial_index=initial_index, + ) -# Beam search routine: +def _right_align_prompts(prompts): + """Right align the prompts.""" + + # Implementation note: + # + # A very short code to do this right aligning, would be to vmap a jnp.roll for + # the amount of padding in each example, i.e. max_len - prompt_max_index + # (+-1) - however this is slow. + # + # A faster way, courtesy Jeremiah Willcock, is to shift rows by bitmasking + # the gap and iterating for 1, 2, 4, ... log2(len) bitmasks. + # + # This gives a ~3x speedup over vmapping a roll. + + max_len = prompts.shape[1] + nbits = np.ceil(np.log2(max_len)).astype(np.int32) + indices = jnp.arange(max_len) + prompt_max_index = jnp.argmax((prompts != 0) * indices[None, :], axis=1) + shifts = max_len - prompt_max_index - 1 + for i in range(0, nbits + 1): + bitmask = 2**i + prompts = jnp.where( + jnp.expand_dims(shifts & bitmask, 1), + jnp.pad(prompts, ((0, 0), (bitmask, 0)))[:, :-bitmask], + prompts, + ) + return prompts + + +def _left_align_prompts(prompts): + """Left align the prompts.""" + # See implementation notes in `_right_align_prompts`. + + max_len = prompts.shape[1] + # [0, 1, 2, ... L - 1] + indices = jnp.arange(max_len) + # Indices of non padding positions - 1 based, since `indices` is 0 based. + non_padding_positions = (prompts != 0) * (indices[None, :] + 1) + # Replace all padding with `max_len + 1` + m = jnp.where(non_padding_positions, non_padding_positions, max_len + 1) + # First prompt's index. + shifts = jnp.argmin(m, axis=1) + temp = prompts + nbits = np.ceil(np.log2(max_len)).astype(np.int32) + for i in range(0, nbits + 1): + bitmask = 2**i + temp = jnp.where( + jnp.expand_dims(shifts & bitmask, 1), + jnp.pad(temp, ((0, 0), (0, bitmask)))[:, bitmask:], + temp, + ) + return temp -def beam_search(inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[[DecodingState], - Tuple[jnp.ndarray, - Mapping[str, jnp.ndarray]]], - eos_id: int, - num_decodes: int = 4, - alpha: float = 0.6, - max_decode_len: Optional[int] = None, - decode_rng: Optional[jnp.ndarray] = None, - cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]: +def _pick_last_prompt_token(prompts): + # prompts: i32[batch, length] + prompt_lengths = jnp.sum(prompts != 0, axis=1) + # return value: i32[batch,] + return prompts[jnp.arange(prompts.shape[0]), prompt_lengths] + + +# Beam search routine: +def beam_search( + inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[ + [DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]] + ], + eos_id: int, + num_decodes: int = 4, + alpha: float = 0.6, + max_decode_len: Optional[int] = None, + decode_rng: Optional[jnp.ndarray] = None, + cache_offset: int = 0, + initial_index: Optional[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Beam search for transformer machine translation. If `inputs` has non-zero entries, those values are not modified, i.e., @@ -1028,6 +1112,11 @@ def beam_search(inputs: jnp.ndarray, None, it uses `inputs.shape[1]` as `max_decode_len`. decode_rng: Unused decoder RNG seed. cache_offset: axis offset for cache, arising from scanned layers. + initial_index: Optional[jnp.ndarray], the index from which to start decoding + autoregressively if set. If unset, then we teacher-force the prefix, but + autoregressively (so it will be slow). When set, this also assumes that + the cache is appropriately populated. Since inputs are padded on the left + with BOS = 0, these are also the lengths of the prompts. Returns: Tuple of: @@ -1046,15 +1135,55 @@ def beam_search(inputs: jnp.ndarray, # We start with a dummy token in the beginning so extend the maximum length. max_decode_len += 1 + right_aligned_input = None + live_seqs = None + if initial_index is not None: + # Now contains the inputs, but "right aligned" so as to end with the last + # prompt token. + # [batch_size, length] + right_aligned_input = _right_align_prompts(inputs) + # `inputs` now is just the last token of the prompt, right padded to the + # same as before. + length = inputs.shape[1] + inputs = jnp.pad( + right_aligned_input[:, -1][:, None], + ((0, 0), (0, length - 1)), + constant_values=0, + ) + + # Sized [batch, max_decode_len] + live_seqs = jnp.pad( + right_aligned_input[:, -1][:, None], + ((0, 0), (0, max_decode_len - 1)), + constant_values=0, + ) + live_seqs = jnp.expand_dims(live_seqs, axis=1) + live_seqs = jnp.broadcast_to( + live_seqs, (live_seqs.shape[0], num_decodes, live_seqs.shape[-1]) + ) + else: + initial_index = jnp.zeros((batch_size,), dtype=jnp.int32) + # initialize beam search state - beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len, - cache, cache_offset) + beam_search_init_state = beam_init( + batch_size, + beam_size, + max_decode_len, + cache, + cache_offset, + live_seqs=live_seqs, + initial_index=initial_index, + ) def beam_search_loop_cond_fn(state: BeamState) -> bool: """Beam search loop termination condition.""" # Have we reached max decoding length? + + # Since we might be starting at different points in the prompts, let's use + # the minimum prompt length to stop conservatively. + cur_index = state.cur_index + jnp.min(state.initial_index) # Because we mutate the "i+1" position, we stop one token before the end. - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. @@ -1134,12 +1263,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_ids = topk_indices % vocab_size # Force decode `inputs` into topk_ids up until PAD. When `inputs` is all # PADs this is a no-op. + # + # Also note that when `initial_index` is set, we've already setup the + # inputs so that at position 1 onwards (i.e. state.cur_index + 1 >= 1) + # the tokens are 0 and we'll immediately be "out of prompt". + # --> [batch_size, 1] next_input_token = jnp.expand_dims( inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1] + # --> [batch_size, 1] out_of_prompt = (next_input_token == 0) # When forcing prompts, update log probabilities to `0` for the top of the # beam and -INF for the rest, effectively keeping only one beam alive. + # This is necessary, because if two beams have the same prefix, then they + # will both decode the exact same sequences and that's redundant. # --> [batch, 2*beams] inside_prompt_log_probs = jnp.concatenate([ jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype), @@ -1171,6 +1308,19 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + # See if they've exceeded their `max_decode_len` allotment, this is not the + # same as checkpoint for `state.cur_index` reaching the end, since + # `initial_index` will contain variable length prompts whose lengths we'd + # need to accommodate. + # --> [batch, 2*beams] + topk_seq_lengths = jnp.sum(topk_seq != 0, axis=-1) + # total lengths along with initial prompts. + # initial_index[:, None] is shaped [batch_size, 1] + topk_seq_lengths += initial_index[:, None] + # Update `newly_finished` with anything that's beyond its' allowed length. + # NOTE: This might bump out some completed sequences, with an incomplete + # sequence if the incomplete sequence's score is higher. + newly_finished |= topk_seq_lengths >= max_decode_len # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -1198,7 +1348,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. - new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # pytype: disable=wrong-arg-types # jax-devicearray + lengths = state.cur_index + 1 + # We should add the lengths of the prompts to the beams as well to + # calculate the brevity penalty correctly. + # initial_index --> [batch_size,] + # topk_lengths --> [batch_size, 2*beams] + topk_lengths = jnp.repeat(initial_index[:, None], beams_to_keep, axis=1) + # lengths is now: [batch_size, 2*beams] + lengths = topk_lengths + lengths + + new_scores = topk_log_probs / brevity_penalty(alpha, lengths) # pytype: disable=wrong-arg-types # jax-devicearray # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF @@ -1225,7 +1384,9 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, - cache=top_alive_cache) + cache=top_alive_cache, + initial_index=initial_index, + ) # Run while loop and get final beam search state. final_state = lax.while_loop(beam_search_loop_cond_fn, @@ -1243,5 +1404,43 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: None], final_state.finished_scores, final_state.live_logprobs) - # Drop the first dummy 0 token. - return finished_seqs[:, :, 1:], finished_scores + # Construct the finished sequences back from the prompts that we kept + # separately in the right aligned buffer. + if right_aligned_input is not None: + # Right now we have right aligned inputs, and then the last tokens + + # completions in finished_seqs. We need to concatenate and get rid of the + # extra padding, while broadcasting in the beam dimension. + + # Drop the first token, because it is also in the `right_aligned_input` + # [batch, beams, length] + finished_seqs = finished_seqs[:, :, 1:] + # right_aligned_input is [batch, length_prompt], we need to create a new + # beams dimension and broadcast it along that. + # --> [batch, beams, length] + right_aligned_input = jnp.broadcast_to( + right_aligned_input[:, None, :], + (batch_size, finished_seqs.shape[1], right_aligned_input.shape[-1]), + ) + # Now concatenate along the length dimension. + # --> [batch, beams, length] + finished_seqs = jnp.concatenate( + [right_aligned_input, finished_seqs], axis=-1 + ) + + # Now we left align everything. + + # First flatten to [batch_size * beams, length] + flat_finished_seqs = jnp.reshape( + finished_seqs, (-1, finished_seqs.shape[-1]) + ) + # Left align everything. + flat_finished_seqs = _left_align_prompts(flat_finished_seqs) + # Shape back to the original shape. + left_aligned_seqs = jnp.reshape(flat_finished_seqs, finished_seqs.shape) + # Cut to the desired length (-1 because we added 1 right off the bat) + finished_seqs = left_aligned_seqs[:, :, : max_decode_len - 1] + else: + # Just drop the first dummy 0 token. + finished_seqs = finished_seqs[:, :, 1:] + + return finished_seqs, finished_scores diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index ad563648d..f6a17cb8d 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -26,6 +26,7 @@ import numpy as np from t5x import decoding +PAD_ID = 0 EOS_ID = 1 NEG_INF = decoding.NEG_INF @@ -957,6 +958,8 @@ def token_to_logits(decoding_state: decoding.DecodingState): for token, prompt_token in zip(beam, prompt): if prompt_token != 0: beam_scores.append(0) + elif token == PAD_ID: + beam_scores.append(0) else: beam_scores.append(log_probs[token]) beam_expected_scores.append(sum(beam_scores)) @@ -990,6 +993,163 @@ def token_to_logits(decoding_state: decoding.DecodingState): [[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]]) np.testing.assert_array_equal(expected, beam_search_sequences) + def test_align_prompt(self): + prompts = np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 6, 7], + [0, 1, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0], + [0, 1, 2, 3, 0, 0, 0], + ], + dtype=np.int32, + ) + right_aligned_prompts = decoding._right_align_prompts(prompts) + left_aligned_prompts = decoding._left_align_prompts(right_aligned_prompts) + np.testing.assert_array_equal( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 6, 7], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 2], + [0, 0, 0, 0, 1, 2, 3], + ], + dtype=np.int32, + ), + right_aligned_prompts, + ) + np.testing.assert_array_equal( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 6, 7], + [1, 0, 0, 0, 0, 0, 0], + [1, 2, 0, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + left_aligned_prompts, + ) + + def test_beam_search_force_decode_prefix_with_initial_index(self): + beam_size = 2 + + record_decoding_states = [] + + def token_to_logits(decoding_state: decoding.DecodingState): + # Record the decoding_state coming in. + # pdb.set_trace() + record_decoding_states.append(decoding_state) + + # Use id 2 then 3 for batch element 0 and id 3, 2 then EOS for element 1. + logits = np.repeat( + np.expand_dims( + np.array( + [ + [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], + ], + dtype=np.float32, + ), + axis=1, + ), + [beam_size], + axis=1, + ) + + logits = decoding.flatten_beam_dim(logits) + # Return the cache as-is. + return logits, decoding_state.cache + + # batch element 0 has length 1 and element 1 has length 2. + inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) + batch_size = inputs.shape[0] + rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) + initial_index = np.array([1, 2], dtype=np.int32) + REST_OF_THE_SHAPE = 1024 # dummy pylint: disable=invalid-name + dummy_cache = { + 'cached_bias': np.ones((1, REST_OF_THE_SHAPE), dtype=np.float32), + 'decoder/layers_0/self_attention/cached_key': np.ones( + (batch_size, REST_OF_THE_SHAPE), dtype=np.float32 + ), + 'decoder/layers_0/self_attention/cache_index': np.ones( + (batch_size,), dtype=np.float32 + ), + } + + # Since we are capturing the cache, etc. + with jax.disable_jit(): + beam_search_sequences, decoding_scores = decoding.beam_search( + inputs, + dummy_cache, + token_to_logits, + EOS_ID, + num_decodes=beam_size, + alpha=0, + initial_index=initial_index, + ) + + # pdb.set_trace() + + # Since we're sending in a decode prefix, the first tokens that should get + # decoded are the last tokens in the prompt - broadcasted to the beam size. + expected_first_tokens = np.array([[7], [7], [5], [5]], dtype=np.int32) + np.testing.assert_array_equal( + expected_first_tokens, record_decoding_states[0].cur_token + ) + + # Assert on the expected cache shapes that `token_to_logits` should see. + first_cache = record_decoding_states[0].cache + + # This shouldn't expand. + self.assertEqual( + dummy_cache['cached_bias'].shape, first_cache['cached_bias'].shape + ) + # These should expand. + self.assertEqual( + (batch_size * beam_size, REST_OF_THE_SHAPE), + first_cache['decoder/layers_0/self_attention/cached_key'].shape, + ) + self.assertEqual( + (batch_size * beam_size,), + first_cache['decoder/layers_0/self_attention/cache_index'].shape, + ) + + # Prefixes are forced depending on inputs. + # Beam search sequences and corresponding scores are in reverse order. + self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) + expected = np.array( + [[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], [[4, 5, 3, 3, 3], [4, 5, 1, 0, 0]]] + ) + np.testing.assert_array_equal(expected, beam_search_sequences) + + expected_scores = [] + batch_logits = np.array( + [ + [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], + ], + dtype=np.float32, + ) + for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): + beam_expected_scores = [] + for beam in batch: + log_probs = jax.nn.log_softmax(logits) + # Add them directly since they are static. + beam_scores = [] + for token, prompt_token in zip(beam, prompt): + if prompt_token != 0: + beam_scores.append(0) + elif token == PAD_ID: + beam_scores.append(0) + else: + beam_scores.append(log_probs[token]) + beam_expected_scores.append(sum(beam_scores)) + expected_scores.append(beam_expected_scores) + np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) + if __name__ == '__main__': absltest.main()