Skip to content

Commit

Permalink
Make use of prefilled cache in beam_search when `predict_batch_with_a…
Browse files Browse the repository at this point in the history
…ux.prompt_with_targets` is set to True

Without this change, the targets that are actually prompts will still be decoded AR style in order to build the cache, the logits are thrown away.

With this change, and setting `EncoderDecoder.predict_batch_with_aux.prompt_with_target` to `True` we instead skip over the given prompts all together, speeding up the decode.

PiperOrigin-RevId: 559256044
  • Loading branch information
afrozenator authored and t5-copybara committed Oct 4, 2023
1 parent 5107638 commit 340b634
Show file tree
Hide file tree
Showing 2 changed files with 386 additions and 27 deletions.
253 changes: 226 additions & 27 deletions t5x/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,47 +968,131 @@ class BeamState:
finished_flags: jax.Array # bool: [batch_size, beam_size]
# The current state of the autoregressive decoding caches.
cache: PyTree # Any pytree of arrays, e.g. flax attention Cache object
# Optional array of initial indices from which decoding starts, will be either
# 0s if there is no prompt or None.
initial_index: jax.Array | None


def beam_init(batch_size: int,
beam_size: int,
max_decode_len: int,
cache: Mapping[str, jnp.ndarray],
offset: int = 0) -> BeamState:
def beam_init(
batch_size: int,
beam_size: int,
max_decode_len: int,
cache: Mapping[str, jnp.ndarray],
offset: int = 0,
live_seqs: Optional[jnp.ndarray] = None,
initial_index: Optional[jnp.ndarray] = None,
) -> BeamState:
"""Initializes the beam search state data structure."""
cur_index0 = jnp.array(0)
live_logprobs0 = jnp.tile(
jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1])
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
# If we prefill any part of the prompt, then the initial live sequences are
# provided. In reality these will be the last token of the prompt or BOS if
# the prompt (in the batch) is empty.
live_seqs0 = (
live_seqs
if live_seqs is not None
else jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
)
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache)
# We will have to expand the cache_index if we're given an initial prompt that
# we prefill.
beam_cache0 = cache_map(
lambda x: add_beam_dim(x, beam_size, offset),
cache,
apply_to_index=live_seqs is not None,
)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
finished_scores=finished_scores0,
live_seqs=live_seqs0,
finished_seqs=finished_seqs0,
finished_flags=finished_flags0,
cache=beam_cache0)
cache=beam_cache0,
initial_index=initial_index,
)


# Beam search routine:
def _right_align_prompts(prompts):
"""Right align the prompts."""

# Implementation note:
#
# A very short code to do this right aligning, would be to vmap a jnp.roll for
# the amount of padding in each example, i.e. max_len - prompt_max_index
# (+-1) - however this is slow.
#
# A faster way, courtesy Jeremiah Willcock, is to shift rows by bitmasking
# the gap and iterating for 1, 2, 4, ... log2(len) bitmasks.
#
# This gives a ~3x speedup over vmapping a roll.

max_len = prompts.shape[1]
nbits = np.ceil(np.log2(max_len)).astype(np.int32)
indices = jnp.arange(max_len)
prompt_max_index = jnp.argmax((prompts != 0) * indices[None, :], axis=1)
shifts = max_len - prompt_max_index - 1
for i in range(0, nbits + 1):
bitmask = 2**i
prompts = jnp.where(
jnp.expand_dims(shifts & bitmask, 1),
jnp.pad(prompts, ((0, 0), (bitmask, 0)))[:, :-bitmask],
prompts,
)
return prompts


def _left_align_prompts(prompts):
"""Left align the prompts."""
# See implementation notes in `_right_align_prompts`.

max_len = prompts.shape[1]
# [0, 1, 2, ... L - 1]
indices = jnp.arange(max_len)
# Indices of non padding positions - 1 based, since `indices` is 0 based.
non_padding_positions = (prompts != 0) * (indices[None, :] + 1)
# Replace all padding with `max_len + 1`
m = jnp.where(non_padding_positions, non_padding_positions, max_len + 1)
# First prompt's index.
shifts = jnp.argmin(m, axis=1)
temp = prompts
nbits = np.ceil(np.log2(max_len)).astype(np.int32)
for i in range(0, nbits + 1):
bitmask = 2**i
temp = jnp.where(
jnp.expand_dims(shifts & bitmask, 1),
jnp.pad(temp, ((0, 0), (0, bitmask)))[:, bitmask:],
temp,
)
return temp


def beam_search(inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[[DecodingState],
Tuple[jnp.ndarray,
Mapping[str, jnp.ndarray]]],
eos_id: int,
num_decodes: int = 4,
alpha: float = 0.6,
max_decode_len: Optional[int] = None,
decode_rng: Optional[jnp.ndarray] = None,
cache_offset: int = 0) -> Tuple[jnp.ndarray, jnp.ndarray]:
def _pick_last_prompt_token(prompts):
# prompts: i32[batch, length]
prompt_lengths = jnp.sum(prompts != 0, axis=1)
# return value: i32[batch,]
return prompts[jnp.arange(prompts.shape[0]), prompt_lengths]


# Beam search routine:
def beam_search(
inputs: jnp.ndarray,
cache: Mapping[str, jnp.ndarray],
tokens_to_logits: Callable[
[DecodingState], Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]
],
eos_id: int,
num_decodes: int = 4,
alpha: float = 0.6,
max_decode_len: Optional[int] = None,
decode_rng: Optional[jnp.ndarray] = None,
cache_offset: int = 0,
initial_index: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Beam search for transformer machine translation.
If `inputs` has non-zero entries, those values are not modified, i.e.,
Expand All @@ -1028,6 +1112,11 @@ def beam_search(inputs: jnp.ndarray,
None, it uses `inputs.shape[1]` as `max_decode_len`.
decode_rng: Unused decoder RNG seed.
cache_offset: axis offset for cache, arising from scanned layers.
initial_index: Optional[jnp.ndarray], the index from which to start decoding
autoregressively if set. If unset, then we teacher-force the prefix, but
autoregressively (so it will be slow). When set, this also assumes that
the cache is appropriately populated. Since inputs are padded on the left
with BOS = 0, these are also the lengths of the prompts.
Returns:
Tuple of:
Expand All @@ -1046,15 +1135,55 @@ def beam_search(inputs: jnp.ndarray,
# We start with a dummy token in the beginning so extend the maximum length.
max_decode_len += 1

right_aligned_input = None
live_seqs = None
if initial_index is not None:
# Now contains the inputs, but "right aligned" so as to end with the last
# prompt token.
# [batch_size, length]
right_aligned_input = _right_align_prompts(inputs)
# `inputs` now is just the last token of the prompt, right padded to the
# same as before.
length = inputs.shape[1]
inputs = jnp.pad(
right_aligned_input[:, -1][:, None],
((0, 0), (0, length - 1)),
constant_values=0,
)

# Sized [batch, max_decode_len]
live_seqs = jnp.pad(
right_aligned_input[:, -1][:, None],
((0, 0), (0, max_decode_len - 1)),
constant_values=0,
)
live_seqs = jnp.expand_dims(live_seqs, axis=1)
live_seqs = jnp.broadcast_to(
live_seqs, (live_seqs.shape[0], num_decodes, live_seqs.shape[-1])
)
else:
initial_index = jnp.zeros((batch_size,), dtype=jnp.int32)

# initialize beam search state
beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len,
cache, cache_offset)
beam_search_init_state = beam_init(
batch_size,
beam_size,
max_decode_len,
cache,
cache_offset,
live_seqs=live_seqs,
initial_index=initial_index,
)

def beam_search_loop_cond_fn(state: BeamState) -> bool:
"""Beam search loop termination condition."""
# Have we reached max decoding length?

# Since we might be starting at different points in the prompts, let's use
# the minimum prompt length to stop conservatively.
cur_index = state.cur_index + jnp.min(state.initial_index)
# Because we mutate the "i+1" position, we stop one token before the end.
not_at_end = (state.cur_index < max_decode_len - 1)
not_at_end = cur_index < max_decode_len - 1

# Is no further progress in the beam search possible?
# Get the best possible scores from alive sequences.
Expand Down Expand Up @@ -1134,12 +1263,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
topk_ids = topk_indices % vocab_size
# Force decode `inputs` into topk_ids up until PAD. When `inputs` is all
# PADs this is a no-op.
#
# Also note that when `initial_index` is set, we've already setup the
# inputs so that at position 1 onwards (i.e. state.cur_index + 1 >= 1)
# the tokens are 0 and we'll immediately be "out of prompt".
# --> [batch_size, 1]
next_input_token = jnp.expand_dims(
inputs, axis=1).astype(jnp.int32)[:, :, state.cur_index + 1]
# --> [batch_size, 1]
out_of_prompt = (next_input_token == 0)

# When forcing prompts, update log probabilities to `0` for the top of the
# beam and -INF for the rest, effectively keeping only one beam alive.
# This is necessary, because if two beams have the same prefix, then they
# will both decode the exact same sequences and that's redundant.
# --> [batch, 2*beams]
inside_prompt_log_probs = jnp.concatenate([
jnp.zeros((batch_size, 1), dtype=topk_log_probs.dtype),
Expand Down Expand Up @@ -1171,6 +1308,19 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
# See if they've exceeded their `max_decode_len` allotment, this is not the
# same as checkpoint for `state.cur_index` reaching the end, since
# `initial_index` will contain variable length prompts whose lengths we'd
# need to accommodate.
# --> [batch, 2*beams]
topk_seq_lengths = jnp.sum(topk_seq != 0, axis=-1)
# total lengths along with initial prompts.
# initial_index[:, None] is shaped [batch_size, 1]
topk_seq_lengths += initial_index[:, None]
# Update `newly_finished` with anything that's beyond its' allowed length.
# NOTE: This might bump out some completed sequences, with an incomplete
# sequence if the incomplete sequence's score is higher.
newly_finished |= topk_seq_lengths >= max_decode_len
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
Expand Down Expand Up @@ -1198,7 +1348,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:

# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) # pytype: disable=wrong-arg-types # jax-devicearray
lengths = state.cur_index + 1
# We should add the lengths of the prompts to the beams as well to
# calculate the brevity penalty correctly.
# initial_index --> [batch_size,]
# topk_lengths --> [batch_size, 2*beams]
topk_lengths = jnp.repeat(initial_index[:, None], beams_to_keep, axis=1)
# lengths is now: [batch_size, 2*beams]
lengths = topk_lengths + lengths

new_scores = topk_log_probs / brevity_penalty(alpha, lengths) # pytype: disable=wrong-arg-types # jax-devicearray
# Mask out the still unfinished sequences by adding large negative value.
# --> [batch, 2*beams]
new_scores += (~newly_finished) * NEG_INF
Expand All @@ -1225,7 +1384,9 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
live_seqs=top_alive_seq,
finished_seqs=top_finished_seq,
finished_flags=top_finished_flags,
cache=top_alive_cache)
cache=top_alive_cache,
initial_index=initial_index,
)

# Run while loop and get final beam search state.
final_state = lax.while_loop(beam_search_loop_cond_fn,
Expand All @@ -1243,5 +1404,43 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
None], final_state.finished_scores,
final_state.live_logprobs)

# Drop the first dummy 0 token.
return finished_seqs[:, :, 1:], finished_scores
# Construct the finished sequences back from the prompts that we kept
# separately in the right aligned buffer.
if right_aligned_input is not None:
# Right now we have right aligned inputs, and then the last tokens +
# completions in finished_seqs. We need to concatenate and get rid of the
# extra padding, while broadcasting in the beam dimension.

# Drop the first token, because it is also in the `right_aligned_input`
# [batch, beams, length]
finished_seqs = finished_seqs[:, :, 1:]
# right_aligned_input is [batch, length_prompt], we need to create a new
# beams dimension and broadcast it along that.
# --> [batch, beams, length]
right_aligned_input = jnp.broadcast_to(
right_aligned_input[:, None, :],
(batch_size, finished_seqs.shape[1], right_aligned_input.shape[-1]),
)
# Now concatenate along the length dimension.
# --> [batch, beams, length]
finished_seqs = jnp.concatenate(
[right_aligned_input, finished_seqs], axis=-1
)

# Now we left align everything.

# First flatten to [batch_size * beams, length]
flat_finished_seqs = jnp.reshape(
finished_seqs, (-1, finished_seqs.shape[-1])
)
# Left align everything.
flat_finished_seqs = _left_align_prompts(flat_finished_seqs)
# Shape back to the original shape.
left_aligned_seqs = jnp.reshape(flat_finished_seqs, finished_seqs.shape)
# Cut to the desired length (-1 because we added 1 right off the bat)
finished_seqs = left_aligned_seqs[:, :, : max_decode_len - 1]
else:
# Just drop the first dummy 0 token.
finished_seqs = finished_seqs[:, :, 1:]

return finished_seqs, finished_scores
Loading

0 comments on commit 340b634

Please sign in to comment.