From 31cf64147b9ab4a3d68849bef0ea59bdb0c113d6 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 23 May 2024 16:21:47 +0200 Subject: [PATCH] Add a couple kv-cache helper functions. (#2206) --- candle-nn/src/kv_cache.rs | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/candle-nn/src/kv_cache.rs b/candle-nn/src/kv_cache.rs index 684053dc73..10e9fe5abc 100644 --- a/candle-nn/src/kv_cache.rs +++ b/candle-nn/src/kv_cache.rs @@ -47,6 +47,10 @@ impl Cache { self.all_data.narrow(self.dim, 0, self.current_seq_len) } + pub fn reset(&mut self) { + self.current_seq_len = 0 + } + pub fn append(&mut self, src: &Tensor) -> Result<()> { let seq_len = src.dim(self.dim)?; if self.current_seq_len + seq_len > self.max_seq_len { @@ -83,6 +87,22 @@ impl KvCache { Ok(Self { k, v }) } + pub fn k_cache(&self) -> &Cache { + &self.k + } + + pub fn v_cache(&self) -> &Cache { + &self.v + } + + pub fn k_cache_mut(&mut self) -> &mut Cache { + &mut self.k + } + + pub fn v_cache_mut(&mut self) -> &mut Cache { + &mut self.v + } + pub fn k(&self) -> Result { self.k.current_data() } @@ -98,4 +118,13 @@ impl KvCache { let v = self.v.current_data()?; Ok((k, v)) } + + pub fn current_seq_len(&self) -> usize { + self.k.current_seq_len() + } + + pub fn reset(&mut self) { + self.k.reset(); + self.v.reset(); + } }