From 62885249f9b731f1a163aed03e4cca21dfe647f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A0=94=E7=A9=B6=E7=A4=BE=E4=BA=A4?= Date: Mon, 25 Mar 2024 20:26:56 +0800 Subject: [PATCH] Fix tensor init cache (#22) * Implement init for `TensorGpu`. * Re-implement resource cache. * Simplify `TensorInitContext` trait. * Bump version to v0.6.35 --- Cargo.toml | 2 +- src/context.rs | 151 ++++++++++++++++++-------------------------- src/tensor/cache.rs | 54 ++-------------- src/tensor/mod.rs | 73 +++++++++++++++++---- 4 files changed, 129 insertions(+), 151 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 110db37..bcfa326 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"] license = "MIT OR Apache-2.0" name = "web-rwkv" repository = "https://github.com/cryscan/web-rwkv" -version = "0.6.34" +version = "0.6.35" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/context.rs b/src/context.rs index f20afed..19b936e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,8 +1,4 @@ -use std::{ - borrow::Cow, - collections::HashMap, - sync::{Arc, Mutex}, -}; +use std::{borrow::Cow, sync::Arc}; #[cfg(not(target_arch = "wasm32"))] use flume::Sender; @@ -85,8 +81,8 @@ pub struct ContextInternal { pub device: Device, pub queue: Queue, - pipeline_cache: Mutex>>, - buffer_cache: ResourceCache<(usize, BufferUsages), Buffer>, + pipeline_cache: ResourceCache, + view_cache: ResourceCache, #[cfg(not(target_arch = "wasm32"))] event: Sender, @@ -156,7 +152,7 @@ impl<'a> ContextBuilder { device, queue, pipeline_cache: Default::default(), - buffer_cache: Default::default(), + view_cache: Default::default(), #[cfg(not(target_arch = "wasm32"))] event: sender, }); @@ -279,46 +275,39 @@ impl ContextInternal { let mut context = Context::new(); context.macros = macros.0.into_iter().collect(); - let mut cache = self.pipeline_cache.lock().unwrap(); - match cache.get(&key) { - Some(pipeline) => pipeline.clone(), - None => { - let shader = process_str(source.as_ref(), &mut context).unwrap(); - let module = &self.device.create_shader_module(ShaderModuleDescriptor { - label: Some(name), - source: wgpu::ShaderSource::Wgsl(Cow::from(shader)), - }); - - let layout = layout.map(|entries| { - let layout = self - .device - .create_bind_group_layout(&BindGroupLayoutDescriptor { - label: None, - entries, - }); - self.device - .create_pipeline_layout(&PipelineLayoutDescriptor { - label: None, - bind_group_layouts: &[&layout], - push_constant_ranges: &[], - }) - }); + self.pipeline_cache.checkout(key, || { + let shader = process_str(source.as_ref(), &mut context).unwrap(); + let module = &self.device.create_shader_module(ShaderModuleDescriptor { + label: Some(name), + source: wgpu::ShaderSource::Wgsl(Cow::from(shader)), + }); - let pipeline = self + let layout = layout.map(|entries| { + let layout = self .device - .create_compute_pipeline(&ComputePipelineDescriptor { - label: Some(name), - layout: layout.as_ref(), - module, - entry_point, + .create_bind_group_layout(&BindGroupLayoutDescriptor { + label: None, + entries, }); - let layout = pipeline.get_bind_group_layout(0); - let pipeline = Arc::new(CachedPipeline { pipeline, layout }); + self.device + .create_pipeline_layout(&PipelineLayoutDescriptor { + label: None, + bind_group_layouts: &[&layout], + push_constant_ranges: &[], + }) + }); - cache.insert(key, pipeline.clone()); - pipeline - } - } + let pipeline = self + .device + .create_compute_pipeline(&ComputePipelineDescriptor { + label: Some(name), + layout: layout.as_ref(), + module, + entry_point, + }); + let layout = pipeline.get_bind_group_layout(0); + CachedPipeline { pipeline, layout } + }) } pub(crate) fn checkout_shape_uniform(&self, shape: Shape) -> Arc { @@ -327,60 +316,44 @@ impl ContextInternal { stride: shape, offset: Shape::new(0, 0, 0, 0), }; - self.buffer_cache.checkout( - (view.into_bytes().len(), BufferUsages::UNIFORM), - |buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()), - || { - self.device.create_buffer_init(&BufferInitDescriptor { - label: None, - contents: &view.into_bytes(), - usage: BufferUsages::UNIFORM, - }) - }, - ) + self.view_cache.checkout(view, || { + self.device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents: &view.into_bytes(), + usage: BufferUsages::UNIFORM, + }) + }) } pub(crate) fn checkout_view_uniform(&self, view: View) -> Arc { - self.buffer_cache.checkout( - (view.into_bytes().len(), BufferUsages::UNIFORM), - |buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()), - || { - self.device.create_buffer_init(&BufferInitDescriptor { - label: None, - contents: &view.into_bytes(), - usage: BufferUsages::UNIFORM, - }) - }, - ) + self.view_cache.checkout(view, || { + self.device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents: &view.into_bytes(), + usage: BufferUsages::UNIFORM, + }) + }) } pub(crate) fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc { - self.buffer_cache.checkout( - (contents.len(), usage), - |buffer| self.queue.write_buffer(buffer, 0, contents), - || { - self.device.create_buffer_init(&BufferInitDescriptor { - label: None, - contents, - usage, - }) - }, - ) + self.device + .create_buffer_init(&BufferInitDescriptor { + label: None, + contents, + usage, + }) + .into() } pub(crate) fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc { - self.buffer_cache.checkout( - (size, usage), - |_| (), - || { - self.device.create_buffer(&BufferDescriptor { - label: None, - size: size as u64, - usage, - mapped_at_creation: false, - }) - }, - ) + self.device + .create_buffer(&BufferDescriptor { + label: None, + size: size as u64, + usage, + mapped_at_creation: false, + }) + .into() } #[cfg(not(target_arch = "wasm32"))] diff --git a/src/tensor/cache.rs b/src/tensor/cache.rs index a49120c..d20bf14 100644 --- a/src/tensor/cache.rs +++ b/src/tensor/cache.rs @@ -4,37 +4,14 @@ use std::{ sync::{Arc, Mutex}, }; -use instant::{Duration, Instant}; - -#[derive(Debug)] -struct CacheItem { - value: Arc, - instant: Instant, -} - -impl CacheItem { - fn new(value: Arc) -> Self { - Self { - value, - instant: Instant::now(), - } - } - - fn count(&self) -> usize { - Arc::strong_count(&self.value) - } -} - #[derive(Debug)] pub struct ResourceCache { - duration: Duration, - map: Mutex>>>, + map: Mutex>>, } impl Default for ResourceCache { fn default() -> Self { Self { - duration: Duration::from_secs(1), map: Default::default(), } } @@ -44,35 +21,14 @@ impl ResourceCache where K: PartialEq + Eq + Hash, { - /// Note: If `duration` is 0, the cache won't evict any items. - pub fn new(duration: Duration) -> Self { - Self { - duration, - map: Default::default(), - } - } - /// Checkout the item with the given key. If the item doesn't exist, `f` is called to construct it. - pub fn checkout(&self, key: K, hit: impl FnOnce(&V), miss: impl FnOnce() -> V) -> Arc { + pub fn checkout(&self, key: K, miss: impl FnOnce() -> V) -> Arc { let mut map = self.map.lock().unwrap(); - if !self.duration.is_zero() { - for (_, items) in map.iter_mut() { - items.retain(|item| item.count() > 1 && item.instant.elapsed() < self.duration); - } - } - - match map - .get_mut(&key) - .and_then(|items| items.iter_mut().find(|item| item.count() == 1)) - { - Some(item) => { - hit(&item.value); - item.instant = Instant::now(); - item.value.clone() - } + match map.get(&key) { + Some(value) => value.clone(), None => { let value = Arc::new(miss()); - map.insert(key, vec![CacheItem::new(value.clone())]); + map.insert(key, value.clone()); value } } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index f463a75..5e515ab 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -184,6 +184,17 @@ pub trait TensorScalar { type T: Scalar; } +pub trait TensorInitContext<'a, T: Scalar>: Sized { + /// Init the tensor with given shape and contents. + fn from_data( + context: &Context, + shape: Shape, + data: impl Into>, + ) -> Result; + /// Init the tensor with given shape. + fn init(context: &Context, shape: Shape) -> Self; +} + pub trait TensorInit<'a, T: Scalar>: Sized { /// Init the tensor with given shape and contents. fn from_data(shape: Shape, data: impl Into>) -> Result; @@ -195,7 +206,6 @@ pub trait TensorInit<'a, T: Scalar>: Sized { if T::DATA_TYPE != dt { return Err(TensorError::Type); } - let shape = Shape::from_slice_rev(&shape)?; match data { Cow::Borrowed(data) => Self::from_data(shape, bytemuck::cast_slice(data)), @@ -339,6 +349,20 @@ impl<'a, T: Scalar> TensorInit<'a, T> for TensorCpu<'a, T> { } } +impl<'a, T: Scalar> TensorInitContext<'a, T> for TensorCpu<'a, T> { + fn from_data( + _context: &Context, + shape: Shape, + data: impl Into>, + ) -> Result { + TensorInit::from_data(shape, data) + } + + fn init(_context: &Context, shape: Shape) -> Self { + TensorInit::init(shape) + } +} + impl<'a, T: Scalar> TensorFrom> for TensorCpu<'a, T> { fn transfer_from(_context: &Context, value: TensorCpu<'a, T>) -> Self { value @@ -375,6 +399,35 @@ impl TensorReshape for TensorCpu<'_, T> { } } +impl<'a, T: Scalar, K: Kind> TensorInitContext<'a, T> for TensorGpu { + fn from_data( + context: &Context, + shape: Shape, + data: impl Into>, + ) -> Result { + let tensor: TensorCpu = TensorInit::from_data(shape, data)?; + Ok(tensor.transfer_into(context)) + } + + fn init(context: &Context, shape: Shape) -> Self { + let context = context.clone(); + let meta = context.checkout_shape_uniform(shape); + + let size = shape.len() * std::mem::size_of::(); + let buffer = context.checkout_buffer(size, K::buffer_usages()); + + Self { + shape, + data: TensorGpuData { + context, + meta, + buffer, + }, + phantom: PhantomData, + } + } +} + impl<'a, T: Scalar, K: Kind> TensorFrom> for TensorGpu { fn transfer_from(context: &Context, value: TensorCpu<'a, T>) -> Self { let Tensor { shape, data, .. } = value; @@ -608,7 +661,7 @@ impl<'a, T: Scalar> TensorCpu<'a, T> { pub fn map(self, f: impl FnMut(&T) -> U) -> TensorCpu<'a, U> { let Self { shape, data, .. } = self; let data = data.iter().map(f).collect_vec(); - TensorCpu::from_data(shape, data).expect("this never happens") + TensorInit::from_data(shape, data).expect("this never happens") } /// Repeat the tensor along a given axis. @@ -916,32 +969,28 @@ impl TryFrom>> for TensorStack<'_, T> { impl<'a> Context { #[inline] pub fn zeros>>(&self, shape: Shape) -> Tensor { - Tensor::transfer_from(self, TensorCpu::init(shape)) + Tensor::transfer_from(self, TensorInit::init(shape)) } #[inline] pub fn ones>>(&self, shape: Shape) -> Tensor { let data = vec![T::one(); shape.len()]; - let tensor = TensorCpu::from_data(shape, data).unwrap(); + let tensor = TensorInit::from_data(shape, data).unwrap(); Tensor::transfer_from(self, tensor) } #[inline] - pub fn tensor_from_data>>( + pub fn tensor_from_data>( &self, shape: Shape, data: impl Into>, ) -> Result { - let tensor = TensorCpu::from_data(shape, data)?; - Ok(Tensor::transfer_from(self, tensor)) + TensorInitContext::from_data(self, shape, data) } #[inline] - pub fn tensor_init>>( - &self, - shape: Shape, - ) -> Tensor { - Tensor::transfer_from(self, TensorCpu::init(shape)) + pub fn tensor_init>(&self, shape: Shape) -> Tensor { + TensorInitContext::init(self, shape) } }