Skip to content

Commit

Permalink
Less memory footprint when loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 25, 2024
1 parent dc8df31 commit 44872e5
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 73 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.6.20"
version = "0.6.21"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async fn run(cli: Cli) -> Result<()> {
let data = unsafe { Mmap::map(&file)? };

let model = SafeTensors::deserialize(&data)?;
let info = Loader::info(&model).await?;
let info = Loader::info(&model)?;
println!("{:#?}", info);

let lora = match cli.lora {
Expand Down
2 changes: 1 addition & 1 deletion examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async fn run(cli: Cli) -> Result<()> {
let data = unsafe { Mmap::map(&file)? };

let model = SafeTensors::deserialize(&data)?;
let info = Loader::info(&model).await?;
let info = Loader::info(&model)?;
println!("{:#?}", info);

let lora = match cli.lora {
Expand Down
2 changes: 1 addition & 1 deletion examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async fn run(cli: Cli) -> Result<()> {
let data = unsafe { Mmap::map(&file)? };

let model = SafeTensors::deserialize(&data)?;
let info = Loader::info(&model).await?;
let info = Loader::info(&model)?;
println!("{:#?}", info);

let lora = match cli.lora {
Expand Down
2 changes: 1 addition & 1 deletion examples/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async fn run(cli: Cli) -> Result<()> {
let data = unsafe { Mmap::map(&file)? };

let model = SafeTensors::deserialize(&data)?;
let info = Loader::info(&model).await?;
let info = Loader::info(&model)?;
if info.version != ModelVersion::V5 {
bail!("this demo only supports v5");
}
Expand Down
120 changes: 60 additions & 60 deletions src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::Result;
use half::f16;
use itertools::Itertools;
use regex::Regex;
use safetensors::{SafeTensorError, SafeTensors};
use safetensors::{Dtype, SafeTensorError, SafeTensors};
use web_rwkv_derive::{Deref, DerefMut};

use super::{ModelError, ModelInfo, ModelVersion, Quant};
Expand All @@ -20,13 +20,13 @@ use crate::{
},
};

pub type ReaderTensor<'a> = (Vec<usize>, Cow<'a, [u8]>);
pub type ReaderTensor<'a> = (Dtype, Vec<usize>, Cow<'a, [u8]>);

/// Interface accessing a safetensors data blob.
pub trait Reader {
fn names(&self) -> Vec<&str>;
fn contains(&self, name: &str) -> bool;

fn shape(&self, name: &str) -> Result<Vec<usize>, SafeTensorError>;
fn tensor(&self, name: &str) -> impl Future<Output = Result<ReaderTensor, SafeTensorError>>;
}

Expand All @@ -41,13 +41,17 @@ impl Reader for SafeTensors<'_> {
self.names().contains(&&name.to_string())
}

#[inline]
fn shape(&self, name: &str) -> Result<Vec<usize>, SafeTensorError> {
Ok(self.tensor(name)?.shape().to_vec())
}

#[inline]
async fn tensor(&self, name: &str) -> Result<ReaderTensor, SafeTensorError> {
let name = name.to_string();
let tensor = self.tensor(&name)?;
let tensor = self.tensor(name)?;
let shape = tensor.shape().to_vec();
let data = Cow::from(tensor.data());
Ok((shape, data))
Ok((tensor.dtype(), shape, data))
}
}

Expand Down Expand Up @@ -110,8 +114,8 @@ struct LoraVector {
}

struct LoraMatrix {
a: TensorGpu<f16, ReadWrite>,
b: TensorGpu<f16, ReadWrite>,
x: TensorGpu<f16, ReadWrite>,
y: TensorGpu<f16, ReadWrite>,
rank: usize,
alpha: f32,
}
Expand All @@ -124,7 +128,7 @@ pub struct Loader<R: Reader> {
}

impl<R: Reader> Loader<R> {
pub async fn info(model: &R) -> Result<ModelInfo> {
pub fn info(model: &R) -> Result<ModelInfo> {
let num_layer = {
let mut r: usize = 0;
for i in model.names() {
Expand All @@ -137,9 +141,9 @@ impl<R: Reader> Loader<R> {
r + 1
};

let embed = model.tensor("emb.weight").await?.0;
let ffn = model.tensor("blocks.0.ffn.key.weight").await?.0;
let time_first = model.tensor("blocks.0.att.time_first").await?.0;
let embed = model.shape("emb.weight")?;
let ffn = model.shape("blocks.0.ffn.key.weight")?;
let time_first = model.shape("blocks.0.att.time_first")?;

let v5 = [
"blocks.0.att.gate.weight",
Expand Down Expand Up @@ -203,10 +207,10 @@ impl<R: Reader> Loader<R> {
continue;
};

let Ok((shape, data)) = lora.data.tensor(name).await else {
let Ok(tensor) = lora.data.tensor(name).await else {
continue;
};
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, data)?
let tensor = TensorCpu::<f16>::from_reader(&self.context, tensor)?
.map(|x| x.to_f32())
.into();
let alpha = blend.alpha;
Expand All @@ -233,36 +237,36 @@ impl<R: Reader> Loader<R> {
continue;
};

let Ok(a) = lora.data.tensor(&format!("{name}.lora.0")).await else {
let Ok(x) = lora.data.tensor(&format!("{name}.lora.0")).await else {
continue;
};
let Ok(b) = lora.data.tensor(&format!("{name}.lora.1")).await else {
let Ok(y) = lora.data.tensor(&format!("{name}.lora.1")).await else {
continue;
};

let a = TensorGpu::from_safetensors(&self.context, a.0, a.1)?;
let b = TensorGpu::from_safetensors(&self.context, b.0, b.1)?;
let rank = a.shape()[0];
let x = TensorGpu::from_reader(&self.context, x)?;
let y = TensorGpu::from_reader(&self.context, y)?;
let rank = x.shape()[0];
let alpha = blend.alpha;
matrices.push(LoraMatrix { a, b, rank, alpha });
matrices.push(LoraMatrix { x, y, rank, alpha });

log::info!("loaded LoRA {name}, alpha: {alpha}");
}
Ok(matrices)
}

pub async fn tensor_shape(&self, name: impl AsRef<str>) -> Result<Shape> {
let (shape, _) = self.model.tensor(name.as_ref()).await?;
Ok(Shape::from_safetensors(&shape)?)
pub fn tensor_shape(&self, name: impl AsRef<str>) -> Result<Shape> {
let shape = self.model.shape(name.as_ref())?;
Ok(Shape::from_slice_rev(&shape)?)
}

pub async fn load_vector_f32(
&self,
name: impl AsRef<str>,
) -> Result<TensorGpu<f32, ReadWrite>> {
use TensorDimension::{Auto, Dimension};
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, tensor)?
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_reader(&self.context, tensor)?
.map(|x| x.to_f32())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.into();
Expand All @@ -289,8 +293,8 @@ impl<R: Reader> Loader<R> {
name: impl AsRef<str>,
) -> Result<TensorGpu<f32, ReadWrite>> {
use TensorDimension::{Auto, Dimension};
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, tensor)?
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_reader(&self.context, tensor)?
.map(|x| -x.to_f32().exp())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
.into();
Expand All @@ -317,8 +321,8 @@ impl<R: Reader> Loader<R> {
name: impl AsRef<str>,
) -> Result<TensorGpu<f32, ReadWrite>> {
use TensorDimension::{Auto, Dimension};
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(&self.context, shape, tensor)?
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_reader(&self.context, tensor)?
.map(|x| -x.to_f32().exp())
.map(|x| x.exp())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?
Expand Down Expand Up @@ -348,16 +352,16 @@ impl<R: Reader> Loader<R> {
use TensorDimension::{Auto, Dimension};
let context = &self.context;
let lora = self.lora_vectors(name.as_ref()).await?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = if lora.is_empty() {
TensorGpu::from_safetensors(context, shape, tensor)?.reshape(
TensorGpu::from_reader(context, tensor)?.reshape(
Auto,
Dimension(1),
Dimension(1),
Dimension(1),
)?
} else {
let tensor_f32 = TensorCpu::<f16>::from_safetensors(context, shape, tensor)?
let tensor_f32 = TensorCpu::<f16>::from_reader(context, tensor)?
.map(|x| x.to_f32())
.reshape(Auto, Dimension(1), Dimension(1), Dimension(1))?;
let tensor_f32 = TensorGpu::from(tensor_f32);
Expand Down Expand Up @@ -393,8 +397,8 @@ impl<R: Reader> Loader<R> {
) -> Result<TensorGpu<f16, ReadWrite>> {
let context = &self.context;
let lora = self.lora_matrices(name.as_ref()).await?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorGpu::from_safetensors(&self.context, shape, tensor)?;
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorGpu::from_reader(&self.context, tensor)?;

if !lora.is_empty() {
let mut encoder = context.device.create_command_encoder(&Default::default());
Expand All @@ -403,8 +407,8 @@ impl<R: Reader> Loader<R> {
let factor = TensorGpu::from_data(context, Shape::new(4, 1, 1, 1), &factor)?;
let op = TensorOp::blend_lora(
&factor,
lora.b.view(.., .., .., ..)?,
lora.a.view(.., .., .., ..)?,
lora.y.view(.., .., .., ..)?,
lora.x.view(.., .., .., ..)?,
tensor.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
Expand All @@ -423,8 +427,8 @@ impl<R: Reader> Loader<R> {
let context = &self.context;

let lora = self.lora_matrices(name.as_ref()).await?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(context, shape, tensor)?
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_reader(context, tensor)?
.map(|x| f16::from_f32(discount * x.to_f32()));
let tensor = TensorGpu::from(tensor);

Expand All @@ -435,8 +439,8 @@ impl<R: Reader> Loader<R> {
let factor = TensorGpu::from_data(context, Shape::new(4, 1, 1, 1), &factor)?;
let op = TensorOp::blend_lora(
&factor,
lora.b.view(.., .., .., ..)?,
lora.a.view(.., .., .., ..)?,
lora.y.view(.., .., .., ..)?,
lora.x.view(.., .., .., ..)?,
tensor.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
Expand All @@ -455,8 +459,8 @@ impl<R: Reader> Loader<R> {
) -> Result<()> {
let context = &self.context;
let lora = self.lora_matrices(name.as_ref()).await?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::from_safetensors(context, shape, tensor)?;
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::from_reader(context, tensor)?;
matrix.load(&tensor)?;

if !lora.is_empty() {
Expand All @@ -466,8 +470,8 @@ impl<R: Reader> Loader<R> {
let factor = TensorGpu::from_data(context, Shape::new(4, 1, 1, 1), &factor)?;
let op = TensorOp::blend_lora(
&factor,
lora.b.view(.., .., .., ..)?,
lora.a.view(.., .., .., ..)?,
lora.y.view(.., .., .., ..)?,
lora.x.view(.., .., .., ..)?,
matrix.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
Expand All @@ -489,8 +493,8 @@ impl<R: Reader> Loader<R> {
let context = &self.context;

let lora = self.lora_matrices(name.as_ref()).await?;
let (shape, tensor) = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_safetensors(context, shape, tensor)?
let tensor = self.model.tensor(name.as_ref()).await?;
let tensor = TensorCpu::<f16>::from_reader(context, tensor)?
.map(|x| f16::from_f32(discount * x.to_f32()))
.reshape(Full, Full, Dimension(1), Dimension(1))?;
matrix.load(&tensor)?;
Expand All @@ -502,8 +506,8 @@ impl<R: Reader> Loader<R> {
let factor = TensorGpu::from_data(context, Shape::new(4, 1, 1, 1), &factor)?;
let op = TensorOp::blend_lora(
&factor,
lora.b.view(.., .., .., ..)?,
lora.a.view(.., .., .., ..)?,
lora.y.view(.., .., .., ..)?,
lora.x.view(.., .., .., ..)?,
matrix.view(.., .., .., ..)?,
)?;
let mut pass = encoder.begin_compute_pass(&Default::default());
Expand All @@ -516,19 +520,15 @@ impl<R: Reader> Loader<R> {
}

pub async fn load_embed<'b>(&self) -> Result<TensorCpu<'b, f16>> {
let (shape, tensor) = self.model.tensor("emb.weight").await?;
let num_emb = shape[1];
let num_vocab = shape[0];
let tensor = self.context.tensor_from_data(
Shape::new(num_emb, num_vocab, 1, 1),
bytemuck::pod_collect_to_vec(&tensor),
)?;
let (dt, shape, tensor) = self.model.tensor("emb.weight").await?;
let tensor = tensor.to_vec();
let tensor = TensorCpu::<f16>::from_reader(&self.context, (dt, shape, tensor.into()))?;
Ok(tensor)
}

pub async fn load_head(&self, chunk_size: usize) -> Result<Vec<TensorGpu<f16, ReadWrite>>> {
let context = &self.context;
let (shape, tensor) = self.model.tensor("head.weight").await?;
let (_, shape, tensor) = self.model.tensor("head.weight").await?;
let shape = Shape::new(shape[1], shape[0], 1, 1);
let chunks = (shape[1] + chunk_size - 1) / chunk_size;
let data = bytemuck::cast_slice(&tensor);
Expand Down Expand Up @@ -557,13 +557,13 @@ impl<R: Reader> Loader<R> {
match quant {
Quant::None => Ok(Matrix::Fp16(self.load_matrix_f16(name).await?)),
Quant::Int8 => {
let shape = self.tensor_shape(&name).await?;
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
self.load_in_place_matrix_f16(&buffer, &name).await?;
Ok(Matrix::quant_u8(&buffer)?)
}
Quant::NF4 => {
let shape = self.tensor_shape(&name).await?;
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
self.load_in_place_matrix_f16(&buffer, &name).await?;
Ok(Matrix::quant_nf4(&buffer)?)
Expand All @@ -584,14 +584,14 @@ impl<R: Reader> Loader<R> {
self.load_matrix_f16_discount(name, discount).await?,
)),
Quant::Int8 => {
let shape = self.tensor_shape(&name).await?;
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
self.load_in_place_matrix_f16_discount(&buffer, &name, discount)
.await?;
Ok(Matrix::quant_u8(&buffer)?)
}
Quant::NF4 => {
let shape = self.tensor_shape(&name).await?;
let shape = self.tensor_shape(&name)?;
let buffer = cache.checkout(shape, || context.tensor_init(shape));
self.load_in_place_matrix_f16_discount(&buffer, &name, discount)
.await?;
Expand Down
2 changes: 1 addition & 1 deletion src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl<R: Reader> ModelBuilder<R> {
token_chunk_size,
} = self;

let info = Loader::info(&model).await?;
let info = Loader::info(&model)?;
let loader = Loader {
context: context.clone(),
model,
Expand Down
Loading

0 comments on commit 44872e5

Please sign in to comment.