Skip to content

Commit

Permalink
Separate head weight into chunks.
Browse files Browse the repository at this point in the history
  • Loading branch information
cedruszhang committed Jul 9, 2023
1 parent 0dd4aa1 commit a7407d7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.1.1"
version = "0.1.2"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
114 changes: 76 additions & 38 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use safetensors::SafeTensors;
use std::{borrow::Cow, sync::Arc};
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
BindGroup, BindGroupDescriptor, BindGroupEntry, Buffer, BufferDescriptor, BufferUsages,
CommandEncoderDescriptor, ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor,
ShaderModuleDescriptor, ShaderSource,
BindGroup, BindGroupDescriptor, BindGroupEntry, Buffer, BufferBinding, BufferDescriptor,
BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor, ComputePipeline,
ComputePipelineDescriptor, ShaderModuleDescriptor, ShaderSource,
};

use crate::Environment;
Expand All @@ -27,6 +27,10 @@ pub struct ModelInfo {
pub num_vocab: usize,
}

impl ModelInfo {
pub const HEAD_CHUNK_SIZE: usize = 16384;
}

pub struct ModelTensor {
pub dim: Buffer,
pub embed: Embed,
Expand Down Expand Up @@ -84,7 +88,7 @@ pub struct Head {
pub layer_norm: LayerNorm,

pub dims: Buffer,
pub w: Buffer,
pub w: Vec<Buffer>,
}

pub struct ModelPipeline {
Expand Down Expand Up @@ -149,7 +153,7 @@ pub struct EmbedBindGroup {

pub struct HeadBindGroup {
pub layer_norm: BindGroup,
pub matmul: BindGroup,
pub matmul: Vec<BindGroup>,
}

pub struct LayerBindGroup {
Expand Down Expand Up @@ -255,13 +259,29 @@ impl Model {
},
w: pod_collect_to_vec(model.tensor("emb.weight")?.data()),
};
let head = Head {
layer_norm: LayerNorm {
w: load_tensor_f32("ln_out.weight".into())?,
b: load_tensor_f32("ln_out.bias".into())?,
},
dims: create_uniform_u32(&[num_emb as u32, num_vocab as u32]),
w: load_tensor_f16("head.weight".into())?,
let head = {
let chunk_size = ModelInfo::HEAD_CHUNK_SIZE;
let w: Vec<f16> = pod_collect_to_vec(model.tensor("head.weight")?.data());
let w = (0..num_vocab / chunk_size)
.map(|chunk| {
let begin = chunk_size * chunk * num_emb;
let end = begin + chunk_size * num_emb;
device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: cast_slice(&w[begin..end]),
usage: BufferUsages::STORAGE,
})
})
.collect();

Head {
layer_norm: LayerNorm {
w: load_tensor_f32("ln_out.weight".into())?,
b: load_tensor_f32("ln_out.bias".into())?,
},
dims: create_uniform_u32(&[num_emb as u32, num_vocab as u32]),
w,
}
};

let mut layers = vec![];
Expand Down Expand Up @@ -546,28 +566,44 @@ impl Model {
},
],
});
let matmul = device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &matmul_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: self.tensor.head.dims.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: self.tensor.head.w.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: buffer.head_r.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: buffer.head_o.as_entire_binding(),
},
],
});
let matmul = self
.tensor
.head
.w
.iter()
.enumerate()
.map(|(chunk, w)| {
let chunk_size = ModelInfo::HEAD_CHUNK_SIZE as u64;
let offset = 4 * chunk as u64 * chunk_size;

device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &matmul_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: self.tensor.head.dims.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: w.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: buffer.head_r.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: wgpu::BindingResource::Buffer(BufferBinding {
buffer: &buffer.head_o,
offset,
size: None,
}),
},
],
})
})
.collect();
HeadBindGroup { layer_norm, matmul }
};

Expand Down Expand Up @@ -1092,7 +1128,7 @@ impl Model {

let num_tokens = buffer.tokens.len() as u32;
let num_emb_vec4 = num_emb as u32 / 4;
let num_vocab_vec4 = num_vocab as u32 / 4;
let chunk_size_vec4 = ModelInfo::HEAD_CHUNK_SIZE as u32 / 4;
const BLOCK_SIZE: u32 = 256;

let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor::default());
Expand Down Expand Up @@ -1204,9 +1240,11 @@ impl Model {
pass.set_bind_group(0, &bind_group.head.layer_norm, &[]);
pass.dispatch_workgroups(1, 1, 1);

pass.set_pipeline(&pipeline.matmul);
pass.set_bind_group(0, &bind_group.head.matmul, &[]);
pass.dispatch_workgroups(1, num_vocab_vec4, 1);
for matmul in &bind_group.head.matmul {
pass.set_pipeline(&pipeline.matmul);
pass.set_bind_group(0, matmul, &[]);
pass.dispatch_workgroups(1, chunk_size_vec4, 1);
}
}

encoder.copy_buffer_to_buffer(&buffer.head_o, 0, &buffer.map, 0, 4 * num_vocab as u64);
Expand Down

0 comments on commit a7407d7

Please sign in to comment.