From de7577e8651ba4e4e759deb126d5f41911204036 Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Tue, 22 Aug 2023 16:27:13 -0700 Subject: [PATCH] Make use of prefilled cache in beam_search when `predict_batch_with_aux.prompt_with_targets` is set to True Without this change, the targets that are actually prompts will still be decoded AR style in order to build the cache, the logits are thrown away. With this change, and setting `EncoderDecoder.predict_batch_with_aux.prompt_with_target` to `True` we instead skip over the given prompts all together, speeding up the decode. PiperOrigin-RevId: 559256044 --- t5x/decoding.py | 282 +++++++++++++++++++++++++++++++++++++------ t5x/decoding_test.py | 160 ++++++++++++++++++++++++ 2 files changed, 403 insertions(+), 39 deletions(-) diff --git a/t5x/decoding.py b/t5x/decoding.py index 63e2a075d..cc0e69d74 100644 --- a/t5x/decoding.py +++ b/t5x/decoding.py @@ -416,7 +416,7 @@ def _temperature_sample_single_trial( rescale_log_probs: bool = True, state_callback_fn: Optional[StateCallbackFn] = None, logit_callback_fn: Optional[LogitCallbackFn] = None, -) -> jnp.ndarray: +) -> Tuple[jax.Array, jax.Array]: """A helper function for `temperature_sample`.""" # We can check the values of topp and topk only if they are not dynamic. @@ -494,13 +494,13 @@ def _temperature_sample_single_trial( # [batch, 1] initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True) - def sampling_loop_cond_fn(state: SamplingLoopState) -> bool: + def sampling_loop_cond_fn(state: SamplingLoopState) -> jax.Array: """Sampling loop termination condition.""" # Have all sampled sequences reached an end marker? # Different elements in the batch can be at different loop indices, if any # of our examples are not at the end, keep going. all_sequences_ended = jnp.all(state.ended) - return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type + return ~all_sequences_ended def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: """Sampling loop state update.""" @@ -681,7 +681,7 @@ def map_logits_with_different_temperatures(logits_batch_item, log_prob = final_state.log_prob # Drop the first position because they are dummy bos tokens. Drop the new # garbage collection token at the end too. - return final_sequences[:, 1:-1], log_prob # pytype: disable=bad-return-type # jax-ndarray + return final_sequences[:, 1:-1], log_prob # ------------------------------------------------------------------------------ @@ -968,23 +968,43 @@ class BeamState: finished_flags: jax.Array # bool: [batch_size, beam_size] # The current state of the autoregressive decoding caches. cache: PyTree # Any pytree of arrays, e.g. flax attention Cache object + # Optional array of initial indices from which decoding starts, will be either + # 0s if there is no prompt or None. + initial_index: jax.Array | None -def beam_init(batch_size: int, - beam_size: int, - max_decode_len: int, - cache: Mapping[str, jnp.ndarray], - offset: int = 0) -> BeamState: +def beam_init( + batch_size: int, + beam_size: int, + max_decode_len: int, + cache: Mapping[str, jnp.ndarray], + offset: int = 0, + live_seqs: Optional[jnp.ndarray] = None, + initial_index: Optional[jnp.ndarray] = None, +) -> BeamState: """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF - live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + # If we prefill any part of the prompt, then the initial live sequences are + # provided. In reality these will be the last token of the prompt or BOS if + # the prompt (in the batch) is empty. + live_seqs0 = ( + live_seqs + if live_seqs is not None + else jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + ) finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache) + # We will have to expand the cache_index if we're given an initial prompt that + # we prefill. + beam_cache0 = cache_map( + lambda x: add_beam_dim(x, beam_size, offset), + cache, + apply_to_index=live_seqs is not None, + ) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -992,29 +1012,106 @@ def beam_init(batch_size: int, live_seqs=live_seqs0, finished_seqs=finished_seqs0, finished_flags=finished_flags0, - cache=beam_cache0) + cache=beam_cache0, + initial_index=initial_index, + ) -# Beam search routine: +def _right_align_prompts(prompts): + """Right align the prompts.""" + + # Implementation note: + # + # A very short code to do this right aligning, would be to vmap a jnp.roll for + # the amount of padding in each example, i.e. max_len - prompt_max_index + # (+-1) - however this is slow. + # + # A faster way, courtesy Jeremiah Willcock, is to shift rows by bitmasking + # the gap and iterating for 1, 2, 4, ... log2(len) bitmasks. + # + # This gives a ~3x speedup over vmapping a roll. + + max_len = prompts.shape[1] + nbits = np.ceil(np.log2(max_len)).astype(np.int32) + indices = jnp.arange(max_len) + prompt_max_index = jnp.argmax((prompts != 0) * indices[None, :], axis=1) + shifts = max_len - prompt_max_index - 1 + for i in range(0, nbits + 1): + bitmask = 2**i + prompts = jnp.where( + jnp.expand_dims(shifts & bitmask, 1), + jnp.pad(prompts, ((0, 0), (bitmask, 0)))[:, :-bitmask], + prompts, + ) + return prompts + + +def _left_align_prompts(prompts): + """Left align the prompts.""" + # See implementation notes in `_right_align_prompts`. + + max_len = prompts.shape[1] + # [0, 1, 2, ... L - 1] + indices = jnp.arange(max_len) + # Indices of non padding positions - 1 based, since `indices` is 0 based. + non_padding_positions = (prompts != 0) * (indices[None, :] + 1) + # Replace all padding with `max_len + 1` + m = jnp.where(non_padding_positions, non_padding_positions, max_len + 1) + # First prompt's index. + shifts = jnp.argmin(m, axis=1) + temp = prompts + nbits = np.ceil(np.log2(max_len)).astype(np.int32) + for i in range(0, nbits + 1): + bitmask = 2**i + temp = jnp.where( + jnp.expand_dims(shifts & bitmask, 1), + jnp.pad(temp, ((0, 0), (0, bitmask)))[:, bitmask:], + temp, + ) + return temp -def beam_search(inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[[DecodingState], - Tuple[jnp.ndarray, - Mapping[str, jnp.ndarray]]], - eos_id: int, - num_decodes: int = 4, - alpha: float = 0.6, - max_decode_len: Optional[int] = None, - decode_rng: Optional[jnp.ndarray] = None, - cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]: +def _pick_last_prompt_token(prompts): + # prompts: i32[batch, length] + prompt_lengths = jnp.sum(prompts != 0, axis=1) + # return value: i32[batch,] + return prompts[jnp.arange(prompts.shape[0]), prompt_lengths] + + +# Beam search routine: +def beam_search( + inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[ + [DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]] + ], + eos_id: int, + num_decodes: int = 4, + alpha: float = 0.6, + max_decode_len: Optional[int] = None, + decode_rng: Optional[jnp.ndarray] = None, + cache_offset: int = 0, + initial_index: Optional[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Beam search for transformer machine translation. If `inputs` has non-zero entries, those values are not modified, i.e., the sampled values for those positions are discarded. This simulates the teacher forcing on the prefix positions. + NOTE: While using initial_index with prompts of variable lengths + To comply with the max_decode_len length requirement, we might now return + sequences that were live (i.e. EOS not decoded yet) when they exceeded their + length allowance along with sequences that finished (i.e. EOS was decoded). + Furthermore there might be sequences that finished decoding after their + max_decode_len was finished, but would appear truncated in the output at + max_decode_len. + + TODO(afrozm): Solve this, if needed, by having a third class of sequences + apart from live and finished called "truncated", then after beam search + completes, we will order them as finished > truncated > live, rather than + finished > live that happens right now. + Args: inputs: array: [batch_size, length] int32 sequence of tokens. cache: flax attention cache. @@ -1028,6 +1125,11 @@ def beam_search(inputs: jnp.ndarray, None, it uses `inputs.shape[1]` as `max_decode_len`. decode_rng: Unused decoder RNG seed. cache_offset: axis offset for cache, arising from scanned layers. + initial_index: Optional[jnp.ndarray], the index from which to start decoding + autoregressively if set. If unset, then we teacher-force the prefix, but + autoregressively (so it will be slow). When set, this also assumes that + the cache is appropriately populated. Since inputs are padded on the left + with BOS = 0, these are also the lengths of the prompts. Returns: Tuple of: @@ -1046,15 +1148,55 @@ def beam_search(inputs: jnp.ndarray, # We start with a dummy token in the beginning so extend the maximum length. max_decode_len += 1 + right_aligned_input = None + live_seqs = None + if initial_index is not None: + # Now contains the inputs, but "right aligned" so as to end with the last + # prompt token. + # [batch_size, length] + right_aligned_input = _right_align_prompts(inputs) + # `inputs` now is just the last token of the prompt, right padded to the + # same as before. + length = inputs.shape[1] + inputs = jnp.pad( + right_aligned_input[:, -1][:, None], + ((0, 0), (0, length - 1)), + constant_values=0, + ) + + # Sized [batch, max_decode_len] + live_seqs = jnp.pad( + right_aligned_input[:, -1][:, None], + ((0, 0), (0, max_decode_len - 1)), + constant_values=0, + ) + live_seqs = jnp.expand_dims(live_seqs, axis=1) + live_seqs = jnp.broadcast_to( + live_seqs, (live_seqs.shape[0], num_decodes, live_seqs.shape[-1]) + ) + else: + initial_index = jnp.zeros((batch_size,), dtype=jnp.int32) + # initialize beam search state - beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len, - cache, cache_offset) + beam_search_init_state = beam_init( + batch_size, + beam_size, + max_decode_len, + cache, + cache_offset, + live_seqs=live_seqs, + initial_index=initial_index, + ) - def beam_search_loop_cond_fn(state: BeamState) -> bool: + def beam_search_loop_cond_fn(state: BeamState) -> jax.Array: """Beam search loop termination condition.""" # Have we reached max decoding length? + + # Since we might be starting at different points in the prompts, let's use + # the minimum prompt length to stop conservatively. + cur_index = state.cur_index + jnp.min(state.initial_index) # Because we mutate the "i+1" position, we stop one token before the end. - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. @@ -1072,7 +1214,7 @@ def beam_search_loop_cond_fn(state: BeamState) -> bool: # If we're not at the max decode length, and the search hasn't terminated, # continue looping. - return not_at_end & (~search_terminated) # pytype: disable=bad-return-type # jax-devicearray + return not_at_end & (~search_terminated) def beam_search_loop_body_fn(state: BeamState) -> BeamState: """Beam search loop state update function.""" @@ -1134,12 +1276,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_ids = topk_indices % vocab_size # Force decode `inputs` into topk_ids up until PAD. When `inputs` is all # PADs this is a no-op. + # + # Also note that when `initial_index` is set, we've already setup the + # inputs so that at position 1 onwards (i.e. state.cur_index + 1 >= 1) + # the tokens are 0 and we'll immediately be "out of prompt". + # --> [batch_size, 1] next_input_token = jnp.expand_dims( inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1] + # --> [batch_size, 1] out_of_prompt = (next_input_token == 0) # When forcing prompts, update log probabilities to `0` for the top of the # beam and -INF for the rest, effectively keeping only one beam alive. + # This is necessary, because if two beams have the same prefix, then they + # will both decode the exact same sequences and that's redundant. # --> [batch, 2*beams] inside_prompt_log_probs = jnp.concatenate([ jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype), @@ -1198,7 +1348,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. - new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # pytype: disable=wrong-arg-types # jax-devicearray + lengths = state.cur_index + 1 + # We should add the lengths of the prompts to the beams as well to + # calculate the brevity penalty correctly. + # initial_index --> [batch_size,] + # topk_lengths --> [batch_size, 2*beams] + topk_lengths = jnp.repeat(initial_index[:, None], beams_to_keep, axis=1) + # lengths is now: [batch_size, 2*beams] + lengths = topk_lengths + lengths + + new_scores = topk_log_probs / brevity_penalty(alpha, lengths) # pytype: disable=wrong-arg-types # jax-devicearray # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF @@ -1225,7 +1384,9 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, - cache=top_alive_cache) + cache=top_alive_cache, + initial_index=initial_index, + ) # Run while loop and get final beam search state. final_state = lax.while_loop(beam_search_loop_cond_fn, @@ -1234,14 +1395,57 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return live sequences for that batch item. # --> [batch] - none_finished = jnp.any(final_state.finished_flags, axis=1) + any_finished = jnp.any(final_state.finished_flags, axis=1) # --> [batch, beams, length] - finished_seqs = jnp.where(none_finished[:, None, None], - final_state.finished_seqs, final_state.live_seqs) + finished_seqs = jnp.where( + any_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) # --> [batch, beams] - finished_scores = jnp.where(none_finished[:, - None], final_state.finished_scores, - final_state.live_logprobs) + finished_scores = jnp.where( + any_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) + + # Construct the finished sequences back from the prompts that we kept + # separately in the right aligned buffer. + if right_aligned_input is not None: + # Right now we have right aligned inputs, and then the last tokens + + # completions in finished_seqs. We need to concatenate and get rid of the + # extra padding, while broadcasting in the beam dimension. + + # Drop the first token, because it is also in the `right_aligned_input` + # [batch, beams, length] + finished_seqs = finished_seqs[:, :, 1:] + # right_aligned_input is [batch, length_prompt], we need to create a new + # beams dimension and broadcast it along that. + # --> [batch, beams, length] + right_aligned_input = jnp.broadcast_to( + right_aligned_input[:, None, :], + (batch_size, finished_seqs.shape[1], right_aligned_input.shape[-1]), + ) + # Now concatenate along the length dimension. + # --> [batch, beams, length] + finished_seqs = jnp.concatenate( + [right_aligned_input, finished_seqs], axis=-1 + ) + + # Now we left align everything. + + # First flatten to [batch_size * beams, length] + flat_finished_seqs = jnp.reshape( + finished_seqs, (-1, finished_seqs.shape[-1]) + ) + # Left align everything. + flat_finished_seqs = _left_align_prompts(flat_finished_seqs) + # Shape back to the original shape. + left_aligned_seqs = jnp.reshape(flat_finished_seqs, finished_seqs.shape) + # Cut to the desired length (-1 because we added 1 right off the bat) + finished_seqs = left_aligned_seqs[:, :, : max_decode_len - 1] + else: + # Just drop the first dummy 0 token. + finished_seqs = finished_seqs[:, :, 1:] - # Drop the first dummy 0 token. - return finished_seqs[:, :, 1:], finished_scores + return finished_seqs, finished_scores diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index ad563648d..5488b2377 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -26,6 +26,7 @@ import numpy as np from t5x import decoding +PAD_ID = 0 EOS_ID = 1 NEG_INF = decoding.NEG_INF @@ -957,6 +958,8 @@ def token_to_logits(decoding_state: decoding.DecodingState): for token, prompt_token in zip(beam, prompt): if prompt_token != 0: beam_scores.append(0) + elif token == PAD_ID: + beam_scores.append(0) else: beam_scores.append(log_probs[token]) beam_expected_scores.append(sum(beam_scores)) @@ -990,6 +993,163 @@ def token_to_logits(decoding_state: decoding.DecodingState): [[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]]) np.testing.assert_array_equal(expected, beam_search_sequences) + def test_align_prompt(self): + prompts = np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 6, 7], + [0, 1, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0], + [0, 1, 2, 3, 0, 0, 0], + ], + dtype=np.int32, + ) + right_aligned_prompts = decoding._right_align_prompts(prompts) + left_aligned_prompts = decoding._left_align_prompts(right_aligned_prompts) + np.testing.assert_array_equal( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 6, 7], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 2], + [0, 0, 0, 0, 1, 2, 3], + ], + dtype=np.int32, + ), + right_aligned_prompts, + ) + np.testing.assert_array_equal( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 6, 7], + [1, 0, 0, 0, 0, 0, 0], + [1, 2, 0, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + left_aligned_prompts, + ) + + def test_beam_search_force_decode_prefix_with_initial_index(self): + beam_size = 2 + + record_decoding_states = [] + + def token_to_logits(decoding_state: decoding.DecodingState): + # Record the decoding_state coming in. + # pdb.set_trace() + record_decoding_states.append(decoding_state) + + # Use id 2 then 3 for batch element 0 and id 3, 2 then EOS for element 1. + logits = np.repeat( + np.expand_dims( + np.array( + [ + [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], + ], + dtype=np.float32, + ), + axis=1, + ), + [beam_size], + axis=1, + ) + + logits = decoding.flatten_beam_dim(logits) + # Return the cache as-is. + return logits, decoding_state.cache + + # batch element 0 has length 1 and element 1 has length 2. + inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) + batch_size = inputs.shape[0] + rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) + initial_index = np.array([1, 2], dtype=np.int32) + REST_OF_THE_SHAPE = 1024 # dummy pylint: disable=invalid-name + dummy_cache = { + 'cached_bias': np.ones((1, REST_OF_THE_SHAPE), dtype=np.float32), + 'decoder/layers_0/self_attention/cached_key': np.ones( + (batch_size, REST_OF_THE_SHAPE), dtype=np.float32 + ), + 'decoder/layers_0/self_attention/cache_index': np.ones( + (batch_size,), dtype=np.float32 + ), + } + + # Since we are capturing the cache, etc. + with jax.disable_jit(): + beam_search_sequences, decoding_scores = decoding.beam_search( + inputs, + dummy_cache, + token_to_logits, + EOS_ID, + num_decodes=beam_size, + alpha=0, + initial_index=initial_index, + ) + + # pdb.set_trace() + + # Since we're sending in a decode prefix, the first tokens that should get + # decoded are the last tokens in the prompt - broadcasted to the beam size. + expected_first_tokens = np.array([[7], [7], [5], [5]], dtype=np.int32) + np.testing.assert_array_equal( + expected_first_tokens, record_decoding_states[0].cur_token + ) + + # Assert on the expected cache shapes that `token_to_logits` should see. + first_cache = record_decoding_states[0].cache + + # This shouldn't expand. + self.assertEqual( + dummy_cache['cached_bias'].shape, first_cache['cached_bias'].shape + ) + # These should expand. + self.assertEqual( + (batch_size * beam_size, REST_OF_THE_SHAPE), + first_cache['decoder/layers_0/self_attention/cached_key'].shape, + ) + self.assertEqual( + (batch_size * beam_size,), + first_cache['decoder/layers_0/self_attention/cache_index'].shape, + ) + + # Prefixes are forced depending on inputs. + # Beam search sequences and corresponding scores are in reverse order. + self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) + expected = np.array( + [[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], [[4, 5, 3, 1, 0], [4, 5, 1, 0, 0]]] + ) + np.testing.assert_array_equal(expected, beam_search_sequences) + + expected_scores = [] + batch_logits = np.array( + [ + [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1.0, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], + ], + dtype=np.float32, + ) + for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): + beam_expected_scores = [] + for beam in batch: + log_probs = jax.nn.log_softmax(logits) + # Add them directly since they are static. + beam_scores = [] + for token, prompt_token in zip(beam, prompt): + if prompt_token != 0: + beam_scores.append(0) + elif token == PAD_ID: + beam_scores.append(0) + else: + beam_scores.append(log_probs[token]) + beam_expected_scores.append(sum(beam_scores)) + expected_scores.append(beam_expected_scores) + np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) + if __name__ == '__main__': absltest.main()