Skip to content

Commit

Permalink
Make use of prefilled cache in beam_search when `predict_batch_with_a…
Browse files Browse the repository at this point in the history
…ux.prompt_with_targets` is set to True

Without this change, the targets that are actually prompts will still be decoded AR style in order to build the cache, the logits are thrown away.

With this change, and setting `EncoderDecoder.predict_batch_with_aux.prompt_with_target` to `True` we instead skip over the given prompts all together, speeding up the decode.

PiperOrigin-RevId: 559256044
  • Loading branch information
afrozenator authored and t5-copybara committed Sep 3, 2023
1 parent 0728d84 commit aac9b3f
Show file tree
Hide file tree
Showing 2 changed files with 320 additions and 25 deletions.
198 changes: 173 additions & 25 deletions t5x/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,21 +970,37 @@ class BeamState:
cache: PyTree # Any pytree of arrays, e.g. flax attention Cache object


def beam_init(batch_size: int,
beam_size: int,
max_decode_len: int,
cache: Mapping[str, jnp.ndarray],
offset: int = 0) -> BeamState:
def beam_init(
batch_size: int,
beam_size: int,
max_decode_len: int,
cache: Mapping[str, jnp.ndarray],
offset: int = 0,
live_seqs: Optional[jnp.ndarray] = None,
) -> BeamState:
"""Initializes the beam search state data structure."""
cur_index0 = jnp.array(0)
live_logprobs0 = jnp.tile(
jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1])
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
# If we prefill any part of the prompt, then the initial live sequences are
# provided. In reality these will be the last token of the prompt or BOS if
# the prompt (in the batch) is empty.
live_seqs0 = (
live_seqs
if live_seqs is not None
else jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
)
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache)
# We will have to expand the cache_index if we're given an initial prompt that
# we prefill.
beam_cache0 = cache_map(
lambda x: add_beam_dim(x, beam_size, offset),
cache,
apply_to_index=live_seqs is not None,
)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand All @@ -995,20 +1011,54 @@ def beam_init(batch_size: int,
cache=beam_cache0)


# Beam search routine:
def right_align_prompts(prompts):
max_len = prompts.shape[1]
# [0, 1, 2, ... L - 1]
indices = jnp.arange(max_len)
# Last prompt's index.
prompt_max_index = jnp.argmax((prompts != 0) * indices[None, :], axis=1)
# Gap to the end of the matrix
gap = max_len - 1 - prompt_max_index
return jax.vmap(lambda x, y: jnp.roll(x, y, axis=0))(prompts, gap)


def left_align_prompts(prompts):
max_len = prompts.shape[1]
# [0, 1, 2, ... L - 1]
indices = jnp.arange(max_len)
# Indices of non padding positions.
non_padding_indices = (prompts != 0) * indices[None, :]
# Replace all padding with `max_len + 1`
m = jnp.where(non_padding_indices, non_padding_indices, max_len + 1)
# First prompt's index.
prompt_min_index = jnp.argmin(m, axis=1)
return jax.vmap(lambda x, y: jnp.roll(x, -1 * y, axis=0))(
prompts, prompt_min_index
)


def pick_last_prompt_token(prompts):
# prompts: i32[batch, length]
prompt_lengths = jnp.sum(prompts != 0, axis=1)
# return value: i32[batch,]
return prompts[jnp.arange(prompts.shape[0]), prompt_lengths]

def beam_search(inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[[DecodingState],
Tuple[jnp.ndarray,
Mapping[str, jnp.ndarray]]],
eos_id: int,
num_decodes: int = 4,
alpha: float = 0.6,
max_decode_len: Optional[int] = None,
decode_rng: Optional[jnp.ndarray] = None,
cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:

# Beam search routine:
def beam_search(
inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[
[DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]
],
eos_id: int,
num_decodes: int = 4,
alpha: float = 0.6,
max_decode_len: Optional[int] = None,
decode_rng: Optional[jnp.ndarray] = None,
cache_offset: int = 0,
initial_index: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Beam search for transformer machine translation.
If `inputs` has non-zero entries, those values are not modified, i.e.,
Expand All @@ -1028,6 +1078,11 @@ def beam_search(inputs: jnp.ndarray,
None, it uses `inputs.shape[1]` as `max_decode_len`.
decode_rng: Unused decoder RNG seed.
cache_offset: axis offset for cache, arising from scanned layers.
initial_index: Optional[jnp.ndarray], the index from which to start decoding
autoregressively if set. If unset, then we teacher-force the prefix, but
autoregressively (so it will be slow). When set, this also assumes that
the cache is appropriately populated. Since inputs are padded on the left
with BOS = 0, these are also the lengths of the prompts.
Returns:
Tuple of:
Expand All @@ -1046,15 +1101,40 @@ def beam_search(inputs: jnp.ndarray,
# We start with a dummy token in the beginning so extend the maximum length.
max_decode_len += 1

right_aligned_input = None
live_seqs = None
if initial_index is not None:
# Now contains the inputs, but "right aligned" so as to end with the last
# prompt token.
# [batch_size, length]
right_aligned_input = right_align_prompts(inputs)
# Override `input` to be just the last prompt token at length=0, with the
# rest of the length dimension as 0s.
inputs = jnp.zeros_like(inputs)
inputs = inputs.at[:, 0].set(right_aligned_input[:, -1])

live_seqs = jnp.zeros((batch_size, max_decode_len), dtype=jnp.int32)
live_seqs = live_seqs.at[:, 0].set(right_aligned_input[:, -1])
live_seqs = jnp.expand_dims(live_seqs, axis=1)
live_seqs = jnp.broadcast_to(
live_seqs, (live_seqs.shape[0], num_decodes, live_seqs.shape[-1])
)

# initialize beam search state
beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len,
cache, cache_offset)
beam_search_init_state = beam_init(
batch_size,
beam_size,
max_decode_len,
cache,
cache_offset,
live_seqs=live_seqs,
)

def beam_search_loop_cond_fn(state: BeamState) -> bool:
"""Beam search loop termination condition."""
# Have we reached max decoding length?
# Because we mutate the "i+1" position, we stop one token before the end.
not_at_end = (state.cur_index < max_decode_len - 1)
not_at_end = state.cur_index < max_decode_len - 1

# Is no further progress in the beam search possible?
# Get the best possible scores from alive sequences.
Expand Down Expand Up @@ -1134,12 +1214,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
topk_ids = topk_indices % vocab_size
# Force decode `inputs` into topk_ids up until PAD. When `inputs` is all
# PADs this is a no-op.
#
# Also note that when `initial_index` is set, we've already set-up the
# inputs so that at position 1 onwards (i.e. state.cur_index + 1 >= 1)
# the tokens are 0 and we'll immediately be "out of prompt".
# --> [batch_size, 1]
next_input_token = jnp.expand_dims(
inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1]
# --> [batch_size, 1]
out_of_prompt = (next_input_token == 0)

# When forcing prompts, update log probabilities to `0` for the top of the
# beam and -INF for the rest, effectively keeping only one beam alive.
# This is necessary, because if two beams have the same prefix, then they
# will both decode the exact same sequences and that's redundant.
# --> [batch, 2*beams]
inside_prompt_log_probs = jnp.concatenate([
jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype),
Expand Down Expand Up @@ -1171,6 +1259,18 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
# See if they've exceeded their `max_decode_len` allotment, this is not the
# same as checkpoint for `state.cur_index` reaching the end, since the
# variable length prefixes (prompts) will be removed in the prompt prefill
# mode.
if initial_index is not None:
# --> [batch, 2*beams]
topk_seq_lengths = jnp.sum(topk_seq != 0, axis=-1)
# total lengths along with initial prompts.
# initial_index[:, None] is shaped [batch_size, 1]
topk_seq_lengths += initial_index[:, None]
# Update `newly_finished` with anything that's beyond its' allowed length.
newly_finished |= topk_seq_lengths >= max_decode_len
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
Expand Down Expand Up @@ -1198,7 +1298,17 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:

# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # pytype: disable=wrong-arg-types # jax-devicearray
lengths = state.cur_index + 1
if initial_index is not None:
# We should add the lengths of the prompts to the beams as well to
# calculate the brevity penalty correctly.
# initial_index --> [batch_size,]
# topk_lengths --> [batch_size, 2*beams]
topk_lengths = jnp.repeat(initial_index[:, None], beams_to_keep, axis=1)
# lengths is now: [batch_size, 2*beams]
lengths = topk_lengths + lengths

new_scores = topk_log_probs / brevity_penalty(alpha, lengths) # pytype: disable=wrong-arg-types # jax-devicearray
# Mask out the still unfinished sequences by adding large negative value.
# --> [batch, 2*beams]
new_scores += (~newly_finished) * NEG_INF
Expand Down Expand Up @@ -1243,5 +1353,43 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
None], final_state.finished_scores,
final_state.live_logprobs)

# Drop the first dummy 0 token.
return finished_seqs[:, :, 1:], finished_scores
# Construct the finished sequences back from the prompts that we kept
# separately in the right aligned buffer.
if initial_index is not None:
# Right now we have right aligned inputs, and then the last tokens +
# completions in finished_seqs. We need to concatenate and get rid of the
# extra padding, while broadcasting in the beam dimension.

# Drop the first token, because it is also in the `right_aligned_input`
# [batch, beams, length]
finished_seqs = finished_seqs[:, :, 1:]
# right_aligned_input is [batch, length_prompt], we need to create a new
# beams dimension and broadcast it along that.
# --> [batch, beams, length]
right_aligned_input = jnp.broadcast_to(
right_aligned_input[:, None, :],
(batch_size, finished_seqs.shape[1], right_aligned_input.shape[-1]),
)
# Now concatenate along the length dimension.
# --> [batch, beams, length]
finished_seqs = jnp.concatenate(
[right_aligned_input, finished_seqs], axis=-1
)

# Now we left align everything.

# First flatten to [batch_size * beams, length]
flat_finished_seqs = jnp.reshape(
finished_seqs, (-1, finished_seqs.shape[-1])
)
# Left align everything.
flat_finished_seqs = left_align_prompts(flat_finished_seqs)
# Shape back to the original shape.
left_aligned_seqs = jnp.reshape(flat_finished_seqs, finished_seqs.shape)
# Cut to the desired length (-1 because we added 1 right off the bat)
finished_seqs = left_aligned_seqs[:, :, : max_decode_len - 1]
else:
# Just drop the first dummy 0 token.
finished_seqs = finished_seqs[:, :, 1:]

return finished_seqs, finished_scores
Loading

0 comments on commit aac9b3f

Please sign in to comment.