Skip to content

Commit

Permalink
Fix tensor init cache (#22)
Browse files Browse the repository at this point in the history
* Implement init for `TensorGpu`.

* Re-implement resource cache.

* Simplify `TensorInitContext` trait.

* Bump version to v0.6.35
  • Loading branch information
cryscan authored Mar 25, 2024
1 parent f22f405 commit 6288524
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 151 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.6.34"
version = "0.6.35"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
151 changes: 62 additions & 89 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use std::{
borrow::Cow,
collections::HashMap,
sync::{Arc, Mutex},
};
use std::{borrow::Cow, sync::Arc};

#[cfg(not(target_arch = "wasm32"))]
use flume::Sender;
Expand Down Expand Up @@ -85,8 +81,8 @@ pub struct ContextInternal {
pub device: Device,
pub queue: Queue,

pipeline_cache: Mutex<HashMap<PipelineKey, Arc<CachedPipeline>>>,
buffer_cache: ResourceCache<(usize, BufferUsages), Buffer>,
pipeline_cache: ResourceCache<PipelineKey, CachedPipeline>,
view_cache: ResourceCache<View, Buffer>,

#[cfg(not(target_arch = "wasm32"))]
event: Sender<ContextEvent>,
Expand Down Expand Up @@ -156,7 +152,7 @@ impl<'a> ContextBuilder {
device,
queue,
pipeline_cache: Default::default(),
buffer_cache: Default::default(),
view_cache: Default::default(),
#[cfg(not(target_arch = "wasm32"))]
event: sender,
});
Expand Down Expand Up @@ -279,46 +275,39 @@ impl ContextInternal {
let mut context = Context::new();
context.macros = macros.0.into_iter().collect();

let mut cache = self.pipeline_cache.lock().unwrap();
match cache.get(&key) {
Some(pipeline) => pipeline.clone(),
None => {
let shader = process_str(source.as_ref(), &mut context).unwrap();
let module = &self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: wgpu::ShaderSource::Wgsl(Cow::from(shader)),
});

let layout = layout.map(|entries| {
let layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries,
});
self.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&layout],
push_constant_ranges: &[],
})
});
self.pipeline_cache.checkout(key, || {
let shader = process_str(source.as_ref(), &mut context).unwrap();
let module = &self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: wgpu::ShaderSource::Wgsl(Cow::from(shader)),
});

let pipeline = self
let layout = layout.map(|entries| {
let layout = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(name),
layout: layout.as_ref(),
module,
entry_point,
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries,
});
let layout = pipeline.get_bind_group_layout(0);
let pipeline = Arc::new(CachedPipeline { pipeline, layout });
self.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&layout],
push_constant_ranges: &[],
})
});

cache.insert(key, pipeline.clone());
pipeline
}
}
let pipeline = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(name),
layout: layout.as_ref(),
module,
entry_point,
});
let layout = pipeline.get_bind_group_layout(0);
CachedPipeline { pipeline, layout }
})
}

pub(crate) fn checkout_shape_uniform(&self, shape: Shape) -> Arc<Buffer> {
Expand All @@ -327,60 +316,44 @@ impl ContextInternal {
stride: shape,
offset: Shape::new(0, 0, 0, 0),
};
self.buffer_cache.checkout(
(view.into_bytes().len(), BufferUsages::UNIFORM),
|buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()),
|| {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
},
)
self.view_cache.checkout(view, || {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
})
}

pub(crate) fn checkout_view_uniform(&self, view: View) -> Arc<Buffer> {
self.buffer_cache.checkout(
(view.into_bytes().len(), BufferUsages::UNIFORM),
|buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()),
|| {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
},
)
self.view_cache.checkout(view, || {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
})
}

pub(crate) fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc<Buffer> {
self.buffer_cache.checkout(
(contents.len(), usage),
|buffer| self.queue.write_buffer(buffer, 0, contents),
|| {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents,
usage,
})
},
)
self.device
.create_buffer_init(&BufferInitDescriptor {
label: None,
contents,
usage,
})
.into()
}

pub(crate) fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
self.buffer_cache.checkout(
(size, usage),
|_| (),
|| {
self.device.create_buffer(&BufferDescriptor {
label: None,
size: size as u64,
usage,
mapped_at_creation: false,
})
},
)
self.device
.create_buffer(&BufferDescriptor {
label: None,
size: size as u64,
usage,
mapped_at_creation: false,
})
.into()
}

#[cfg(not(target_arch = "wasm32"))]
Expand Down
54 changes: 5 additions & 49 deletions src/tensor/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,14 @@ use std::{
sync::{Arc, Mutex},
};

use instant::{Duration, Instant};

#[derive(Debug)]
struct CacheItem<V> {
value: Arc<V>,
instant: Instant,
}

impl<V> CacheItem<V> {
fn new(value: Arc<V>) -> Self {
Self {
value,
instant: Instant::now(),
}
}

fn count(&self) -> usize {
Arc::strong_count(&self.value)
}
}

#[derive(Debug)]
pub struct ResourceCache<K, V> {
duration: Duration,
map: Mutex<HashMap<K, Vec<CacheItem<V>>>>,
map: Mutex<HashMap<K, Arc<V>>>,
}

impl<K, V> Default for ResourceCache<K, V> {
fn default() -> Self {
Self {
duration: Duration::from_secs(1),
map: Default::default(),
}
}
Expand All @@ -44,35 +21,14 @@ impl<K, V> ResourceCache<K, V>
where
K: PartialEq + Eq + Hash,
{
/// Note: If `duration` is 0, the cache won't evict any items.
pub fn new(duration: Duration) -> Self {
Self {
duration,
map: Default::default(),
}
}

/// Checkout the item with the given key. If the item doesn't exist, `f` is called to construct it.
pub fn checkout(&self, key: K, hit: impl FnOnce(&V), miss: impl FnOnce() -> V) -> Arc<V> {
pub fn checkout(&self, key: K, miss: impl FnOnce() -> V) -> Arc<V> {
let mut map = self.map.lock().unwrap();
if !self.duration.is_zero() {
for (_, items) in map.iter_mut() {
items.retain(|item| item.count() > 1 && item.instant.elapsed() < self.duration);
}
}

match map
.get_mut(&key)
.and_then(|items| items.iter_mut().find(|item| item.count() == 1))
{
Some(item) => {
hit(&item.value);
item.instant = Instant::now();
item.value.clone()
}
match map.get(&key) {
Some(value) => value.clone(),
None => {
let value = Arc::new(miss());
map.insert(key, vec![CacheItem::new(value.clone())]);
map.insert(key, value.clone());
value
}
}
Expand Down
Loading

0 comments on commit 6288524

Please sign in to comment.