diff --git a/Cargo.toml b/Cargo.toml index f50a4f2..2ccbf81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "web-rwkv" -version = "0.6.9" +version = "0.6.10" edition = "2021" authors = ["Zhenyuan Zhang "] license = "MIT OR Apache-2.0" diff --git a/examples/batch.rs b/examples/batch.rs index 137c265..adce846 100644 --- a/examples/batch.rs +++ b/examples/batch.rs @@ -6,6 +6,7 @@ use crossterm::terminal::{ }; #[cfg(not(debug_assertions))] use dialoguer::{theme::ColorfulTheme, Select}; +use half::f16; use itertools::Itertools; use memmap2::Mmap; #[cfg(not(debug_assertions))] @@ -148,7 +149,7 @@ async fn run(cli: Cli) -> Result<()> { match info.version { ModelVersion::V4 => { - let model: v4::Model = load_model( + let model: v4::Model = load_model( &context, &map, cli.lora, @@ -167,7 +168,7 @@ async fn run(cli: Cli) -> Result<()> { run_internal(model, state, tokenizer, cli.batch).await } ModelVersion::V5 => { - let model: v5::Model = load_model( + let model: v5::Model = load_model( &context, &map, cli.lora, @@ -186,7 +187,7 @@ async fn run(cli: Cli) -> Result<()> { run_internal(model, state, tokenizer, cli.batch).await } ModelVersion::V6 => { - let model: v6::Model = load_model( + let model: v6::Model = load_model( &context, &map, cli.lora, diff --git a/examples/chat.rs b/examples/chat.rs index 95a5b04..6ab8ab6 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -2,6 +2,7 @@ use anyhow::Result; use clap::{Args, Parser, ValueEnum}; #[cfg(not(debug_assertions))] use dialoguer::{theme::ColorfulTheme, Select}; +use half::f16; use itertools::Itertools; use memmap2::Mmap; use serde::Deserialize; @@ -188,7 +189,7 @@ async fn run(cli: Cli) -> Result<()> { match info.version { ModelVersion::V4 => { - let model: v4::Model = load_model( + let model: v4::Model = load_model( &context, &map, cli.lora, @@ -202,7 +203,7 @@ async fn run(cli: Cli) -> Result<()> { run_internal(model, state, tokenizer, prompt, sampler).await } ModelVersion::V5 => { - let model: v5::Model = load_model( + let model: v5::Model = load_model( &context, &map, cli.lora, @@ -216,7 +217,7 @@ async fn run(cli: Cli) -> Result<()> { run_internal(model, state, tokenizer, prompt, sampler).await } ModelVersion::V6 => { - let model: v6::Model = load_model( + let model: v6::Model = load_model( &context, &map, cli.lora, diff --git a/examples/gen.rs b/examples/gen.rs index 00ef645..41efe44 100644 --- a/examples/gen.rs +++ b/examples/gen.rs @@ -2,6 +2,7 @@ use anyhow::Result; use clap::{Parser, ValueEnum}; #[cfg(not(debug_assertions))] use dialoguer::{theme::ColorfulTheme, Select}; +use half::f16; #[cfg(not(debug_assertions))] use itertools::Itertools; use memmap2::Mmap; @@ -122,7 +123,7 @@ async fn run(cli: Cli) -> Result<()> { match info.version { ModelVersion::V4 => { - let model: v4::Model = load_model( + let model: v4::Model = load_model( &context, &map, cli.lora, @@ -136,7 +137,7 @@ async fn run(cli: Cli) -> Result<()> { run_internal(model, state, tokenizer).await } ModelVersion::V5 => { - let model: v5::Model = load_model( + let model: v5::Model = load_model( &context, &map, cli.lora, @@ -150,7 +151,7 @@ async fn run(cli: Cli) -> Result<()> { run_internal(model, state, tokenizer).await } ModelVersion::V6 => { - let model: v6::Model = load_model( + let model: v6::Model = load_model( &context, &map, cli.lora, diff --git a/examples/inspector.rs b/examples/inspector.rs index 1000c1c..6186570 100644 --- a/examples/inspector.rs +++ b/examples/inspector.rs @@ -147,7 +147,7 @@ async fn run(cli: Cli) -> Result<()> { println!("{:#?}", info); let context = create_context(&info).await?; - let model: v5::Model = load_model( + let model: v5::Model = load_model( &context, &map, cli.lora, @@ -168,7 +168,11 @@ async fn run(cli: Cli) -> Result<()> { hooks.insert( v5::Hook::PostFfn(layer), Box::new( - move |_model, _state, runtime: &v5::Runtime| -> Result { + move |_model, + _state, + runtime: &v5::Runtime<_>, + _header| + -> Result { // figure out how many tokens this run has let shape = runtime.ffn_x.shape(); let num_token = shape[1]; @@ -222,7 +226,7 @@ async fn run(cli: Cli) -> Result<()> { &tensor.head.layer_norm.b, &buffer.ffn_x, None, - v5::Model::LN_EPS, + v5::Model::::LN_EPS, )?, tensor.head.w.matmul_mat_op( buffer.ffn_x.view(.., .., .., ..)?, diff --git a/src/model/mod.rs b/src/model/mod.rs index 5afded2..4a04fd0 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -163,12 +163,8 @@ pub trait ModelState: for<'a> FromBuilder = StateBuilder, Error = In } pub trait ModelBase { - type ModelTensor; - fn context(&self) -> &Context; fn info(&self) -> &ModelInfo; - - fn tensor(&self) -> &Self::ModelTensor; } pub trait Model: diff --git a/src/model/run.rs b/src/model/run.rs index e546531..36f509b 100644 --- a/src/model/run.rs +++ b/src/model/run.rs @@ -9,6 +9,7 @@ use super::{ }; use crate::{ context::Context, + num::{Float, Hom}, tensor::{ kind::{ReadBack, ReadWrite}, ops::TensorOp, @@ -18,13 +19,13 @@ use crate::{ }; #[derive(Debug)] -pub struct Header { - pub head_x: TensorGpu, +pub struct Header { + pub head_x: TensorGpu, pub head_o: TensorGpu, pub map: TensorGpu, } -impl Header { +impl Header { pub fn new(context: &Context, info: &ModelInfo, num_batch: usize) -> Self { let head_shape = Shape::new(info.num_emb, num_batch, 1, 1); let output_shape = Shape::new(info.num_vocab, num_batch, 1, 1); @@ -37,16 +38,20 @@ impl Header { } } -pub type HookMap = - HashMap Result>>; +pub type HookMap = + HashMap Result>>; pub(crate) trait ModelRunInternal: ModelBase { type Hook: Hash; type State: ModelState; + type Tensor; type Runtime; + type Header; + + fn tensor(&self) -> &Self::Tensor; fn checkout_runtime(&self, num_batch: usize) -> Arc; - fn checkout_header(&self, num_batch: usize) -> Arc
; + fn checkout_header(&self, num_batch: usize) -> Arc; /// To prevent the GPU device from lost, this limits the maximum batch-token it processes one time. fn token_chunk_size(&self) -> usize; @@ -60,14 +65,14 @@ pub(crate) trait ModelRunInternal: ModelBase { tokens: Vec>, state: &Self::State, outputs: Vec>, - hooks: &HookMap, + hooks: &HookMap, ) -> Result<(TensorGpu, Vec>), TensorError>; - fn create_input<'a>( + fn create_input<'a, F: Float>( &self, embed: &TensorCpu<'a, f16>, tokens: &[Vec], - ) -> Result, TensorError> { + ) -> Result, TensorError> { let info = self.info(); let context = self.context(); @@ -80,7 +85,8 @@ pub(crate) trait ModelRunInternal: ModelBase { .map(|&token| embed.slice(.., token as usize, .., ..)) .try_collect()?, ) - .unwrap_or_else(|_| context.zeros(Shape::new(info.num_emb, 1, 0, 1))); + .unwrap_or_else(|_| context.zeros(Shape::new(info.num_emb, 1, 0, 1))) + .map(|x| Hom::co_hom(*x)); stack.reshape( TensorDimension::Full, TensorDimension::Auto, @@ -96,7 +102,11 @@ pub(crate) trait ModelRunInternal: ModelBase { pub trait ModelRun { type Hook: Hash; type State: ModelState; + type Tensor; type Runtime; + type Header; + + fn tensor(&self) -> &Self::Tensor; /// Run the model for a batch of tokens as input. /// The length of `tokens` must match the number of batches in `state`. @@ -110,23 +120,37 @@ pub trait ModelRun { /// Run the model for a batch of tokens as input, but with custom hooks. /// The length of `tokens` must match the number of batches in `state`. /// `tokens` may have slots with no tokens, for which `run` won't compute that batch and will return an empty vector in that corresponding slot. + #[allow(clippy::type_complexity)] fn run_with_hooks( &self, tokens: &mut Vec, state: &Self::State, - hooks: &HookMap, + hooks: &HookMap, ) -> impl Future, TensorError>>; } -impl ModelRun for Model +impl ModelRun for Model where Hook: Hash, - Model: ModelRunInternal, State: super::ModelState, + Model: ModelRunInternal< + Hook = Hook, + Tensor = Tensor, + State = State, + Runtime = Runtime, + Header = Header, + >, { type Hook = Hook; - type Runtime = Runtime; type State = State; + type Tensor = Tensor; + type Runtime = Runtime; + type Header = Header; + + #[inline] + fn tensor(&self) -> &Self::Tensor { + ::tensor(self) + } async fn run( &self, @@ -141,7 +165,7 @@ where &self, tokens: &mut Vec, state: &Self::State, - hooks: &HookMap, + hooks: &HookMap, ) -> Result, TensorError> { let num_token: usize = tokens.iter().map(|input| input.tokens.len()).sum(); let num_batch = state.num_batch(); diff --git a/src/model/v4.rs b/src/model/v4.rs index 5a0821a..266e891 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -15,6 +15,7 @@ use super::{ use crate::{ context::Context, model::RESCALE_LAYER, + num::{Float, Hom}, tensor::{ cache::ResourceCache, kind::{ReadBack, ReadWrite}, @@ -26,7 +27,7 @@ use crate::{ }; #[derive(Debug)] -pub struct Model<'a> { +pub struct Model<'a, F: Float> { context: Context, info: ModelInfo, @@ -36,8 +37,8 @@ pub struct Model<'a> { token_chunk_size: usize, tensor: ModelTensor<'a>, - runtime_cache: ResourceCache, - header_cache: ResourceCache, + runtime_cache: ResourceCache>, + header_cache: ResourceCache>, softmax_cache: ResourceCache, } @@ -102,31 +103,31 @@ pub struct Head { /// Runtime buffers. #[derive(Debug)] -pub struct Runtime { +pub struct Runtime { pub tokens: TensorGpu, pub cursors: TensorGpu, - pub input: TensorGpu, + pub input: TensorGpu, - pub att_x: TensorGpu, - pub att_kx: TensorGpu, - pub att_vx: TensorGpu, - pub att_rx: TensorGpu, + pub att_x: TensorGpu, + pub att_kx: TensorGpu, + pub att_vx: TensorGpu, + pub att_rx: TensorGpu, pub att_k: TensorGpu, pub att_v: TensorGpu, pub att_r: TensorGpu, - pub att_o: TensorGpu, + pub att_o: TensorGpu, - pub ffn_x: TensorGpu, - pub ffn_kx: TensorGpu, - pub ffn_rx: TensorGpu, - pub ffn_k: TensorGpu, - pub ffn_v: TensorGpu, - pub ffn_r: TensorGpu, + pub ffn_x: TensorGpu, + pub ffn_kx: TensorGpu, + pub ffn_rx: TensorGpu, + pub ffn_k: TensorGpu, + pub ffn_v: TensorGpu, + pub ffn_r: TensorGpu, pub aux_x: TensorGpu, } -impl Runtime { +impl Runtime { pub fn new(context: &Context, info: &ModelInfo, num_token: usize, max_token: usize) -> Self { let shape = Shape::new(info.num_emb, num_token, 1, 1); let tokens_shape = Shape::new(num_token, 1, 1, 1); @@ -399,12 +400,12 @@ impl super::BackedState for BackedState { } } -impl<'a> Model<'a> { +impl<'a, F: Float> Model<'a, F> { pub const LN_EPS: f32 = 1.0e-5; pub const GN_EPS: f32 = 64.0e-5; } -impl FromBuilder for Model<'_> { +impl FromBuilder for Model<'_, F> { type Builder<'a> = ModelBuilder<'a>; type Error = anyhow::Error; @@ -562,9 +563,7 @@ impl FromBuilder for Model<'_> { } } -impl<'a> ModelBase for Model<'a> { - type ModelTensor = ModelTensor<'a>; - +impl<'a, F: Float> ModelBase for Model<'a, F> { #[inline] fn context(&self) -> &Context { &self.context @@ -574,14 +573,9 @@ impl<'a> ModelBase for Model<'a> { fn info(&self) -> &ModelInfo { &self.info } - - #[inline] - fn tensor(&self) -> &Self::ModelTensor { - &self.tensor - } } -impl ModelSoftmaxInternal for Model<'_> { +impl ModelSoftmaxInternal for Model<'_, F> { #[inline] fn checkout_softmax(&self, num_batch: usize) -> Arc { self.softmax_cache.checkout(num_batch, || { @@ -590,20 +584,27 @@ impl ModelSoftmaxInternal for Model<'_> { } } -impl ModelRunInternal for Model<'_> { +impl<'a, F: Float + Hom> ModelRunInternal for Model<'a, F> { type Hook = Hook; - type Runtime = Runtime; type State = ModelState; + type Tensor = ModelTensor<'a>; + type Runtime = Runtime; + type Header = Header; + + #[inline] + fn tensor(&self) -> &Self::Tensor { + &self.tensor + } #[inline] - fn checkout_runtime(&self, num_token: usize) -> Arc { + fn checkout_runtime(&self, num_token: usize) -> Arc { self.runtime_cache.checkout(num_token, || { Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) }) } #[inline] - fn checkout_header(&self, num_batch: usize) -> Arc
{ + fn checkout_header(&self, num_batch: usize) -> Arc { self.header_cache.checkout(num_batch, || { Header::new(&self.context, &self.info, num_batch) }) @@ -624,7 +625,7 @@ impl ModelRunInternal for Model<'_> { tokens: Vec>, state: &ModelState, outputs: Vec>, - hooks: &HookMap, + hooks: &HookMap, ) -> Result<(TensorGpu, Vec>), TensorError> { let context = &self.context; let tensor = &self.tensor; @@ -666,7 +667,7 @@ impl ModelRunInternal for Model<'_> { let hook_op = |hook: Hook| -> Result { hooks .get(&hook) - .map(|f| f(self, state, &buffer)) + .map(|f| f(&self.tensor, state, &buffer, &header)) .unwrap_or_else(|| Ok(TensorOp::List(vec![]))) }; diff --git a/src/model/v5.rs b/src/model/v5.rs index 1e46c9a..8e9887b 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -14,6 +14,7 @@ use super::{ use crate::{ context::Context, model::{OutputType, RESCALE_LAYER}, + num::Float, tensor::{ cache::ResourceCache, kind::{ReadBack, ReadWrite}, @@ -26,7 +27,7 @@ use crate::{ }; #[derive(Debug)] -pub struct Model<'a> { +pub struct Model<'a, F: Float> { context: Context, info: ModelInfo, @@ -36,8 +37,8 @@ pub struct Model<'a> { token_chunk_size: usize, tensor: ModelTensor<'a>, - runtime_cache: ResourceCache, - header_cache: ResourceCache, + runtime_cache: ResourceCache>, + header_cache: ResourceCache>, softmax_cache: ResourceCache, } @@ -106,33 +107,33 @@ pub struct Head { /// Runtime buffers. #[derive(Debug)] -pub struct Runtime { +pub struct Runtime { pub tokens: TensorGpu, pub cursors: TensorGpu, - pub input: TensorGpu, + pub input: TensorGpu, - pub att_x: TensorGpu, - pub att_kx: TensorGpu, - pub att_vx: TensorGpu, - pub att_rx: TensorGpu, - pub att_gx: TensorGpu, + pub att_x: TensorGpu, + pub att_kx: TensorGpu, + pub att_vx: TensorGpu, + pub att_rx: TensorGpu, + pub att_gx: TensorGpu, pub att_k: TensorGpu, pub att_v: TensorGpu, pub att_r: TensorGpu, - pub att_g: TensorGpu, - pub att_o: TensorGpu, + pub att_g: TensorGpu, + pub att_o: TensorGpu, - pub ffn_x: TensorGpu, - pub ffn_kx: TensorGpu, - pub ffn_rx: TensorGpu, - pub ffn_k: TensorGpu, - pub ffn_v: TensorGpu, - pub ffn_r: TensorGpu, + pub ffn_x: TensorGpu, + pub ffn_kx: TensorGpu, + pub ffn_rx: TensorGpu, + pub ffn_k: TensorGpu, + pub ffn_v: TensorGpu, + pub ffn_r: TensorGpu, pub aux_x: TensorGpu, } -impl Runtime { +impl Runtime { pub fn new(context: &Context, info: &ModelInfo, num_token: usize, max_token: usize) -> Self { let shape = Shape::new(info.num_emb, num_token, 1, 1); let tokens_shape = Shape::new(num_token, 1, 1, 1); @@ -463,12 +464,12 @@ impl super::BackedState for BackedState { } } -impl<'a> Model<'a> { +impl<'a, F: Float> Model<'a, F> { pub const LN_EPS: f32 = 1.0e-5; pub const GN_EPS: f32 = 64.0e-5; } -impl<'a> FromBuilder for Model<'a> { +impl<'a, F: Float> FromBuilder for Model<'a, F> { type Builder<'b> = ModelBuilder<'b>; type Error = anyhow::Error; @@ -649,9 +650,7 @@ impl<'a> FromBuilder for Model<'a> { } } -impl<'a> ModelBase for Model<'a> { - type ModelTensor = ModelTensor<'a>; - +impl<'a, F: Float> ModelBase for Model<'a, F> { #[inline] fn context(&self) -> &Context { &self.context @@ -661,14 +660,9 @@ impl<'a> ModelBase for Model<'a> { fn info(&self) -> &ModelInfo { &self.info } - - #[inline] - fn tensor(&self) -> &Self::ModelTensor { - &self.tensor - } } -impl ModelSoftmaxInternal for Model<'_> { +impl ModelSoftmaxInternal for Model<'_, F> { #[inline] fn checkout_softmax(&self, num_batch: usize) -> Arc { self.softmax_cache.checkout(num_batch, || { @@ -677,20 +671,27 @@ impl ModelSoftmaxInternal for Model<'_> { } } -impl ModelRunInternal for Model<'_> { +impl<'a, F: Float> ModelRunInternal for Model<'a, F> { type Hook = Hook; - type Runtime = Runtime; type State = ModelState; + type Tensor = ModelTensor<'a>; + type Runtime = Runtime; + type Header = Header; + + #[inline] + fn tensor(&self) -> &Self::Tensor { + &self.tensor + } #[inline] - fn checkout_runtime(&self, num_token: usize) -> Arc { + fn checkout_runtime(&self, num_token: usize) -> Arc { self.runtime_cache.checkout(num_token, || { Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) }) } #[inline] - fn checkout_header(&self, num_batch: usize) -> Arc
{ + fn checkout_header(&self, num_batch: usize) -> Arc { self.header_cache.checkout(num_batch, || { Header::new(&self.context, &self.info, num_batch) }) @@ -711,7 +712,7 @@ impl ModelRunInternal for Model<'_> { tokens: Vec>, state: &ModelState, outputs: Vec>, - hooks: &HookMap, + hooks: &HookMap, ) -> Result<(TensorGpu, Vec>), TensorError> { let context = &self.context; let tensor = &self.tensor; @@ -754,7 +755,7 @@ impl ModelRunInternal for Model<'_> { let hook_op = |hook: Hook| -> Result { hooks .get(&hook) - .map(|f| f(self, state, &buffer)) + .map(|f| f(&self.tensor, state, &buffer, &header)) .unwrap_or_else(|| Ok(TensorOp::List(vec![]))) }; diff --git a/src/model/v6.rs b/src/model/v6.rs index 5172fe5..18751ff 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -14,6 +14,7 @@ use super::{ use crate::{ context::Context, model::RESCALE_LAYER, + num::Float, tensor::{ cache::ResourceCache, kind::{ReadBack, ReadWrite}, @@ -26,7 +27,7 @@ use crate::{ }; #[derive(Debug)] -pub struct Model<'a> { +pub struct Model<'a, F: Float> { context: Context, info: ModelInfo, @@ -36,8 +37,8 @@ pub struct Model<'a> { token_chunk_size: usize, tensor: ModelTensor<'a>, - runtime_cache: ResourceCache, - header_cache: ResourceCache, + runtime_cache: ResourceCache>, + header_cache: ResourceCache>, softmax_cache: ResourceCache, } @@ -113,55 +114,55 @@ pub struct Head { /// Runtime buffers. #[derive(Debug)] -pub struct Runtime { +pub struct Runtime { pub tokens: TensorGpu, pub cursors: TensorGpu, - pub input: TensorGpu, + pub input: TensorGpu, - pub att_x: TensorGpu, - pub att_xx: TensorGpu, + pub att_x: TensorGpu, + pub att_xx: TensorGpu, /// Token shifted time decay input, `[C, T]`. - pub att_wx: TensorGpu, - pub att_kx: TensorGpu, - pub att_vx: TensorGpu, - pub att_rx: TensorGpu, - pub att_gx: TensorGpu, + pub att_wx: TensorGpu, + pub att_kx: TensorGpu, + pub att_vx: TensorGpu, + pub att_rx: TensorGpu, + pub att_gx: TensorGpu, /// Time decay LoRA intermediate, `[64, T]`. - pub att_w: TensorGpu, + pub att_w: TensorGpu, pub att_k: TensorGpu, pub att_v: TensorGpu, pub att_r: TensorGpu, - pub att_g: TensorGpu, - pub att_o: TensorGpu, + pub att_g: TensorGpu, + pub att_o: TensorGpu, /// Token shift LoRA intermediate, `[32, 5, T]`. - pub time_mix_x: TensorGpu, + pub time_mix_x: TensorGpu, /// Token shift LoRA intermediate transposed, `[32, T, 5]`. - pub time_mix_t: TensorGpu, + pub time_mix_t: TensorGpu, /// Token shift LoRA output, `[C, T, 5]`. - pub time_mix: TensorGpu, + pub time_mix: TensorGpu, pub time_decay: TensorGpu, - pub ffn_x: TensorGpu, - pub ffn_kx: TensorGpu, - pub ffn_rx: TensorGpu, - pub ffn_k: TensorGpu, - pub ffn_v: TensorGpu, - pub ffn_r: TensorGpu, + pub ffn_x: TensorGpu, + pub ffn_kx: TensorGpu, + pub ffn_rx: TensorGpu, + pub ffn_k: TensorGpu, + pub ffn_v: TensorGpu, + pub ffn_r: TensorGpu, pub aux_x: TensorGpu, } -impl Runtime { +impl Runtime { pub fn new(context: &Context, info: &ModelInfo, num_token: usize, max_token: usize) -> Self { let shape = Shape::new(info.num_emb, num_token, 1, 1); let tokens_shape = Shape::new(num_token, 1, 1, 1); let cursors_shape = Shape::new(max_token, 1, 1, 1); let hidden_shape = Shape::new(info.num_hidden, num_token, 1, 1); let time_mix_shape = Shape::new(info.num_emb, num_token, 5, 1); - let time_mix_x_shape = Shape::new(Model::TIME_MIX_ADAPTER_SIZE, 5, num_token, 1); - let time_mix_t_shape = Shape::new(Model::TIME_MIX_ADAPTER_SIZE, num_token, 5, 1); - let time_decay_shape = Shape::new(Model::TIME_DECAY_ADAPTER_SIZE, num_token, 1, 1); + let time_mix_x_shape = Shape::new(Model::::TIME_MIX_ADAPTER_SIZE, 5, num_token, 1); + let time_mix_t_shape = Shape::new(Model::::TIME_MIX_ADAPTER_SIZE, num_token, 5, 1); + let time_decay_shape = Shape::new(Model::::TIME_DECAY_ADAPTER_SIZE, num_token, 1, 1); Self { tokens: context.tensor_init(tokens_shape), @@ -504,7 +505,7 @@ impl super::BackedState for BackedState { } } -impl<'a> Model<'a> { +impl<'a, F: Float> Model<'a, F> { const TIME_MIX_ADAPTER_SIZE: usize = 32; const TIME_DECAY_ADAPTER_SIZE: usize = 64; @@ -512,7 +513,7 @@ impl<'a> Model<'a> { pub const GN_EPS: f32 = 64.0e-5; } -impl<'a> FromBuilder for Model<'a> { +impl<'a, F: Float> FromBuilder for Model<'a, F> { type Builder<'b> = ModelBuilder<'b>; type Error = anyhow::Error; @@ -707,9 +708,7 @@ impl<'a> FromBuilder for Model<'a> { } } -impl<'a> ModelBase for Model<'a> { - type ModelTensor = ModelTensor<'a>; - +impl<'a, F: Float> ModelBase for Model<'a, F> { #[inline] fn context(&self) -> &Context { &self.context @@ -719,14 +718,9 @@ impl<'a> ModelBase for Model<'a> { fn info(&self) -> &ModelInfo { &self.info } - - #[inline] - fn tensor(&self) -> &Self::ModelTensor { - &self.tensor - } } -impl ModelSoftmaxInternal for Model<'_> { +impl ModelSoftmaxInternal for Model<'_, F> { #[inline] fn checkout_softmax(&self, num_batch: usize) -> Arc { self.softmax_cache.checkout(num_batch, || { @@ -735,20 +729,27 @@ impl ModelSoftmaxInternal for Model<'_> { } } -impl ModelRunInternal for Model<'_> { +impl<'a, F: Float> ModelRunInternal for Model<'a, F> { type Hook = Hook; - type Runtime = Runtime; type State = ModelState; + type Tensor = ModelTensor<'a>; + type Runtime = Runtime; + type Header = Header; + + #[inline] + fn tensor(&self) -> &Self::Tensor { + &self.tensor + } #[inline] - fn checkout_runtime(&self, num_token: usize) -> Arc { + fn checkout_runtime(&self, num_token: usize) -> Arc { self.runtime_cache.checkout(num_token, || { Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) }) } #[inline] - fn checkout_header(&self, num_batch: usize) -> Arc
{ + fn checkout_header(&self, num_batch: usize) -> Arc { self.header_cache.checkout(num_batch, || { Header::new(&self.context, &self.info, num_batch) }) @@ -769,7 +770,7 @@ impl ModelRunInternal for Model<'_> { tokens: Vec>, state: &ModelState, outputs: Vec>, - hooks: &HookMap, + hooks: &HookMap, ) -> Result<(TensorGpu, Vec>), TensorError> { let context = &self.context; let tensor = &self.tensor; @@ -812,7 +813,7 @@ impl ModelRunInternal for Model<'_> { let hook_op = |hook: Hook| -> Result { hooks .get(&hook) - .map(|f| f(self, state, &buffer)) + .map(|f| f(&self.tensor, state, &buffer, &header)) .unwrap_or_else(|| Ok(TensorOp::List(vec![]))) }; diff --git a/src/num.rs b/src/num.rs index 7921ae9..916a902 100644 --- a/src/num.rs +++ b/src/num.rs @@ -95,7 +95,7 @@ impl Scalar for u32 { const DATA_TYPE: Dtype = Dtype::U32; } -pub trait Float: Scalar { +pub trait Float: Scalar + Hom + Hom { const DEF: &'static str; } @@ -107,6 +107,51 @@ impl Float for f16 { const DEF: &'static str = "FP16"; } +pub trait Hom { + fn hom(self) -> T; + fn co_hom(value: T) -> Self; +} + +impl Hom for f32 { + fn hom(self) -> f32 { + self + } + + fn co_hom(value: f32) -> Self { + value + } +} + +impl Hom for f32 { + fn hom(self) -> f16 { + f16::from_f32(self) + } + + fn co_hom(value: f16) -> Self { + value.to_f32() + } +} + +impl Hom for f16 { + fn hom(self) -> f32 { + self.to_f32() + } + + fn co_hom(value: f32) -> Self { + Self::from_f32(value) + } +} + +impl Hom for f16 { + fn hom(self) -> f16 { + self + } + + fn co_hom(value: f16) -> Self { + value + } +} + mod sealed { use half::f16; diff --git a/src/shaders/matmul_mat_fp16.wgsl b/src/shaders/matmul_mat_fp16.wgsl index 464d559..62c2866 100644 --- a/src/shaders/matmul_mat_fp16.wgsl +++ b/src/shaders/matmul_mat_fp16.wgsl @@ -16,7 +16,11 @@ struct Input { @group(0) @binding(2) var destination: View; // [M, N, B] @group(0) @binding(3) var xa: array>; // (B, M, K) +#ifdef IN_FP16 @group(0) @binding(4) var xb: array>; // (B, N, K) +#else +@group(0) @binding(4) var xb: array>; // (B, N, K) +#endif #ifdef OUT_FP16 @group(0) @binding(5) var output: array>; // (B, N, M) #else @@ -26,7 +30,11 @@ struct Input { const TILE_SIZE: u32 = BLOCK_SIZE * 4u; var sa: array, BLOCK_SIZE>, TILE_SIZE>; +#ifdef IN_FP16 var sb: array, BLOCK_SIZE>, TILE_SIZE>; +#else +var sb: array, BLOCK_SIZE>, TILE_SIZE>; +#endif fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 { let stride = view.stride.x >> 2u; @@ -73,7 +81,11 @@ fn matmul(in: Input) { if all(vec2(x, y) < rb) { sb[j][i] = xb[compute_index(vb, in.uid.z, y, x)]; } else { +#ifdef IN_FP16 sb[j][i] = vec2(0u); +#else + sb[j][i] = vec4(0.0); +#endif } } workgroupBarrier(); @@ -90,12 +102,21 @@ fn matmul(in: Input) { unpack4x16float(sa[t.x + 2u][x]), unpack4x16float(sa[t.x + 3u][x]), ); +#ifdef IN_FP16 let bb = mat4x4( unpack4x16float(sb[t.y][x]), unpack4x16float(sb[t.y + 1u][x]), unpack4x16float(sb[t.y + 2u][x]), unpack4x16float(sb[t.y + 3u][x]), ); +#else + let bb = mat4x4( + sb[t.y][x], + sb[t.y + 1u][x], + sb[t.y + 2u][x], + sb[t.y + 3u][x], + ); +#endif local_sum += transpose(aa) * bb; } } diff --git a/src/shaders/matmul_mat_int8.wgsl b/src/shaders/matmul_mat_int8.wgsl index 6851520..d2099ec 100644 --- a/src/shaders/matmul_mat_int8.wgsl +++ b/src/shaders/matmul_mat_int8.wgsl @@ -17,7 +17,11 @@ struct Input { @group(0) @binding(3) var minmax: array; @group(0) @binding(4) var xa: array; // (B, M, K) +#ifdef IN_FP16 @group(0) @binding(5) var xb: array>; // (B, N, K) +#else +@group(0) @binding(5) var xb: array>; // (B, N, K) +#endif #ifdef OUT_FP16 @group(0) @binding(6) var output: array>; // (B, N, M) #else @@ -28,7 +32,11 @@ const TILE_SIZE: u32 = BLOCK_SIZE * 4u; const INT8_BLOCK_STEP: u32 = INT8_BLOCK_SIZE / 4u; var sa: array, TILE_SIZE>; +#ifdef IN_FP16 var sb: array, BLOCK_SIZE>, TILE_SIZE>; +#else +var sb: array, BLOCK_SIZE>, TILE_SIZE>; +#endif fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 { let stride = view.stride.x >> 2u; @@ -80,7 +88,11 @@ fn matmul(in: Input) { if all(vec2(x, y) < rb) { sb[j][i] = xb[compute_index(vb, in.uid.z, y, x)]; } else { +#ifdef IN_FP16 sb[j][i] = vec2(0u); +#else + sb[j][i] = vec4(0.0); +#endif } } workgroupBarrier(); @@ -104,12 +116,21 @@ fn matmul(in: Input) { fma(unpack4x8unorm(sa[t.x + 2u][x]), vec4(b[2][1] - b[2][0]), vec4(b[2][0])), fma(unpack4x8unorm(sa[t.x + 3u][x]), vec4(b[3][1] - b[3][0]), vec4(b[3][0])), ); +#ifdef IN_FP16 let bb = mat4x4( unpack4x16float(sb[t.y][x]), unpack4x16float(sb[t.y + 1u][x]), unpack4x16float(sb[t.y + 2u][x]), unpack4x16float(sb[t.y + 3u][x]), ); +#else + let bb = mat4x4( + sb[t.y][x], + sb[t.y + 1u][x], + sb[t.y + 2u][x], + sb[t.y + 3u][x], + ); +#endif local_sum += transpose(aa) * bb; } } diff --git a/src/shaders/matmul_mat_nf4.wgsl b/src/shaders/matmul_mat_nf4.wgsl index ac7e411..ea6747c 100644 --- a/src/shaders/matmul_mat_nf4.wgsl +++ b/src/shaders/matmul_mat_nf4.wgsl @@ -18,7 +18,11 @@ struct Input { @group(0) @binding(4) var absmax: array; @group(0) @binding(5) var xa: array; // (B, M, K) +#ifdef IN_FP16 @group(0) @binding(6) var xb: array>; // (B, N, K) +#else +@group(0) @binding(6) var xb: array>; // (B, N, K) +#endif #ifdef OUT_FP16 @group(0) @binding(7) var output: array>; // (B, N, M) #else @@ -29,7 +33,11 @@ const TILE_SIZE: u32 = BLOCK_SIZE * 4u; const NF4_BLOCK_STEP: u32 = NF4_BLOCK_SIZE / 8u; var sa: array, TILE_SIZE>; +#ifdef IN_FP16 var sb: array, BLOCK_SIZE>, TILE_SIZE>; +#else +var sb: array, BLOCK_SIZE>, TILE_SIZE>; +#endif var q: array, 4u>; fn compute_index(view: View, batch: u32, token: u32, index: u32, step: u32) -> u32 { @@ -117,7 +125,11 @@ fn matmul(in: Input) { if all(vec2(x, y) < rb) { sb[j][i] = xb[compute_index(vb, in.uid.z, y, x, 8u)]; } else { +#ifdef IN_FP16 sb[j][i] = vec4(0u); +#else + sb[j][i] = mat2x4(); +#endif } } workgroupBarrier(); @@ -148,12 +160,21 @@ fn matmul(in: Input) { a[2] * unpack_matrix_0(la[2]), a[3] * unpack_matrix_0(la[3]), ); +#ifdef IN_FP16 var bb = mat4x4( unpack4x16float(sb[t.y][x].xy), unpack4x16float(sb[t.y + 1u][x].xy), unpack4x16float(sb[t.y + 2u][x].xy), unpack4x16float(sb[t.y + 3u][x].xy), ); +#else + var bb = mat4x4( + sb[t.y][x][0], + sb[t.y + 1u][x][0], + sb[t.y + 2u][x][0], + sb[t.y + 3u][x][0], + ); +#endif local_sum += transpose(aa) * bb; aa = mat4x4( @@ -162,12 +183,21 @@ fn matmul(in: Input) { a[2] * unpack_matrix_1(la[2]), a[3] * unpack_matrix_1(la[3]), ); +#ifdef IN_FP16 bb = mat4x4( unpack4x16float(sb[t.y][x].zw), unpack4x16float(sb[t.y + 1u][x].zw), unpack4x16float(sb[t.y + 2u][x].zw), unpack4x16float(sb[t.y + 3u][x].zw), ); +#else + bb = mat4x4( + sb[t.y][x][1], + sb[t.y + 1u][x][1], + sb[t.y + 2u][x][1], + sb[t.y + 3u][x][1], + ); +#endif local_sum += transpose(aa) * bb; } } diff --git a/src/shaders/matmul_vec_int8.wgsl b/src/shaders/matmul_vec_int8.wgsl index 6079953..fe6c4e1 100644 --- a/src/shaders/matmul_vec_int8.wgsl +++ b/src/shaders/matmul_vec_int8.wgsl @@ -11,7 +11,11 @@ struct View { @group(0) @binding(3) var matrix: array; // (B, R, C) @group(0) @binding(4) var minmax: array; +#ifdef IN_FP16 @group(0) @binding(5) var input: array>; // (B, T, C) +#else +@group(0) @binding(5) var input: array>; // (B, T, C) +#endif #ifdef OUT_FP16 @group(0) @binding(6) var output: array>; // (B, T, R) #else diff --git a/src/shaders/matmul_vec_nf4.wgsl b/src/shaders/matmul_vec_nf4.wgsl index d0eb5b3..d483e40 100644 --- a/src/shaders/matmul_vec_nf4.wgsl +++ b/src/shaders/matmul_vec_nf4.wgsl @@ -12,7 +12,11 @@ struct View { @group(0) @binding(4) var matrix: array; // (B, R, C) @group(0) @binding(5) var absmax: array; +#ifdef IN_FP16 @group(0) @binding(6) var input: array>; // (B, T, C) +#else +@group(0) @binding(6) var input: array>; // (B, T, C) +#endif #ifdef OUT_FP16 @group(0) @binding(7) var output: array>; // (B, T, R) #else @@ -120,18 +124,22 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3) { m[2] = unpack_matrix_0(v[2]); m[3] = unpack_matrix_0(v[3]); m = transpose(m); - // var s = transpose(m) * unpack4x16float(x.xy); +#ifdef IN_FP16 local_sum = fma(m * unpack4x16float(x.xy), a, local_sum); +#else + local_sum = fma(m * x[0], a, local_sum); +#endif m[0] = unpack_matrix_1(v[0]); m[1] = unpack_matrix_1(v[1]); m[2] = unpack_matrix_1(v[2]); m[3] = unpack_matrix_1(v[3]); m = transpose(m); - // s += transpose(m) * unpack4x16float(x.zw); +#ifdef IN_FP16 local_sum = fma(m * unpack4x16float(x.zw), a, local_sum); - - // local_sum = fma(s, a, local_sum); +#else + local_sum = fma(m * x[1], a, local_sum); +#endif } sketch[index] = local_sum; workgroupBarrier(); diff --git a/src/tensor/matrix.rs b/src/tensor/matrix.rs index 0d2f30e..b7a8bcc 100644 --- a/src/tensor/matrix.rs +++ b/src/tensor/matrix.rs @@ -28,7 +28,7 @@ pub enum Matrix { impl Matrix { pub fn matmul_vec_op( &self, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, ) -> Result { @@ -41,7 +41,7 @@ impl Matrix { pub fn matmul_mat_op( &self, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, ) -> Result { @@ -60,7 +60,7 @@ impl Matrix { pub fn matmul_op( &self, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, turbo: bool, diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index e87ea0c..587f54e 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -540,7 +540,7 @@ impl TensorOp { matrix: &TensorGpu, quant: &TensorGpu, absmax: &TensorGpu, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, ) -> Result { @@ -622,7 +622,7 @@ impl TensorOp { /// Note: `K` must be multiples of 128; `M` and `N` must be multiples of 4. pub fn matmul_mat_fp16( matrix: TensorView, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, ) -> Result { @@ -695,7 +695,7 @@ impl TensorOp { pub fn matmul_mat_int8( matrix: TensorView, minmax: &TensorGpu, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, ) -> Result { @@ -779,7 +779,7 @@ impl TensorOp { matrix: TensorView, quant: &TensorGpu, absmax: &TensorGpu, - input: TensorView, + input: TensorView, output: TensorView, active: Activation, ) -> Result {