From a7407d77f3b67291f68e026ead7fcf5d268fe394 Mon Sep 17 00:00:00 2001 From: cedruszhang Date: Sun, 9 Jul 2023 20:43:44 +0800 Subject: [PATCH] Separate head weight into chunks. --- Cargo.toml | 2 +- src/model.rs | 114 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 096942a..880c761 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "web-rwkv" -version = "0.1.1" +version = "0.1.2" edition = "2021" authors = ["Zhenyuan Zhang "] license = "MIT OR Apache-2.0" diff --git a/src/model.rs b/src/model.rs index 7308104..3cdef1a 100644 --- a/src/model.rs +++ b/src/model.rs @@ -5,9 +5,9 @@ use safetensors::SafeTensors; use std::{borrow::Cow, sync::Arc}; use wgpu::{ util::{BufferInitDescriptor, DeviceExt}, - BindGroup, BindGroupDescriptor, BindGroupEntry, Buffer, BufferDescriptor, BufferUsages, - CommandEncoderDescriptor, ComputePassDescriptor, ComputePipeline, ComputePipelineDescriptor, - ShaderModuleDescriptor, ShaderSource, + BindGroup, BindGroupDescriptor, BindGroupEntry, Buffer, BufferBinding, BufferDescriptor, + BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor, ComputePipeline, + ComputePipelineDescriptor, ShaderModuleDescriptor, ShaderSource, }; use crate::Environment; @@ -27,6 +27,10 @@ pub struct ModelInfo { pub num_vocab: usize, } +impl ModelInfo { + pub const HEAD_CHUNK_SIZE: usize = 16384; +} + pub struct ModelTensor { pub dim: Buffer, pub embed: Embed, @@ -84,7 +88,7 @@ pub struct Head { pub layer_norm: LayerNorm, pub dims: Buffer, - pub w: Buffer, + pub w: Vec, } pub struct ModelPipeline { @@ -149,7 +153,7 @@ pub struct EmbedBindGroup { pub struct HeadBindGroup { pub layer_norm: BindGroup, - pub matmul: BindGroup, + pub matmul: Vec, } pub struct LayerBindGroup { @@ -255,13 +259,29 @@ impl Model { }, w: pod_collect_to_vec(model.tensor("emb.weight")?.data()), }; - let head = Head { - layer_norm: LayerNorm { - w: load_tensor_f32("ln_out.weight".into())?, - b: load_tensor_f32("ln_out.bias".into())?, - }, - dims: create_uniform_u32(&[num_emb as u32, num_vocab as u32]), - w: load_tensor_f16("head.weight".into())?, + let head = { + let chunk_size = ModelInfo::HEAD_CHUNK_SIZE; + let w: Vec = pod_collect_to_vec(model.tensor("head.weight")?.data()); + let w = (0..num_vocab / chunk_size) + .map(|chunk| { + let begin = chunk_size * chunk * num_emb; + let end = begin + chunk_size * num_emb; + device.create_buffer_init(&BufferInitDescriptor { + label: None, + contents: cast_slice(&w[begin..end]), + usage: BufferUsages::STORAGE, + }) + }) + .collect(); + + Head { + layer_norm: LayerNorm { + w: load_tensor_f32("ln_out.weight".into())?, + b: load_tensor_f32("ln_out.bias".into())?, + }, + dims: create_uniform_u32(&[num_emb as u32, num_vocab as u32]), + w, + } }; let mut layers = vec![]; @@ -546,28 +566,44 @@ impl Model { }, ], }); - let matmul = device.create_bind_group(&BindGroupDescriptor { - label: None, - layout: &matmul_layout, - entries: &[ - BindGroupEntry { - binding: 0, - resource: self.tensor.head.dims.as_entire_binding(), - }, - BindGroupEntry { - binding: 1, - resource: self.tensor.head.w.as_entire_binding(), - }, - BindGroupEntry { - binding: 2, - resource: buffer.head_r.as_entire_binding(), - }, - BindGroupEntry { - binding: 3, - resource: buffer.head_o.as_entire_binding(), - }, - ], - }); + let matmul = self + .tensor + .head + .w + .iter() + .enumerate() + .map(|(chunk, w)| { + let chunk_size = ModelInfo::HEAD_CHUNK_SIZE as u64; + let offset = 4 * chunk as u64 * chunk_size; + + device.create_bind_group(&BindGroupDescriptor { + label: None, + layout: &matmul_layout, + entries: &[ + BindGroupEntry { + binding: 0, + resource: self.tensor.head.dims.as_entire_binding(), + }, + BindGroupEntry { + binding: 1, + resource: w.as_entire_binding(), + }, + BindGroupEntry { + binding: 2, + resource: buffer.head_r.as_entire_binding(), + }, + BindGroupEntry { + binding: 3, + resource: wgpu::BindingResource::Buffer(BufferBinding { + buffer: &buffer.head_o, + offset, + size: None, + }), + }, + ], + }) + }) + .collect(); HeadBindGroup { layer_norm, matmul } }; @@ -1092,7 +1128,7 @@ impl Model { let num_tokens = buffer.tokens.len() as u32; let num_emb_vec4 = num_emb as u32 / 4; - let num_vocab_vec4 = num_vocab as u32 / 4; + let chunk_size_vec4 = ModelInfo::HEAD_CHUNK_SIZE as u32 / 4; const BLOCK_SIZE: u32 = 256; let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor::default()); @@ -1204,9 +1240,11 @@ impl Model { pass.set_bind_group(0, &bind_group.head.layer_norm, &[]); pass.dispatch_workgroups(1, 1, 1); - pass.set_pipeline(&pipeline.matmul); - pass.set_bind_group(0, &bind_group.head.matmul, &[]); - pass.dispatch_workgroups(1, num_vocab_vec4, 1); + for matmul in &bind_group.head.matmul { + pass.set_pipeline(&pipeline.matmul); + pass.set_bind_group(0, matmul, &[]); + pass.dispatch_workgroups(1, chunk_size_vec4, 1); + } } encoder.copy_buffer_to_buffer(&buffer.head_o, 0, &buffer.map, 0, 4 * num_vocab as u64);