Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reserve KV cache capacity after the first model run #408

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 99 additions & 14 deletions rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,76 @@ enum KvCacheData {
BatchHeadSeqChans(NdTensor<f32, 4>),
}

impl KvCacheData {
/// Allocate a KV cache buffer with the given batch size, number of heads
/// and embed size.
///
/// The buffer initially has capacity to be extended to a sequence length
/// of `seq_len_capacity`.
fn with_capacity(
batch_size: usize,
n_heads: Option<usize>,
size: usize,
seq_len_capacity: usize,
) -> KvCacheData {
if let Some(n_heads) = n_heads {
KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
[batch_size, n_heads, seq_len_capacity, size],
2, /* seq dim */
))
} else {
KvCacheData::BatchSeqChans(NdTensor::with_capacity(
[batch_size, seq_len_capacity, size],
1, /* seq dim */
))
}
}

/// Return the current sequence length of the cache.
fn sequence_len(&self) -> usize {
match self {
KvCacheData::BatchSeqChans(data) => data.size(1),
KvCacheData::BatchHeadSeqChans(data) => data.size(2),
}
}

/// Return true if the KV cache has capacity for a given sequence length.
fn has_capacity(&self, sequence_len: usize) -> bool {
match self {
KvCacheData::BatchSeqChans(data) => {
data.has_capacity(1 /* seq dim */, sequence_len)
}
KvCacheData::BatchHeadSeqChans(data) => {
data.has_capacity(2 /* seq dim */, sequence_len)
}
}
}

/// Clone this cache into a new buffer with space to store sequences of
/// a given size.
fn clone_with_capacity(&self, max_sequence_len: usize) -> KvCacheData {
let max_sequence_len = max_sequence_len.max(self.sequence_len());
match self {
KvCacheData::BatchSeqChans(data) => {
let [batch, _seq, chans] = data.shape();
let mut new_data =
NdTensor::with_capacity([batch, max_sequence_len, chans], 1 /* seq dim */);
new_data.append(1, data).expect("should have capacity");
KvCacheData::BatchSeqChans(new_data)
}
KvCacheData::BatchHeadSeqChans(data) => {
let [batch, n_heads, _seq, chans] = data.shape();
let mut new_data = NdTensor::with_capacity(
[batch, n_heads, max_sequence_len, chans],
2, /* seq dim */
);
new_data.append(2, data).expect("should have capacity");
KvCacheData::BatchHeadSeqChans(new_data)
}
}
}
}

/// Key-value cache for a single layer of a transformer model.
struct KvCache {
/// Input ID for this cache entry.
Expand Down Expand Up @@ -440,23 +510,28 @@ impl<'a> Generator<'a> {
.find_node(&output_name)
.ok_or(GeneratorError::OutputNotFound(output_name))?;

// This value should be configurable.
let max_seq_len = 512;
// Initial sequence length capacity for KV cache buffer.
//
// For models that execute different operations on the first vs
// subsequent iterations (eg. Hugging Face "merged" models with
// past and no-past branches) the input buffer may not be used in
// the first iteration. Instead we need to reserve capacity once
// the model returns the initial KV cache.
//
// For other simpler models the input KV cache buffer is used for
// all iterations, in which case we would ideally reserve capacity
// up-front based on the max expected sequence length.
let max_seq_len = 1;

let kv_cache_entry = KvCache {
input_id,
output_id,
cache: if let Some(n_heads) = n_heads {
Some(KvCacheData::BatchHeadSeqChans(NdTensor::with_capacity(
[batch_size, n_heads, max_seq_len, size],
2, /* seq dim */
)))
} else {
Some(KvCacheData::BatchSeqChans(NdTensor::with_capacity(
[batch_size, max_seq_len, size],
1, /* seq dim */
)))
},
cache: Some(KvCacheData::with_capacity(
batch_size,
n_heads,
size,
max_seq_len,
)),
};

if kv_pattern.encoder {
Expand Down Expand Up @@ -717,7 +792,7 @@ impl<'a> Generator<'a> {
let output = outputs.remove(0);

let err_context = "failed to save self-attention KV-cache";
let kv_cache = match output.ndim() {
let mut kv_cache = match output.ndim() {
3 => KvCacheData::BatchSeqChans(
output.try_into().map_err(|e| wrap_error(e, err_context))?,
),
Expand All @@ -731,6 +806,16 @@ impl<'a> Generator<'a> {
));
}
};

// Grow the KV cache buffer if it has reached the limit of its
// pre-allocated sequence length.
//
// Double the capacity each time to amortize the costs of copying
// the previous buffer.
if !kv_cache.has_capacity(kv_cache.sequence_len() + 1) {
kv_cache = kv_cache.clone_with_capacity(kv_cache.sequence_len() * 2);
}

cache_entry.cache = Some(kv_cache);
}

Expand Down
Loading