diff --git a/Cargo.toml b/Cargo.toml index d2fae57..36e7853 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"] license = "MIT OR Apache-2.0" name = "web-rwkv" repository = "https://github.com/cryscan/web-rwkv" -version = "0.6.24" +version = "0.6.25" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/context.rs b/src/context.rs index 7968843..340461a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,4 +1,8 @@ -use std::{borrow::Cow, sync::Arc}; +use std::{ + borrow::Cow, + collections::HashMap, + sync::{Arc, Mutex}, +}; use thiserror::Error; use wasm_bindgen::prelude::wasm_bindgen; @@ -6,9 +10,9 @@ use web_rwkv_derive::{Deref, DerefMut}; use wgpu::{ util::{BufferInitDescriptor, DeviceExt}, Adapter, Backends, BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, - BufferUsages, ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, Features, - Limits, PipelineLayoutDescriptor, PowerPreference, Queue, RequestAdapterOptions, - ShaderModuleDescriptor, + BufferDescriptor, BufferUsages, ComputePipeline, ComputePipelineDescriptor, Device, + DeviceDescriptor, Features, Limits, PipelineLayoutDescriptor, PowerPreference, Queue, + RequestAdapterOptions, ShaderModuleDescriptor, }; use crate::{ @@ -76,10 +80,8 @@ pub struct ContextInternal { pub device: Device, pub queue: Queue, - pipeline_cache: ResourceCache, - - shape_cache: ResourceCache, - view_cache: ResourceCache, + pipeline_cache: Mutex>>, + buffer_cache: ResourceCache, } #[derive(Debug, Clone, Deref, DerefMut)] @@ -134,9 +136,8 @@ impl<'a> ContextBuilder { adapter, device, queue, - pipeline_cache: ResourceCache::new(0), - shape_cache: Default::default(), - view_cache: Default::default(), + pipeline_cache: Default::default(), + buffer_cache: Default::default(), } .into(), )) @@ -234,58 +235,111 @@ impl Context { let mut context = Context::new(); context.macros = macros.0.into_iter().collect(); - self.pipeline_cache.checkout(key, move || { - let shader = process_str(source.as_ref(), &mut context).unwrap(); - let module = &self.device.create_shader_module(ShaderModuleDescriptor { - label: Some(name), - source: wgpu::ShaderSource::Wgsl(Cow::from(shader)), - }); + let mut cache = self.pipeline_cache.lock().unwrap(); + match cache.get(&key) { + Some(pipeline) => pipeline.clone(), + None => { + let shader = process_str(source.as_ref(), &mut context).unwrap(); + let module = &self.device.create_shader_module(ShaderModuleDescriptor { + label: Some(name), + source: wgpu::ShaderSource::Wgsl(Cow::from(shader)), + }); + + let layout = layout.map(|entries| { + let layout = self + .device + .create_bind_group_layout(&BindGroupLayoutDescriptor { + label: None, + entries, + }); + self.device + .create_pipeline_layout(&PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&layout], + push_constant_ranges: &[], + }) + }); - let layout = layout.map(|entries| { - let layout = self + let pipeline = self .device - .create_bind_group_layout(&BindGroupLayoutDescriptor { - label: None, - entries, + .create_compute_pipeline(&ComputePipelineDescriptor { + label: Some(name), + layout: layout.as_ref(), + module, + entry_point, }); - self.device - .create_pipeline_layout(&PipelineLayoutDescriptor { - label: None, - bind_group_layouts: &[&layout], - push_constant_ranges: &[], - }) - }); - - let pipeline = self - .device - .create_compute_pipeline(&ComputePipelineDescriptor { - label: Some(name), - layout: layout.as_ref(), - module, - entry_point, - }); - let layout = pipeline.get_bind_group_layout(0); - CachedPipeline { pipeline, layout } - }) + let layout = pipeline.get_bind_group_layout(0); + let pipeline = Arc::new(CachedPipeline { pipeline, layout }); + + cache.insert(key, pipeline.clone()); + pipeline + } + } } pub fn checkout_shape_uniform(&self, shape: Shape) -> Arc { - self.shape_cache.checkout(shape, || { - self.device.create_buffer_init(&BufferInitDescriptor { - label: None, - contents: &shape.into_bytes(), - usage: BufferUsages::UNIFORM, - }) - }) + let view = View { + shape, + stride: shape, + offset: Shape::new(0, 0, 0, 0), + }; + self.buffer_cache.checkout( + shape.into_bytes().len(), + |buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()), + || { + self.device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents: &view.into_bytes(), + usage: BufferUsages::UNIFORM, + }) + }, + ) } pub fn checkout_view_uniform(&self, view: View) -> Arc { - self.view_cache.checkout(view, || { - self.device.create_buffer_init(&BufferInitDescriptor { - label: None, - contents: &view.into_bytes(), - usage: BufferUsages::UNIFORM, - }) - }) + self.buffer_cache.checkout( + view.into_bytes().len(), + |buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()), + || { + self.device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents: &view.into_bytes(), + usage: BufferUsages::UNIFORM, + }) + }, + ) + } + + pub fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc { + self.buffer_cache.checkout( + contents.len(), + |buffer| { + if usage.contains(BufferUsages::STORAGE) { + self.queue.write_buffer(buffer, 0, contents); + } + }, + || { + self.device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents, + usage, + }) + }, + ) + } + + pub fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc { + self.buffer_cache.checkout( + size, + |_| (), + || { + self.device.create_buffer(&BufferDescriptor { + label: None, + size: size as u64, + usage, + mapped_at_creation: false, + }) + }, + ) } } diff --git a/src/model/loader.rs b/src/model/loader.rs index ba69e1f..45e4ce3 100644 --- a/src/model/loader.rs +++ b/src/model/loader.rs @@ -11,7 +11,6 @@ use super::{ModelError, ModelInfo, ModelVersion, Quant}; use crate::{ context::Context, tensor::{ - cache::ResourceCache, kind::ReadWrite, matrix::Matrix, ops::{TensorCommand, TensorOp, TensorPass}, @@ -575,24 +574,19 @@ impl Loader { Ok(head) } - pub async fn load_matrix( - &self, - cache: &ResourceCache>, - name: String, - quant: Quant, - ) -> Result { + pub async fn load_matrix(&self, name: String, quant: Quant) -> Result { let context = &self.context; match quant { Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16(name).await?)), Quant::Int8 => { let shape = self.tensor_shape(&name)?; - let buffer = cache.checkout(shape, || context.tensor_init(shape)); + let buffer = context.tensor_init(shape); self.load_in_place_matrix_f16(&buffer, &name).await?; Ok(Matrix::quant_u8(&buffer)?) } Quant::NF4 => { let shape = self.tensor_shape(&name)?; - let buffer = cache.checkout(shape, || context.tensor_init(shape)); + let buffer = context.tensor_init(shape); self.load_in_place_matrix_f16(&buffer, &name).await?; Ok(Matrix::quant_nf4(&buffer)?) } @@ -601,7 +595,6 @@ impl Loader { pub async fn load_matrix_discount( &self, - cache: &ResourceCache>, name: String, quant: Quant, discount: f32, @@ -613,14 +606,14 @@ impl Loader { )), Quant::Int8 => { let shape = self.tensor_shape(&name)?; - let buffer = cache.checkout(shape, || context.tensor_init(shape)); + let buffer = context.tensor_init(shape); self.load_in_place_matrix_f16_discount(&buffer, &name, discount) .await?; Ok(Matrix::quant_u8(&buffer)?) } Quant::NF4 => { let shape = self.tensor_shape(&name)?; - let buffer = cache.checkout(shape, || context.tensor_init(shape)); + let buffer = context.tensor_init(shape); self.load_in_place_matrix_f16_discount(&buffer, &name, discount) .await?; Ok(Matrix::quant_nf4(&buffer)?) diff --git a/src/model/run.rs b/src/model/run.rs index 36f509b..7d9be59 100644 --- a/src/model/run.rs +++ b/src/model/run.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc}; +use std::{collections::HashMap, future::Future, hash::Hash}; use anyhow::Result; use half::f16; @@ -50,8 +50,8 @@ pub(crate) trait ModelRunInternal: ModelBase { fn tensor(&self) -> &Self::Tensor; - fn checkout_runtime(&self, num_batch: usize) -> Arc; - fn checkout_header(&self, num_batch: usize) -> Arc; + fn checkout_runtime(&self, num_batch: usize) -> Self::Runtime; + fn checkout_header(&self, num_batch: usize) -> Self::Header; /// To prevent the GPU device from lost, this limits the maximum batch-token it processes one time. fn token_chunk_size(&self) -> usize; diff --git a/src/model/softmax.rs b/src/model/softmax.rs index a9455ab..9914bb1 100644 --- a/src/model/softmax.rs +++ b/src/model/softmax.rs @@ -1,4 +1,4 @@ -use std::{future::Future, sync::Arc}; +use std::future::Future; use anyhow::Result; use itertools::Itertools; @@ -30,10 +30,6 @@ impl Softmax { } } -pub(crate) trait ModelSoftmaxInternal: ModelBase { - fn checkout_softmax(&self, num_batch: usize) -> Arc; -} - pub trait ModelSoftmax { /// Softmax of the input tensors. fn softmax( @@ -42,7 +38,7 @@ pub trait ModelSoftmax { ) -> impl Future, TensorError>>; } -impl ModelSoftmax for Model { +impl ModelSoftmax for M { async fn softmax(&self, input: Vec) -> Result, TensorError> { let context = self.context(); let info = self.info(); @@ -78,7 +74,7 @@ impl ModelSoftmax for Model { )?; let num_batch = input.shape()[2]; - let softmax = self.checkout_softmax(num_batch); + let softmax = Softmax::new(context, info, num_batch); softmax.buffer.load(&input)?; let op = TensorOp::softmax(&softmax.buffer)?; diff --git a/src/model/v4.rs b/src/model/v4.rs index 41f047b..0897f4c 100644 --- a/src/model/v4.rs +++ b/src/model/v4.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, sync::Arc}; +use std::{convert::Infallible, marker::PhantomData}; use anyhow::Result; use half::f16; @@ -9,7 +9,6 @@ use web_rwkv_derive::{Deref, DerefMut}; use super::{ loader::Reader, run::{Header, HookMap, ModelRunInternal}, - softmax::{ModelSoftmaxInternal, Softmax}, Build, BuildFuture, ModelBase, ModelBuilder, ModelInfo, OutputType, PreparedModelBuilder, Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE, }; @@ -18,12 +17,12 @@ use crate::{ model::RESCALE_LAYER, num::{Float, Hom}, tensor::{ - cache::ResourceCache, kind::{ReadBack, ReadWrite}, matrix::Matrix, ops::{Activation, TensorCommand, TensorOp, TensorPass}, shape::Shape, - DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorShape, TensorView, + DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, + TensorShape, }, }; @@ -38,9 +37,7 @@ pub struct Model<'a, F: Float> { token_chunk_size: usize, tensor: ModelTensor<'a>, - runtime_cache: ResourceCache>, - header_cache: ResourceCache>, - softmax_cache: ResourceCache, + _phantom: PhantomData, } #[derive(Debug)] @@ -197,13 +194,13 @@ impl ModelState { self.0.context() } - fn att(&self, layer: usize) -> Result, TensorError> { + fn att(&self, layer: usize) -> Result, TensorError> { let start = 5 * layer; let end = start + 4; self.view(.., start..end, .., ..) } - fn ffn(&self, layer: usize) -> Result, TensorError> { + fn ffn(&self, layer: usize) -> Result, TensorError> { let start = 5 * layer + 4; self.view(.., start..=start, .., ..) } @@ -441,19 +438,15 @@ impl<'a, R: Reader, F: Float> BuildFuture> for ModelBuilder { context.queue.submit(None); context.device.poll(wgpu::MaintainBase::Wait); - let cache = ResourceCache::>::new(0); - let load_matrix = |name: String, quant: Quant| loader.load_matrix(&cache, name, quant); + let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { - loader.load_matrix_discount(&cache, name, quant, discount) + loader.load_matrix_discount(name, quant, discount) }; let mut layers = vec![]; for layer in 0..info.num_layer { let quant = quant.get(&layer).copied().unwrap_or_default(); let discount = 2.0_f32.powi(-((layer / RESCALE_LAYER) as i32)); - if matches!(quant, Quant::None) { - cache.clear(); - } let att_layer_norm = LayerNorm { w: loader @@ -531,9 +524,7 @@ impl<'a, R: Reader, F: Float> BuildFuture> for ModelBuilder { turbo, token_chunk_size, tensor, - runtime_cache: ResourceCache::new(1), - header_cache: ResourceCache::new(1), - softmax_cache: ResourceCache::new(1), + _phantom: PhantomData, }) } } @@ -550,15 +541,6 @@ impl<'a, F: Float> ModelBase for Model<'a, F> { } } -impl ModelSoftmaxInternal for Model<'_, F> { - #[inline] - fn checkout_softmax(&self, num_batch: usize) -> Arc { - self.softmax_cache.checkout(num_batch, || { - Softmax::new(&self.context, &self.info, num_batch) - }) - } -} - impl<'a, F: Float + Hom> ModelRunInternal for Model<'a, F> { type Hook = Hook; type State = ModelState; @@ -572,17 +554,13 @@ impl<'a, F: Float + Hom> ModelRunInternal for Model<'a, F> { } #[inline] - fn checkout_runtime(&self, num_token: usize) -> Arc { - self.runtime_cache.checkout(num_token, || { - Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) - }) + fn checkout_runtime(&self, num_token: usize) -> Self::Runtime { + Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) } #[inline] - fn checkout_header(&self, num_batch: usize) -> Arc { - self.header_cache.checkout(num_batch, || { - Header::new(&self.context, &self.info, num_batch) - }) + fn checkout_header(&self, num_batch: usize) -> Self::Header { + Header::new(&self.context, &self.info, num_batch) } #[inline] diff --git a/src/model/v5.rs b/src/model/v5.rs index dbbd226..63cce50 100644 --- a/src/model/v5.rs +++ b/src/model/v5.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, sync::Arc}; +use std::{convert::Infallible, marker::PhantomData}; use anyhow::Result; use half::f16; @@ -8,7 +8,6 @@ use serde::{Deserialize, Serialize}; use super::{ loader::Reader, run::{Header, HookMap, ModelRunInternal}, - softmax::{ModelSoftmaxInternal, Softmax}, Build, BuildFuture, ModelBase, ModelBuilder, ModelInfo, PreparedModelBuilder, Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE, }; @@ -17,13 +16,12 @@ use crate::{ model::{OutputType, RESCALE_LAYER}, num::Float, tensor::{ - cache::ResourceCache, kind::{ReadBack, ReadWrite}, matrix::Matrix, ops::{Activation, TensorCommand, TensorOp, TensorPass}, shape::{Shape, TensorDimension}, - DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorReshape, - TensorShape, TensorView, + DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, + TensorReshape, TensorShape, }, }; @@ -38,9 +36,7 @@ pub struct Model<'a, F: Float> { token_chunk_size: usize, tensor: ModelTensor<'a>, - runtime_cache: ResourceCache>, - header_cache: ResourceCache>, - softmax_cache: ResourceCache, + _phantom: PhantomData, } #[derive(Debug)] @@ -208,7 +204,7 @@ pub struct ModelState { } impl ModelState { - fn att(&self, layer: usize) -> Result, TensorError> { + fn att(&self, layer: usize) -> Result, TensorError> { let chunk = layer / self.chunk_size; let offset = layer % self.chunk_size; let head_size = self.info.num_emb / self.info.num_head; @@ -218,7 +214,7 @@ impl ModelState { self.state[chunk].view(.., start..end, .., ..) } - fn ffn(&self, layer: usize) -> Result, TensorError> { + fn ffn(&self, layer: usize) -> Result, TensorError> { let chunk = layer / self.chunk_size; let offset = layer % self.chunk_size; let head_size = self.info.num_emb / self.info.num_head; @@ -505,19 +501,15 @@ impl<'a, R: Reader, F: Float> BuildFuture> for ModelBuilder { context.queue.submit(None); context.device.poll(wgpu::MaintainBase::Wait); - let cache = ResourceCache::>::new(0); - let load_matrix = |name: String, quant: Quant| loader.load_matrix(&cache, name, quant); + let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { - loader.load_matrix_discount(&cache, name, quant, discount) + loader.load_matrix_discount(name, quant, discount) }; let mut layers = vec![]; for layer in 0..info.num_layer { let quant = quant.get(&layer).copied().unwrap_or_default(); let discount = 2.0_f32.powi(-((layer / RESCALE_LAYER) as i32)); - if matches!(quant, Quant::None) { - cache.clear(); - } let att_layer_norm = LayerNorm { w: loader @@ -620,9 +612,7 @@ impl<'a, R: Reader, F: Float> BuildFuture> for ModelBuilder { turbo, token_chunk_size, tensor, - runtime_cache: ResourceCache::new(1), - header_cache: ResourceCache::new(1), - softmax_cache: ResourceCache::new(1), + _phantom: PhantomData, }) } } @@ -639,15 +629,6 @@ impl<'a, F: Float> ModelBase for Model<'a, F> { } } -impl ModelSoftmaxInternal for Model<'_, F> { - #[inline] - fn checkout_softmax(&self, num_batch: usize) -> Arc { - self.softmax_cache.checkout(num_batch, || { - Softmax::new(&self.context, &self.info, num_batch) - }) - } -} - impl<'a, F: Float> ModelRunInternal for Model<'a, F> { type Hook = Hook; type State = ModelState; @@ -661,17 +642,13 @@ impl<'a, F: Float> ModelRunInternal for Model<'a, F> { } #[inline] - fn checkout_runtime(&self, num_token: usize) -> Arc { - self.runtime_cache.checkout(num_token, || { - Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) - }) + fn checkout_runtime(&self, num_token: usize) -> Self::Runtime { + Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) } #[inline] - fn checkout_header(&self, num_batch: usize) -> Arc { - self.header_cache.checkout(num_batch, || { - Header::new(&self.context, &self.info, num_batch) - }) + fn checkout_header(&self, num_batch: usize) -> Self::Header { + Header::new(&self.context, &self.info, num_batch) } #[inline] diff --git a/src/model/v6.rs b/src/model/v6.rs index 6ba1324..586cf2b 100644 --- a/src/model/v6.rs +++ b/src/model/v6.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, sync::Arc}; +use std::{convert::Infallible, marker::PhantomData}; use anyhow::Result; use half::f16; @@ -8,7 +8,6 @@ use serde::{Deserialize, Serialize}; use super::{ loader::Reader, run::{Header, HookMap, ModelRunInternal}, - softmax::{ModelSoftmaxInternal, Softmax}, Build, BuildFuture, ModelBase, ModelBuilder, ModelInfo, OutputType, PreparedModelBuilder, Quant, StateBuilder, MIN_TOKEN_CHUNK_SIZE, }; @@ -17,13 +16,12 @@ use crate::{ model::RESCALE_LAYER, num::Float, tensor::{ - cache::ResourceCache, kind::{ReadBack, ReadWrite}, matrix::Matrix, ops::{Activation, TensorCommand, TensorOp, TensorPass}, shape::{Shape, TensorDimension}, - DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorReshape, - TensorShape, TensorView, + DeepClone, IntoPackedCursors, TensorCpu, TensorError, TensorGpu, TensorGpuView, + TensorReshape, TensorShape, }, }; @@ -38,9 +36,7 @@ pub struct Model<'a, F: Float> { token_chunk_size: usize, tensor: ModelTensor<'a>, - runtime_cache: ResourceCache>, - header_cache: ResourceCache>, - softmax_cache: ResourceCache, + _phantom: PhantomData, } #[derive(Debug)] @@ -249,7 +245,7 @@ pub struct ModelState { } impl ModelState { - fn att(&self, layer: usize) -> Result, TensorError> { + fn att(&self, layer: usize) -> Result, TensorError> { let chunk = layer / self.chunk_size; let offset = layer % self.chunk_size; let head_size = self.info.num_emb / self.info.num_head; @@ -259,7 +255,7 @@ impl ModelState { self.state[chunk].view(.., start..end, .., ..) } - fn ffn(&self, layer: usize) -> Result, TensorError> { + fn ffn(&self, layer: usize) -> Result, TensorError> { let chunk = layer / self.chunk_size; let offset = layer % self.chunk_size; let head_size = self.info.num_emb / self.info.num_head; @@ -549,19 +545,15 @@ impl<'a, R: Reader, F: Float> BuildFuture> for ModelBuilder { context.queue.submit(None); context.device.poll(wgpu::MaintainBase::Wait); - let cache = ResourceCache::>::new(0); - let load_matrix = |name: String, quant: Quant| loader.load_matrix(&cache, name, quant); + let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { - loader.load_matrix_discount(&cache, name, quant, discount) + loader.load_matrix_discount(name, quant, discount) }; let mut layers = vec![]; for layer in 0..info.num_layer { let quant = quant.get(&layer).copied().unwrap_or_default(); let discount = 2.0_f32.powi(-((layer / RESCALE_LAYER) as i32)); - if matches!(quant, Quant::None) { - cache.clear(); - } let att_layer_norm = LayerNorm { w: loader @@ -680,9 +672,7 @@ impl<'a, R: Reader, F: Float> BuildFuture> for ModelBuilder { turbo, token_chunk_size, tensor, - runtime_cache: ResourceCache::new(1), - header_cache: ResourceCache::new(1), - softmax_cache: ResourceCache::new(1), + _phantom: PhantomData, }) } } @@ -699,15 +689,6 @@ impl<'a, F: Float> ModelBase for Model<'a, F> { } } -impl ModelSoftmaxInternal for Model<'_, F> { - #[inline] - fn checkout_softmax(&self, num_batch: usize) -> Arc { - self.softmax_cache.checkout(num_batch, || { - Softmax::new(&self.context, &self.info, num_batch) - }) - } -} - impl<'a, F: Float> ModelRunInternal for Model<'a, F> { type Hook = Hook; type State = ModelState; @@ -721,17 +702,13 @@ impl<'a, F: Float> ModelRunInternal for Model<'a, F> { } #[inline] - fn checkout_runtime(&self, num_token: usize) -> Arc { - self.runtime_cache.checkout(num_token, || { - Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) - }) + fn checkout_runtime(&self, num_token: usize) -> Self::Runtime { + Runtime::new(&self.context, &self.info, num_token, self.token_chunk_size) } #[inline] - fn checkout_header(&self, num_batch: usize) -> Arc { - self.header_cache.checkout(num_batch, || { - Header::new(&self.context, &self.info, num_batch) - }) + fn checkout_header(&self, num_batch: usize) -> Self::Header { + Header::new(&self.context, &self.info, num_batch) } #[inline] diff --git a/src/shaders/binary.wgsl b/src/shaders/binary.wgsl index d25ce7b..d5411af 100644 --- a/src/shaders/binary.wgsl +++ b/src/shaders/binary.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; @group(0) @binding(0) var source: View; diff --git a/src/shaders/blend.wgsl b/src/shaders/blend.wgsl index 80d12f6..85b5c3a 100644 --- a/src/shaders/blend.wgsl +++ b/src/shaders/blend.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; @group(0) @binding(0) var source: View; diff --git a/src/shaders/blend_lora.wgsl b/src/shaders/blend_lora.wgsl index 637a6d5..4067c7f 100644 --- a/src/shaders/blend_lora.wgsl +++ b/src/shaders/blend_lora.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Input { diff --git a/src/shaders/blit.wgsl b/src/shaders/blit.wgsl index 00ba9a0..49d0af5 100644 --- a/src/shaders/blit.wgsl +++ b/src/shaders/blit.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; @group(0) @binding(0) var source: View; diff --git a/src/shaders/channel_mix.wgsl b/src/shaders/channel_mix.wgsl index 0ebe746..dc56754 100644 --- a/src/shaders/channel_mix.wgsl +++ b/src/shaders/channel_mix.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Cursor { diff --git a/src/shaders/matmul_mat_fp16.wgsl b/src/shaders/matmul_mat_fp16.wgsl index 62c2866..7e12c12 100644 --- a/src/shaders/matmul_mat_fp16.wgsl +++ b/src/shaders/matmul_mat_fp16.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Input { diff --git a/src/shaders/matmul_mat_int8.wgsl b/src/shaders/matmul_mat_int8.wgsl index d2099ec..226485d 100644 --- a/src/shaders/matmul_mat_int8.wgsl +++ b/src/shaders/matmul_mat_int8.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Input { diff --git a/src/shaders/matmul_mat_nf4.wgsl b/src/shaders/matmul_mat_nf4.wgsl index ea6747c..2af35d5 100644 --- a/src/shaders/matmul_mat_nf4.wgsl +++ b/src/shaders/matmul_mat_nf4.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Input { diff --git a/src/shaders/matmul_vec_fp16.wgsl b/src/shaders/matmul_vec_fp16.wgsl index add15ef..1cc3ead 100644 --- a/src/shaders/matmul_vec_fp16.wgsl +++ b/src/shaders/matmul_vec_fp16.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; @group(0) @binding(0) var shape: vec4; // [C, R, B] diff --git a/src/shaders/matmul_vec_int8.wgsl b/src/shaders/matmul_vec_int8.wgsl index fe6c4e1..5082e7d 100644 --- a/src/shaders/matmul_vec_int8.wgsl +++ b/src/shaders/matmul_vec_int8.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; @group(0) @binding(0) var shape: vec4; // [C, R, B] diff --git a/src/shaders/matmul_vec_nf4.wgsl b/src/shaders/matmul_vec_nf4.wgsl index d483e40..d1455bc 100644 --- a/src/shaders/matmul_vec_nf4.wgsl +++ b/src/shaders/matmul_vec_nf4.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; @group(0) @binding(0) var shape: vec4; // [C, R, B] diff --git a/src/shaders/time_mix_v4.wgsl b/src/shaders/time_mix_v4.wgsl index 1d54b77..cf9b1f3 100644 --- a/src/shaders/time_mix_v4.wgsl +++ b/src/shaders/time_mix_v4.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Cursor { diff --git a/src/shaders/time_mix_v5.wgsl b/src/shaders/time_mix_v5.wgsl index 04ba75e..33353aa 100644 --- a/src/shaders/time_mix_v5.wgsl +++ b/src/shaders/time_mix_v5.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Cursor { diff --git a/src/shaders/time_mix_v6.wgsl b/src/shaders/time_mix_v6.wgsl index 7f224f7..6765eb0 100644 --- a/src/shaders/time_mix_v6.wgsl +++ b/src/shaders/time_mix_v6.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Cursor { diff --git a/src/shaders/token_shift.wgsl b/src/shaders/token_shift.wgsl index 848e6a2..c91b39e 100644 --- a/src/shaders/token_shift.wgsl +++ b/src/shaders/token_shift.wgsl @@ -1,7 +1,7 @@ struct View { + shape: vec4, stride: vec4, offset: vec4, - shape: vec4, }; struct Cursor { diff --git a/src/tensor/cache.rs b/src/tensor/cache.rs index 7087f8c..797bcab 100644 --- a/src/tensor/cache.rs +++ b/src/tensor/cache.rs @@ -2,35 +2,38 @@ use std::{ collections::HashMap, hash::Hash, sync::{Arc, Mutex}, + time::{Duration, Instant}, }; #[derive(Debug)] struct CacheItem { value: Arc, - count: usize, + instant: Instant, } impl CacheItem { - fn make(f: impl FnOnce() -> V) -> Self { + fn new(value: Arc) -> Self { Self { - value: f().into(), - count: 0, + value, + instant: Instant::now(), } } + + fn count(&self) -> usize { + Arc::strong_count(&self.value) + } } -/// An LRU cache. -#[allow(clippy::type_complexity)] -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ResourceCache { - max_count: usize, - map: Arc>>>, + duration: Duration, + map: Mutex>>>, } impl Default for ResourceCache { fn default() -> Self { Self { - max_count: 16, + duration: Duration::from_secs(1), map: Default::default(), } } @@ -40,37 +43,37 @@ impl ResourceCache where K: PartialEq + Eq + Hash, { - /// Note: If `max_count` is 0, the cache won't evict any items. - pub fn new(max_count: usize) -> Self { + /// Note: If `duration` is 0, the cache won't evict any items. + pub fn new(duration: Duration) -> Self { Self { - max_count, + duration, map: Default::default(), } } /// Checkout the item with the given key. If the item doesn't exist, `f` is called to construct it. - pub fn checkout(&self, key: K, f: impl FnOnce() -> V) -> Arc { + pub fn checkout(&self, key: K, hit: impl FnOnce(&V), miss: impl FnOnce() -> V) -> Arc { let mut map = self.map.lock().unwrap(); - if self.max_count > 0 { - map.retain(|_, item| { - item.count += 1; - item.count <= self.max_count - }); + if !self.duration.is_zero() { + for (_, items) in map.iter_mut() { + items.retain(|item| item.count() > 1 && item.instant.elapsed() < self.duration); + } } - let CacheItem { value, .. } = map.remove(&key).unwrap_or_else(|| CacheItem::make(f)); - map.insert( - key, - CacheItem { - value: value.clone(), - count: 0, - }, - ); - value - } - /// Empty the cache. - pub fn clear(&self) { - let mut map = self.map.lock().unwrap(); - map.clear(); + match map + .get_mut(&key) + .and_then(|items| items.iter_mut().find(|item| item.count() == 1)) + { + Some(item) => { + hit(&item.value); + item.instant = Instant::now(); + item.value.clone() + } + None => { + let value = Arc::new(miss()); + map.insert(key, vec![CacheItem::new(value.clone())]); + value + } + } } } diff --git a/src/tensor/matrix.rs b/src/tensor/matrix.rs index b7a8bcc..390bdc3 100644 --- a/src/tensor/matrix.rs +++ b/src/tensor/matrix.rs @@ -7,7 +7,7 @@ use crate::{ kind::{ReadWrite, Uniform}, ops::{TensorOp, TensorPass}, shape::Shape, - TensorError, TensorGpu, TensorShape, TensorView, + TensorError, TensorGpu, TensorGpuView, TensorShape, }, }; @@ -28,8 +28,8 @@ pub enum Matrix { impl Matrix { pub fn matmul_vec_op( &self, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { match self { @@ -41,8 +41,8 @@ impl Matrix { pub fn matmul_mat_op( &self, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { match self { @@ -60,8 +60,8 @@ impl Matrix { pub fn matmul_op( &self, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, turbo: bool, ) -> Result { diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 7519132..5ffc683 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -3,10 +3,7 @@ use std::{borrow::Cow, marker::PhantomData, sync::Arc}; use itertools::Itertools; use thiserror::Error; use web_rwkv_derive::JsError; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt}, - BindingResource, Buffer, BufferBinding, BufferDescriptor, MapMode, -}; +use wgpu::{BindingResource, Buffer, BufferBinding, MapMode}; use self::{ kind::{Kind, ReadBack, ReadWrite, Uniform}, @@ -124,17 +121,17 @@ pub enum TensorError { /// Data defining a tensor view in shader. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] pub struct View { + pub shape: Shape, pub stride: Shape, pub offset: Shape, - pub shape: Shape, } impl IntoBytes for View { fn into_bytes(self) -> Vec { [ + self.shape.into_bytes(), self.stride.into_bytes(), self.offset.into_bytes(), - self.shape.into_bytes(), ] .concat() } @@ -395,16 +392,8 @@ impl<'a, T: Scalar, K: Kind> TensorInit<'a, T> for TensorGpu { /// Initialize a GPU tensor with a given shape. fn init(context: &Context, shape: Shape) -> Self { - let size = shape.len() as u64 * T::size() as u64; - let buffer = context - .device - .create_buffer(&BufferDescriptor { - label: None, - size, - usage: K::buffer_usages(), - mapped_at_creation: false, - }) - .into(); + let size = shape.len() * T::size(); + let buffer = context.checkout_buffer(size, K::buffer_usages()); Self { context: context.clone(), @@ -457,14 +446,7 @@ impl From> for TensorGpu { } = value; let meta = context.checkout_shape_uniform(shape); let contents = bytemuck::cast_slice(&data); - let buffer = context - .device - .create_buffer_init(&BufferInitDescriptor { - label: None, - contents, - usage: K::buffer_usages(), - }) - .into(); + let buffer = context.checkout_buffer_init(contents, K::buffer_usages()); Self { context, @@ -730,20 +712,20 @@ impl<'a, T: Scalar> TensorCpu<'a, T> { /// Like a reference to a tensor, but refer to a sub-chunk of it. #[derive(Debug, Clone)] -pub struct TensorView<'a, T: Scalar> { +pub struct TensorGpuView<'a, T: Scalar> { tensor: &'a TensorGpu, meta: Arc, view: View, } -impl TensorShape for TensorView<'_, T> { +impl TensorShape for TensorGpuView<'_, T> { #[inline] fn shape(&self) -> Shape { self.view.shape } } -impl TensorView<'_, T> { +impl TensorGpuView<'_, T> { #[inline] pub fn tensor(&self) -> &TensorGpu { self.tensor @@ -774,11 +756,11 @@ impl TensorView<'_, T> { } } -impl TensorScalar for TensorView<'_, T> { +impl TensorScalar for TensorGpuView<'_, T> { type T = T; } -impl TensorView<'_, F> { +impl TensorGpuView<'_, F> { #[inline] pub const fn def(&self) -> &'static str { F::DEF @@ -793,7 +775,7 @@ impl TensorGpu { y: impl TensorAxis, z: impl TensorAxis, w: impl TensorAxis, - ) -> Result, TensorError> { + ) -> Result, TensorError> { let slice = (x, y, z, w); let (start, end) = slice.shape_bounds(self.shape)?; let view = View { @@ -802,7 +784,7 @@ impl TensorGpu { shape: end - start, }; let meta = self.context.checkout_view_uniform(view); - Ok(TensorView { + Ok(TensorGpuView { tensor: self, meta, view, diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 587f54e..6ddf2e8 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -5,7 +5,7 @@ use wgpu::{BindGroup, BindGroupDescriptor, BindGroupEntry, CommandEncoder, Compu use super::{ kind::{Kind, ReadWrite, Uniform}, - Shape, TensorError, TensorGpu, TensorScalar, TensorShape, TensorView, + Shape, TensorError, TensorGpu, TensorGpuView, TensorScalar, TensorShape, }; use crate::{ context::{CachedPipeline, Macros}, @@ -395,8 +395,8 @@ impl TensorOp { /// - `output` shape: `[R, T, B]`. pub fn matmul_vec_fp16( matrix: &TensorGpu, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -462,8 +462,8 @@ impl TensorOp { pub fn matmul_vec_int8( matrix: &TensorGpu, minmax: &TensorGpu, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -540,8 +540,8 @@ impl TensorOp { matrix: &TensorGpu, quant: &TensorGpu, absmax: &TensorGpu, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -621,9 +621,9 @@ impl TensorOp { /// /// Note: `K` must be multiples of 128; `M` and `N` must be multiples of 4. pub fn matmul_mat_fp16( - matrix: TensorView, - input: TensorView, - output: TensorView, + matrix: TensorGpuView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { const BLOCK_SIZE: u32 = 8; @@ -693,10 +693,10 @@ impl TensorOp { /// Note: `K` must be multiples of 128; `M` and `N` must be multiples of 4. #[allow(clippy::too_many_arguments)] pub fn matmul_mat_int8( - matrix: TensorView, + matrix: TensorGpuView, minmax: &TensorGpu, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { const BLOCK_SIZE: u32 = 8; @@ -776,11 +776,11 @@ impl TensorOp { /// /// Note: `K` must be multiples of 256; `M` and `N` must be multiples of 8. pub fn matmul_mat_nf4( - matrix: TensorView, + matrix: TensorGpuView, quant: &TensorGpu, absmax: &TensorGpu, - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, active: Activation, ) -> Result { const BLOCK_SIZE: u32 = 8; @@ -852,8 +852,8 @@ impl TensorOp { } pub fn add( - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -907,8 +907,8 @@ impl TensorOp { } pub fn mul( - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -963,8 +963,8 @@ impl TensorOp { pub fn token_shift( cursors: &TensorGpu, - time_mix: TensorView, - sx: TensorView, + time_mix: TensorGpuView, + sx: TensorGpuView, input: &TensorGpu, output: &TensorGpu, reversed: bool, @@ -1041,7 +1041,7 @@ impl TensorOp { cursors: &TensorGpu, time_decay: &TensorGpu, time_first: &TensorGpu, - state: TensorView, + state: TensorGpuView, k: &TensorGpu, v: &TensorGpu, r: &TensorGpu, @@ -1124,7 +1124,7 @@ impl TensorOp { cursors: &TensorGpu, time_decay: &TensorGpu, time_first: &TensorGpu, - state: TensorView, + state: TensorGpuView, k: &TensorGpu, v: &TensorGpu, r: &TensorGpu, @@ -1209,7 +1209,7 @@ impl TensorOp { cursors: &TensorGpu, time_decay: &TensorGpu, time_first: &TensorGpu, - state: TensorView, + state: TensorGpuView, k: &TensorGpu, v: &TensorGpu, r: &TensorGpu, @@ -1454,7 +1454,7 @@ impl TensorOp { pub fn channel_mix( cursors: &TensorGpu, - state: TensorView, + state: TensorGpuView, r: &TensorGpu, v: &TensorGpu, x: &TensorGpu, @@ -1522,8 +1522,8 @@ impl TensorOp { /// Copy the content of `input` into `output` of the same shape. pub fn blit( - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -1576,8 +1576,8 @@ impl TensorOp { /// Repeat the content of `input` into `output` along the token and batch axes. pub fn broadcast( - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -1630,8 +1630,8 @@ impl TensorOp { /// Swap the `token` and `batch` axes. pub fn transpose( - input: TensorView, - output: TensorView, + input: TensorGpuView, + output: TensorGpuView, ) -> Result { const BLOCK_SIZE: u32 = 128; @@ -1743,9 +1743,9 @@ impl TensorOp { pub fn blend_lora( factor: &TensorGpu, - xa: TensorView, - xb: TensorView, - output: TensorView, + xa: TensorGpuView, + xb: TensorGpuView, + output: TensorGpuView, ) -> Result { const BLOCK_SIZE: u32 = 8;