From aac9b3f2e47399bb04f5afcdaa7ccf7a718da21c Mon Sep 17 00:00:00 2001 From: Afroz Mohiuddin Date: Tue, 22 Aug 2023 16:27:13 -0700 Subject: [PATCH] Make use of prefilled cache in beam_search when `predict_batch_with_aux.prompt_with_targets` is set to True Without this change, the targets that are actually prompts will still be decoded AR style in order to build the cache, the logits are thrown away. With this change, and setting `EncoderDecoder.predict_batch_with_aux.prompt_with_target` to `True` we instead skip over the given prompts all together, speeding up the decode. PiperOrigin-RevId: 559256044 --- t5x/decoding.py | 198 +++++++++++++++++++++++++++++++++++++------ t5x/decoding_test.py | 147 ++++++++++++++++++++++++++++++++ 2 files changed, 320 insertions(+), 25 deletions(-) diff --git a/t5x/decoding.py b/t5x/decoding.py index c445c8837..116d06796 100644 --- a/t5x/decoding.py +++ b/t5x/decoding.py @@ -970,21 +970,37 @@ class BeamState: cache: PyTree # Any pytree of arrays, e.g. flax attention Cache object -def beam_init(batch_size: int, - beam_size: int, - max_decode_len: int, - cache: Mapping[str, jnp.ndarray], - offset: int = 0) -> BeamState: +def beam_init( + batch_size: int, + beam_size: int, + max_decode_len: int, + cache: Mapping[str, jnp.ndarray], + offset: int = 0, + live_seqs: Optional[jnp.ndarray] = None, +) -> BeamState: """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF - live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + # If we prefill any part of the prompt, then the initial live sequences are + # provided. In reality these will be the last token of the prompt or BOS if + # the prompt (in the batch) is empty. + live_seqs0 = ( + live_seqs + if live_seqs is not None + else jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + ) finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache) + # We will have to expand the cache_index if we're given an initial prompt that + # we prefill. + beam_cache0 = cache_map( + lambda x: add_beam_dim(x, beam_size, offset), + cache, + apply_to_index=live_seqs is not None, + ) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -995,20 +1011,54 @@ def beam_init(batch_size: int, cache=beam_cache0) -# Beam search routine: +def right_align_prompts(prompts): + max_len = prompts.shape[1] + # [0, 1, 2, ... L - 1] + indices = jnp.arange(max_len) + # Last prompt's index. + prompt_max_index = jnp.argmax((prompts != 0) * indices[None, :], axis=1) + # Gap to the end of the matrix + gap = max_len - 1 - prompt_max_index + return jax.vmap(lambda x, y: jnp.roll(x, y, axis=0))(prompts, gap) + + +def left_align_prompts(prompts): + max_len = prompts.shape[1] + # [0, 1, 2, ... L - 1] + indices = jnp.arange(max_len) + # Indices of non padding positions. + non_padding_indices = (prompts != 0) * indices[None, :] + # Replace all padding with `max_len + 1` + m = jnp.where(non_padding_indices, non_padding_indices, max_len + 1) + # First prompt's index. + prompt_min_index = jnp.argmin(m, axis=1) + return jax.vmap(lambda x, y: jnp.roll(x, -1 * y, axis=0))( + prompts, prompt_min_index + ) + +def pick_last_prompt_token(prompts): + # prompts: i32[batch, length] + prompt_lengths = jnp.sum(prompts != 0, axis=1) + # return value: i32[batch,] + return prompts[jnp.arange(prompts.shape[0]), prompt_lengths] -def beam_search(inputs: jnp.ndarray, - cache: Mapping[str, jnp.ndarray], - tokens_to_logits: Callable[[DecodingState], - Tuple[jnp.ndarray, - Mapping[str, jnp.ndarray]]], - eos_id: int, - num_decodes: int = 4, - alpha: float = 0.6, - max_decode_len: Optional[int] = None, - decode_rng: Optional[jnp.ndarray] = None, - cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]: + +# Beam search routine: +def beam_search( + inputs: jnp.ndarray, + cache: Mapping[str, jnp.ndarray], + tokens_to_logits: Callable[ + [DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]] + ], + eos_id: int, + num_decodes: int = 4, + alpha: float = 0.6, + max_decode_len: Optional[int] = None, + decode_rng: Optional[jnp.ndarray] = None, + cache_offset: int = 0, + initial_index: Optional[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Beam search for transformer machine translation. If `inputs` has non-zero entries, those values are not modified, i.e., @@ -1028,6 +1078,11 @@ def beam_search(inputs: jnp.ndarray, None, it uses `inputs.shape[1]` as `max_decode_len`. decode_rng: Unused decoder RNG seed. cache_offset: axis offset for cache, arising from scanned layers. + initial_index: Optional[jnp.ndarray], the index from which to start decoding + autoregressively if set. If unset, then we teacher-force the prefix, but + autoregressively (so it will be slow). When set, this also assumes that + the cache is appropriately populated. Since inputs are padded on the left + with BOS = 0, these are also the lengths of the prompts. Returns: Tuple of: @@ -1046,15 +1101,40 @@ def beam_search(inputs: jnp.ndarray, # We start with a dummy token in the beginning so extend the maximum length. max_decode_len += 1 + right_aligned_input = None + live_seqs = None + if initial_index is not None: + # Now contains the inputs, but "right aligned" so as to end with the last + # prompt token. + # [batch_size, length] + right_aligned_input = right_align_prompts(inputs) + # Override `input` to be just the last prompt token at length=0, with the + # rest of the length dimension as 0s. + inputs = jnp.zeros_like(inputs) + inputs = inputs.at[:, 0].set(right_aligned_input[:, -1]) + + live_seqs = jnp.zeros((batch_size, max_decode_len), dtype=jnp.int32) + live_seqs = live_seqs.at[:, 0].set(right_aligned_input[:, -1]) + live_seqs = jnp.expand_dims(live_seqs, axis=1) + live_seqs = jnp.broadcast_to( + live_seqs, (live_seqs.shape[0], num_decodes, live_seqs.shape[-1]) + ) + # initialize beam search state - beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len, - cache, cache_offset) + beam_search_init_state = beam_init( + batch_size, + beam_size, + max_decode_len, + cache, + cache_offset, + live_seqs=live_seqs, + ) def beam_search_loop_cond_fn(state: BeamState) -> bool: """Beam search loop termination condition.""" # Have we reached max decoding length? # Because we mutate the "i+1" position, we stop one token before the end. - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = state.cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. @@ -1134,12 +1214,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: topk_ids = topk_indices % vocab_size # Force decode `inputs` into topk_ids up until PAD. When `inputs` is all # PADs this is a no-op. + # + # Also note that when `initial_index` is set, we've already set-up the + # inputs so that at position 1 onwards (i.e. state.cur_index + 1 >= 1) + # the tokens are 0 and we'll immediately be "out of prompt". + # --> [batch_size, 1] next_input_token = jnp.expand_dims( inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1] + # --> [batch_size, 1] out_of_prompt = (next_input_token == 0) # When forcing prompts, update log probabilities to `0` for the top of the # beam and -INF for the rest, effectively keeping only one beam alive. + # This is necessary, because if two beams have the same prefix, then they + # will both decode the exact same sequences and that's redundant. # --> [batch, 2*beams] inside_prompt_log_probs = jnp.concatenate([ jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype), @@ -1171,6 +1259,18 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + # See if they've exceeded their `max_decode_len` allotment, this is not the + # same as checkpoint for `state.cur_index` reaching the end, since the + # variable length prefixes (prompts) will be removed in the prompt prefill + # mode. + if initial_index is not None: + # --> [batch, 2*beams] + topk_seq_lengths = jnp.sum(topk_seq != 0, axis=-1) + # total lengths along with initial prompts. + # initial_index[:, None] is shaped [batch_size, 1] + topk_seq_lengths += initial_index[:, None] + # Update `newly_finished` with anything that's beyond its' allowed length. + newly_finished |= topk_seq_lengths >= max_decode_len # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -1198,7 +1298,17 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. - new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # pytype: disable=wrong-arg-types # jax-devicearray + lengths = state.cur_index + 1 + if initial_index is not None: + # We should add the lengths of the prompts to the beams as well to + # calculate the brevity penalty correctly. + # initial_index --> [batch_size,] + # topk_lengths --> [batch_size, 2*beams] + topk_lengths = jnp.repeat(initial_index[:, None], beams_to_keep, axis=1) + # lengths is now: [batch_size, 2*beams] + lengths = topk_lengths + lengths + + new_scores = topk_log_probs / brevity_penalty(alpha, lengths) # pytype: disable=wrong-arg-types # jax-devicearray # Mask out the still unfinished sequences by adding large negative value. # --> [batch, 2*beams] new_scores += (~newly_finished) * NEG_INF @@ -1243,5 +1353,43 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: None], final_state.finished_scores, final_state.live_logprobs) - # Drop the first dummy 0 token. - return finished_seqs[:, :, 1:], finished_scores + # Construct the finished sequences back from the prompts that we kept + # separately in the right aligned buffer. + if initial_index is not None: + # Right now we have right aligned inputs, and then the last tokens + + # completions in finished_seqs. We need to concatenate and get rid of the + # extra padding, while broadcasting in the beam dimension. + + # Drop the first token, because it is also in the `right_aligned_input` + # [batch, beams, length] + finished_seqs = finished_seqs[:, :, 1:] + # right_aligned_input is [batch, length_prompt], we need to create a new + # beams dimension and broadcast it along that. + # --> [batch, beams, length] + right_aligned_input = jnp.broadcast_to( + right_aligned_input[:, None, :], + (batch_size, finished_seqs.shape[1], right_aligned_input.shape[-1]), + ) + # Now concatenate along the length dimension. + # --> [batch, beams, length] + finished_seqs = jnp.concatenate( + [right_aligned_input, finished_seqs], axis=-1 + ) + + # Now we left align everything. + + # First flatten to [batch_size * beams, length] + flat_finished_seqs = jnp.reshape( + finished_seqs, (-1, finished_seqs.shape[-1]) + ) + # Left align everything. + flat_finished_seqs = left_align_prompts(flat_finished_seqs) + # Shape back to the original shape. + left_aligned_seqs = jnp.reshape(flat_finished_seqs, finished_seqs.shape) + # Cut to the desired length (-1 because we added 1 right off the bat) + finished_seqs = left_aligned_seqs[:, :, : max_decode_len - 1] + else: + # Just drop the first dummy 0 token. + finished_seqs = finished_seqs[:, :, 1:] + + return finished_seqs, finished_scores diff --git a/t5x/decoding_test.py b/t5x/decoding_test.py index ad563648d..b2795fc21 100644 --- a/t5x/decoding_test.py +++ b/t5x/decoding_test.py @@ -990,6 +990,153 @@ def token_to_logits(decoding_state: decoding.DecodingState): [[2, 3, 3, 3, 3], [3, 3, 3, 3, 3]]]) np.testing.assert_array_equal(expected, beam_search_sequences) + def test_align_prompt(self): + prompts = np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 1, 2, 0, 0, 0, 0], + [0, 1, 2, 3, 0, 0, 0], + ], + dtype=np.int32, + ) + right_aligned_prompts = decoding.right_align_prompts(prompts) + np.testing.assert_array_equal( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 2], + [0, 0, 0, 0, 1, 2, 3], + ], + dtype=np.int32, + ), + right_aligned_prompts, + ) + left_aligned_prompts = decoding.left_align_prompts(right_aligned_prompts) + np.testing.assert_array_equal( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0], + [1, 2, 0, 0, 0, 0, 0], + [1, 2, 3, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + left_aligned_prompts, + ) + + def test_beam_search_force_decode_prefix_with_initial_index(self): + beam_size = 2 + + record_decoding_states = [] + + def token_to_logits(decoding_state: decoding.DecodingState): + # Record the decoding_state coming in. + record_decoding_states.append(decoding_state) + + # Use id 2 then 3 for batch element 0 and id 3 then 2 for element 1. + logits = np.repeat( + np.expand_dims( + np.array( + [ + [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], + ], + dtype=np.float32, + ), + axis=1, + ), + [beam_size], + axis=1, + ) + logits = decoding.flatten_beam_dim(logits) + # Return the cache as-is. + return logits, decoding_state.cache + # return logits, {} + + # batch element 0 has length 1 and element 1 has length 2. + inputs = np.array([[0, 7, 0, 0, 0], [0, 4, 5, 0, 0]], dtype=np.int32) + batch_size = inputs.shape[0] + rolled_inputs = np.array([[7, 0, 0, 0, 0], [4, 5, 0, 0, 0]], dtype=np.int32) + initial_index = np.array([1, 2], dtype=np.int32) + REST_OF_THE_SHAPE = 1024 # dummy pylint: disable=invalid-name + dummy_cache = { + 'cached_bias': np.ones((1, REST_OF_THE_SHAPE), dtype=np.float32), + 'decoder/layers_0/self_attention/cached_key': np.ones( + (batch_size, REST_OF_THE_SHAPE), dtype=np.float32 + ), + 'decoder/layers_0/self_attention/cache_index': np.ones( + (batch_size,), dtype=np.float32 + ), + } + + beam_search_sequences, decoding_scores = decoding.beam_search( + inputs, + dummy_cache, + token_to_logits, + EOS_ID, + num_decodes=beam_size, + alpha=0, + initial_index=initial_index, + ) + + # Since we're sending in a decode prefix, the first tokens that should get + # decoded are the last tokens in the prompt - broadcasted to the beam size. + expected_first_tokens = np.array([[7], [7], [5], [5]], dtype=np.int32) + np.testing.assert_array_equal( + expected_first_tokens, record_decoding_states[0].cur_token + ) + + # Assert on the expected cache shapes that `token_to_logits` should see. + first_cache = record_decoding_states[0].cache + + # This shouldn't expand. + self.assertEqual( + dummy_cache['cached_bias'].shape, first_cache['cached_bias'].shape + ) + # These should expand. + self.assertEqual( + (batch_size * beam_size, REST_OF_THE_SHAPE), + first_cache['decoder/layers_0/self_attention/cached_key'].shape, + ) + self.assertEqual( + (batch_size * beam_size,), + first_cache['decoder/layers_0/self_attention/cache_index'].shape, + ) + + # Prefixes are forced depending on inputs. + # Beam search sequences and corresponding scores are in reverse order. + self.assertTrue(np.all(np.diff(decoding_scores) >= 0)) + expected = np.array( + [[[7, 3, 2, 2, 2], [7, 2, 2, 2, 2]], [[4, 5, 2, 3, 3], [4, 5, 3, 3, 3]]] + ) + np.testing.assert_array_equal(expected, beam_search_sequences) + + expected_scores = [] + batch_logits = np.array( + [ + [-1e7, -1e10, -0.1, -0.9, -1e4, -1e4, -1e4, -1e4], + [-1e7, -1e10, -0.9, -0.1, -1e4, -1e4, -1e4, -1e4], + ], + dtype=np.float32, + ) + for batch, logits, prompt in zip(expected, batch_logits, rolled_inputs): + beam_expected_scores = [] + for beam in batch: + log_probs = jax.nn.log_softmax(logits) + # Add them directly since they are static. + beam_scores = [] + for token, prompt_token in zip(beam, prompt): + if prompt_token != 0: + beam_scores.append(0) + else: + beam_scores.append(log_probs[token]) + beam_expected_scores.append(sum(beam_scores)) + expected_scores.append(beam_expected_scores) + np.testing.assert_allclose(expected_scores, decoding_scores, atol=1e-5) + if __name__ == '__main__': absltest.main()