Skip to content

Commit

Permalink
Add a couple kv-cache helper functions. (#2206)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored May 23, 2024
1 parent 77ea479 commit 31cf641
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ impl Cache {
self.all_data.narrow(self.dim, 0, self.current_seq_len)
}

pub fn reset(&mut self) {
self.current_seq_len = 0
}

pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
if self.current_seq_len + seq_len > self.max_seq_len {
Expand Down Expand Up @@ -83,6 +87,22 @@ impl KvCache {
Ok(Self { k, v })
}

pub fn k_cache(&self) -> &Cache {
&self.k
}

pub fn v_cache(&self) -> &Cache {
&self.v
}

pub fn k_cache_mut(&mut self) -> &mut Cache {
&mut self.k
}

pub fn v_cache_mut(&mut self) -> &mut Cache {
&mut self.v
}

pub fn k(&self) -> Result<Tensor> {
self.k.current_data()
}
Expand All @@ -98,4 +118,13 @@ impl KvCache {
let v = self.v.current_data()?;
Ok((k, v))
}

pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}

pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
}

0 comments on commit 31cf641

Please sign in to comment.