Skip to content

Commit

Permalink
Neo tensor (#20)
Browse files Browse the repository at this point in the history
* Implement ref count cache.

* `Tensor`'s meta data is also `View`.

* When creating tensor, use ref count cache.

* Fully leverages ref count buffer.

* Do not use temporary cache when loading the model.

* Ref count cache correctly initializes resources.

* Unify cache types.

* Reformat code.

* Bump version to v0.6.25
  • Loading branch information
cryscan authored Mar 13, 2024
1 parent a0b2740 commit d8d13a2
Show file tree
Hide file tree
Showing 27 changed files with 266 additions and 306 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.6.24"
version = "0.6.25"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
164 changes: 109 additions & 55 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::{borrow::Cow, sync::Arc};
use std::{
borrow::Cow,
collections::HashMap,
sync::{Arc, Mutex},
};

use thiserror::Error;
use wasm_bindgen::prelude::wasm_bindgen;
use web_rwkv_derive::{Deref, DerefMut};
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
Adapter, Backends, BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer,
BufferUsages, ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, Features,
Limits, PipelineLayoutDescriptor, PowerPreference, Queue, RequestAdapterOptions,
ShaderModuleDescriptor,
BufferDescriptor, BufferUsages, ComputePipeline, ComputePipelineDescriptor, Device,
DeviceDescriptor, Features, Limits, PipelineLayoutDescriptor, PowerPreference, Queue,
RequestAdapterOptions, ShaderModuleDescriptor,
};

use crate::{
Expand Down Expand Up @@ -76,10 +80,8 @@ pub struct ContextInternal {
pub device: Device,
pub queue: Queue,

pipeline_cache: ResourceCache<PipelineKey, CachedPipeline>,

shape_cache: ResourceCache<Shape, Buffer>,
view_cache: ResourceCache<View, Buffer>,
pipeline_cache: Mutex<HashMap<PipelineKey, Arc<CachedPipeline>>>,
buffer_cache: ResourceCache<usize, Buffer>,
}

#[derive(Debug, Clone, Deref, DerefMut)]
Expand Down Expand Up @@ -134,9 +136,8 @@ impl<'a> ContextBuilder {
adapter,
device,
queue,
pipeline_cache: ResourceCache::new(0),
shape_cache: Default::default(),
view_cache: Default::default(),
pipeline_cache: Default::default(),
buffer_cache: Default::default(),
}
.into(),
))
Expand Down Expand Up @@ -234,58 +235,111 @@ impl Context {
let mut context = Context::new();
context.macros = macros.0.into_iter().collect();

self.pipeline_cache.checkout(key, move || {
let shader = process_str(source.as_ref(), &mut context).unwrap();
let module = &self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: wgpu::ShaderSource::Wgsl(Cow::from(shader)),
});
let mut cache = self.pipeline_cache.lock().unwrap();
match cache.get(&key) {
Some(pipeline) => pipeline.clone(),
None => {
let shader = process_str(source.as_ref(), &mut context).unwrap();
let module = &self.device.create_shader_module(ShaderModuleDescriptor {
label: Some(name),
source: wgpu::ShaderSource::Wgsl(Cow::from(shader)),
});

let layout = layout.map(|entries| {
let layout = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries,
});
self.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&layout],
push_constant_ranges: &[],
})
});

let layout = layout.map(|entries| {
let layout = self
let pipeline = self
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries,
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(name),
layout: layout.as_ref(),
module,
entry_point,
});
self.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&layout],
push_constant_ranges: &[],
})
});

let pipeline = self
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some(name),
layout: layout.as_ref(),
module,
entry_point,
});
let layout = pipeline.get_bind_group_layout(0);
CachedPipeline { pipeline, layout }
})
let layout = pipeline.get_bind_group_layout(0);
let pipeline = Arc::new(CachedPipeline { pipeline, layout });

cache.insert(key, pipeline.clone());
pipeline
}
}
}

pub fn checkout_shape_uniform(&self, shape: Shape) -> Arc<Buffer> {
self.shape_cache.checkout(shape, || {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &shape.into_bytes(),
usage: BufferUsages::UNIFORM,
})
})
let view = View {
shape,
stride: shape,
offset: Shape::new(0, 0, 0, 0),
};
self.buffer_cache.checkout(
shape.into_bytes().len(),
|buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()),
|| {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
},
)
}

pub fn checkout_view_uniform(&self, view: View) -> Arc<Buffer> {
self.view_cache.checkout(view, || {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
})
self.buffer_cache.checkout(
view.into_bytes().len(),
|buffer| self.queue.write_buffer(buffer, 0, &view.into_bytes()),
|| {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &view.into_bytes(),
usage: BufferUsages::UNIFORM,
})
},
)
}

pub fn checkout_buffer_init(&self, contents: &[u8], usage: BufferUsages) -> Arc<Buffer> {
self.buffer_cache.checkout(
contents.len(),
|buffer| {
if usage.contains(BufferUsages::STORAGE) {
self.queue.write_buffer(buffer, 0, contents);
}
},
|| {
self.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents,
usage,
})
},
)
}

pub fn checkout_buffer(&self, size: usize, usage: BufferUsages) -> Arc<Buffer> {
self.buffer_cache.checkout(
size,
|_| (),
|| {
self.device.create_buffer(&BufferDescriptor {
label: None,
size: size as u64,
usage,
mapped_at_creation: false,
})
},
)
}
}
17 changes: 5 additions & 12 deletions src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use super::{ModelError, ModelInfo, ModelVersion, Quant};
use crate::{
context::Context,
tensor::{
cache::ResourceCache,
kind::ReadWrite,
matrix::Matrix,
ops::{TensorCommand, TensorOp, TensorPass},
Expand Down Expand Up @@ -575,24 +574,19 @@ impl<R: Reader> Loader<R> {
Ok(head)
}

pub async fn load_matrix(
&self,
cache: &ResourceCache<Shape, TensorGpu<f16, ReadWrite>>,
name: String,
quant: Quant,
) -> Result<Matrix> {
pub async fn load_matrix(&self, name: String, quant: Quant) -> Result<Matrix> {
let context = &self.context;
match quant {
Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16(name).await?)),
Quant::Int8 => {
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
let buffer = context.tensor_init(shape);
self.load_in_place_matrix_f16(&buffer, &name).await?;
Ok(Matrix::quant_u8(&buffer)?)
}
Quant::NF4 => {
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
let buffer = context.tensor_init(shape);
self.load_in_place_matrix_f16(&buffer, &name).await?;
Ok(Matrix::quant_nf4(&buffer)?)
}
Expand All @@ -601,7 +595,6 @@ impl<R: Reader> Loader<R> {

pub async fn load_matrix_discount(
&self,
cache: &ResourceCache<Shape, TensorGpu<f16, ReadWrite>>,
name: String,
quant: Quant,
discount: f32,
Expand All @@ -613,14 +606,14 @@ impl<R: Reader> Loader<R> {
)),
Quant::Int8 => {
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
let buffer = context.tensor_init(shape);
self.load_in_place_matrix_f16_discount(&buffer, &name, discount)
.await?;
Ok(Matrix::quant_u8(&buffer)?)
}
Quant::NF4 => {
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
let buffer = context.tensor_init(shape);
self.load_in_place_matrix_f16_discount(&buffer, &name, discount)
.await?;
Ok(Matrix::quant_nf4(&buffer)?)
Expand Down
6 changes: 3 additions & 3 deletions src/model/run.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, future::Future, hash::Hash, sync::Arc};
use std::{collections::HashMap, future::Future, hash::Hash};

use anyhow::Result;
use half::f16;
Expand Down Expand Up @@ -50,8 +50,8 @@ pub(crate) trait ModelRunInternal: ModelBase {

fn tensor(&self) -> &Self::Tensor;

fn checkout_runtime(&self, num_batch: usize) -> Arc<Self::Runtime>;
fn checkout_header(&self, num_batch: usize) -> Arc<Self::Header>;
fn checkout_runtime(&self, num_batch: usize) -> Self::Runtime;
fn checkout_header(&self, num_batch: usize) -> Self::Header;

/// To prevent the GPU device from lost, this limits the maximum batch-token it processes one time.
fn token_chunk_size(&self) -> usize;
Expand Down
10 changes: 3 additions & 7 deletions src/model/softmax.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{future::Future, sync::Arc};
use std::future::Future;

use anyhow::Result;
use itertools::Itertools;
Expand Down Expand Up @@ -30,10 +30,6 @@ impl Softmax {
}
}

pub(crate) trait ModelSoftmaxInternal: ModelBase {
fn checkout_softmax(&self, num_batch: usize) -> Arc<Softmax>;
}

pub trait ModelSoftmax {
/// Softmax of the input tensors.
fn softmax(
Expand All @@ -42,7 +38,7 @@ pub trait ModelSoftmax {
) -> impl Future<Output = Result<Vec<ModelOutput>, TensorError>>;
}

impl<Model: ModelSoftmaxInternal> ModelSoftmax for Model {
impl<M: ModelBase> ModelSoftmax for M {
async fn softmax(&self, input: Vec<ModelOutput>) -> Result<Vec<ModelOutput>, TensorError> {
let context = self.context();
let info = self.info();
Expand Down Expand Up @@ -78,7 +74,7 @@ impl<Model: ModelSoftmaxInternal> ModelSoftmax for Model {
)?;

let num_batch = input.shape()[2];
let softmax = self.checkout_softmax(num_batch);
let softmax = Softmax::new(context, info, num_batch);
softmax.buffer.load(&input)?;

let op = TensorOp::softmax(&softmax.buffer)?;
Expand Down
Loading

0 comments on commit d8d13a2

Please sign in to comment.