Skip to content

Commit

Permalink
Allow specifying rescale layers when loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jul 3, 2024
1 parent 2b60c4c commit 7428d28
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.8.16"
version = "0.8.17"

[dependencies]
ahash = "0.8"
Expand Down
18 changes: 13 additions & 5 deletions src/runtime/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pub trait Build<T> {
pub struct ModelBuilder<R: Reader> {
pub context: Context,
pub model: R,
pub rescale: usize,
pub lora: Vec<Lora<R>>,
pub quant: HashMap<usize, Quant>,
pub embed_device: EmbedDevice,
Expand All @@ -133,12 +134,24 @@ impl<R: Reader> ModelBuilder<R> {
Self {
context: context.clone(),
model,
rescale: 6,
lora: vec![],
quant: Default::default(),
embed_device: Default::default(),
}
}

/// Half the layer and activation every `value` layers.
pub fn rescale(mut self, value: usize) -> Self {
self.rescale = value.max(1);
self
}

pub fn lora(mut self, value: Lora<R>) -> Self {
self.lora.push(value);
self
}

pub fn quant(mut self, value: HashMap<usize, Quant>) -> Self {
self.quant = value;
self
Expand All @@ -148,11 +161,6 @@ impl<R: Reader> ModelBuilder<R> {
self.embed_device = value;
self
}

pub fn lora(mut self, value: Lora<R>) -> Self {
self.lora.push(value);
self
}
}

pub trait ContextAutoLimits {
Expand Down
12 changes: 7 additions & 5 deletions src/runtime/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ use crate::{
pub struct Model {
pub context: Context,
pub info: ModelInfo,
pub rescale: usize,
pub tensor: ModelTensor,
}

impl Model {
pub const RESCALE_LAYER: usize = 6;

pub const LN_EPS: f32 = 1.0e-5;
pub const GN_EPS: f32 = 64.0e-5;
}
Expand Down Expand Up @@ -598,7 +597,7 @@ impl<F: Float> JobBuilder<InferJob> for ModelRuntime<F> {
let frame = frame.clone();
let layer = layer.clone();

let op = build_layer(hooks, frame, layer, index, num_token)?;
let op = build_layer(hooks, frame, layer, index, num_token, model.rescale)?;
ops.push(op);

if (index + 1) % (info.num_layer / super::infer::NUM_LAYER_CHUNK) == 0 {
Expand Down Expand Up @@ -644,6 +643,7 @@ fn build_layer<F: Float>(
layer: Layer,
index: usize,
num_token: usize,
rescale: usize,
) -> Result<TensorOp> {
let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
let Frame { state, buffer, .. } = &frame;
Expand Down Expand Up @@ -812,7 +812,7 @@ fn build_layer<F: Float>(
hook_op(Hook::PostFfn(index))?,
]);

if (index + 1) % Model::RESCALE_LAYER == 0 {
if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
}

Expand Down Expand Up @@ -857,6 +857,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
let ModelBuilder {
context,
model,
rescale,
lora,
quant,
embed_device,
Expand Down Expand Up @@ -900,7 +901,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
let mut layers = vec![];
for layer in 0..info.num_layer {
let quant = quant.get(&layer).copied().unwrap_or_default();
let discount = 2.0_f32.powi(-((layer / Model::RESCALE_LAYER) as i32));
let discount = 2.0_f32.powi(-((layer / rescale) as i32));

let att_layer_norm = LayerNorm {
w: loader
Expand Down Expand Up @@ -978,6 +979,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
Model {
context,
info,
rescale,
tensor,
}
};
Expand Down
20 changes: 15 additions & 5 deletions src/runtime/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ use crate::{
pub struct Model {
pub context: Context,
pub info: ModelInfo,
pub rescale: usize,
pub tensor: ModelTensor,
}

impl Model {
pub const RESCALE_LAYER: usize = 6;

pub const LN_EPS: f32 = 1.0e-5;
pub const GN_EPS: f32 = 64.0e-5;
}
Expand Down Expand Up @@ -610,7 +609,15 @@ impl<F: Float> JobBuilder<InferJob> for ModelRuntime<F> {
let frame = frame.clone();
let layer = layer.clone();

let op = build_layer(hooks, frame, layer, index, num_token, head_size)?;
let op = build_layer(
hooks,
frame,
layer,
index,
num_token,
head_size,
model.rescale,
)?;
ops.push(op);

if (index + 1) % (info.num_layer / super::infer::NUM_LAYER_CHUNK) == 0 {
Expand Down Expand Up @@ -657,6 +664,7 @@ fn build_layer<F: Float>(
index: usize,
num_token: usize,
head_size: usize,
rescale: usize,
) -> Result<TensorOp> {
let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
let Frame { state, buffer, .. } = &frame;
Expand Down Expand Up @@ -884,7 +892,7 @@ fn build_layer<F: Float>(
hook_op(Hook::PostFfn(index))?,
]);

if (index + 1) % Model::RESCALE_LAYER == 0 {
if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
}

Expand Down Expand Up @@ -929,6 +937,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
let ModelBuilder {
context,
model,
rescale,
lora,
quant,
embed_device,
Expand Down Expand Up @@ -972,7 +981,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
let mut layers = vec![];
for layer in 0..info.num_layer {
let quant = quant.get(&layer).copied().unwrap_or_default();
let discount = 2.0_f32.powi(-((layer / Model::RESCALE_LAYER) as i32));
let discount = 2.0_f32.powi(-((layer / rescale) as i32));

let att_layer_norm = LayerNorm {
w: loader
Expand Down Expand Up @@ -1075,6 +1084,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
Model {
context,
info,
rescale,
tensor,
}
};
Expand Down
20 changes: 15 additions & 5 deletions src/runtime/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ use crate::{
pub struct Model {
pub context: Context,
pub info: ModelInfo,
pub rescale: usize,
pub tensor: ModelTensor,
}

impl Model {
pub const RESCALE_LAYER: usize = 6;

pub const LN_EPS: f32 = 1.0e-5;
pub const GN_EPS: f32 = 64.0e-5;
}
Expand Down Expand Up @@ -640,7 +639,15 @@ impl<F: Float> JobBuilder<InferJob> for ModelRuntime<F> {
let frame = frame.clone();
let layer = layer.clone();

let op = build_layer(hooks, frame, layer, index, num_token, head_size)?;
let op = build_layer(
hooks,
frame,
layer,
index,
num_token,
head_size,
model.rescale,
)?;
ops.push(op);

if (index + 1) % (info.num_layer / super::infer::NUM_LAYER_CHUNK) == 0 {
Expand Down Expand Up @@ -687,6 +694,7 @@ fn build_layer<F: Float>(
index: usize,
num_token: usize,
head_size: usize,
rescale: usize,
) -> Result<TensorOp> {
let hook_op = |hook: Hook| hook_op(&hooks, &hook, &frame);
let Frame { state, buffer, .. } = &frame;
Expand Down Expand Up @@ -950,7 +958,7 @@ fn build_layer<F: Float>(
hook_op(Hook::PostFfn(index))?,
]);

if (index + 1) % Model::RESCALE_LAYER == 0 {
if (index + 1) % rescale == 0 {
ops.push(TensorOp::discount(&buffer.x, 0.5, 0.0)?);
}

Expand Down Expand Up @@ -995,6 +1003,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
let ModelBuilder {
context,
model,
rescale,
lora,
quant,
embed_device,
Expand Down Expand Up @@ -1038,7 +1047,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
let mut layers = vec![];
for layer in 0..info.num_layer {
let quant = quant.get(&layer).copied().unwrap_or_default();
let discount = 2.0_f32.powi(-((layer / Model::RESCALE_LAYER) as i32));
let discount = 2.0_f32.powi(-((layer / rescale) as i32));

let att_layer_norm = LayerNorm {
w: loader
Expand Down Expand Up @@ -1181,6 +1190,7 @@ impl<R: Reader> Build<Model> for ModelBuilder<R> {
Model {
context,
info,
rescale,
tensor,
}
};
Expand Down

0 comments on commit 7428d28

Please sign in to comment.