Skip to content

Commit

Permalink
Use welford's algorithm in LN and GN.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jan 15, 2024
1 parent 143bc30 commit ce7572f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.5.0"
version = "0.5.1"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
53 changes: 41 additions & 12 deletions src/shaders/layer_norm.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
@group(0) @binding(3) var<storage, read_write> x: array<vec4<f32>>; // (B, T, C)
#endif

var<workgroup> sum: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> sum_squared: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> mu: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> m2: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> count: array<vec4<u32>, BLOCK_SIZE>;

var<workgroup> mean: f32;
var<workgroup> deviation: f32;

Expand All @@ -23,8 +25,17 @@ fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {

fn reduce_step(index: u32, stride: u32) {
if index < stride {
sum[index] += sum[index + stride];
sum_squared[index] += sum_squared[index + stride];
let mu_1 = mu[index];
let mu_2 = mu[index + stride];
let count_1 = count[index];
let count_2 = count[index + stride];

let delta = mu_2 - mu_1;
let total = count_1 + count_2;
count[index] = total;

mu[index] = select(vec4<f32>(0.0), (mu_1 * vec4<f32>(count_1) + mu_2 * vec4<f32>(count_2)) / vec4<f32>(total), total > vec4<u32>(0u));
m2[index] = select(vec4<f32>(0.0), m2[index] + m2[index + stride] + delta * delta * vec4<f32>(count_1 * count_2) / vec4<f32>(total), total > vec4<u32>(0u));
}
workgroupBarrier();
}
Expand All @@ -44,8 +55,13 @@ fn layer_norm(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
#else
let value = x[bb + i];
#endif
sum[index] += value;
sum_squared[index] += value * value;
let delta = value - mu[index];
let _count = count[index] + 1u;
let _mu = mu[index] + delta / vec4<f32>(_count);

count[index] = _count;
mu[index] = _mu;
m2[index] += delta * (value - _mu);
}
workgroupBarrier();

Expand All @@ -58,8 +74,12 @@ fn layer_norm(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
reduce_step(index, 1u);

if index == 0u {
mean = dot(sum[0], vec4<f32>(1.0)) / f32(shape[0]);
deviation = inverseSqrt(dot(sum_squared[0], vec4<f32>(1.0)) / f32(shape[0]) - mean * mean + EPS);
let _count = vec4<f32>(count[0]);
mean = dot(mu[0], _count / f32(shape[0]));

let _delta = mu[0] - mean;
let _m2 = dot(m2[0], vec4<f32>(1.0)) + dot(_delta * _delta, _count);
deviation = inverseSqrt(_m2 / f32(shape[0]) + EPS);
}
workgroupBarrier();

Expand Down Expand Up @@ -90,8 +110,13 @@ fn group_norm(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
#else
let value = x[th + i];
#endif
sum[index] += value;
sum_squared[index] += value * value;
let delta = value - mu[index];
let _count = count[index] + 1u;
let _mu = mu[index] + delta / vec4<f32>(_count);

count[index] = _count;
mu[index] = _mu;
m2[index] += delta * (value - _mu);
}
workgroupBarrier();

Expand All @@ -102,8 +127,12 @@ fn group_norm(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
reduce_step(index, 1u);

if index == 0u {
mean = dot(sum[0], vec4<f32>(1.0)) / f32(shape[0]);
deviation = inverseSqrt(dot(sum_squared[0], vec4<f32>(1.0)) / f32(shape[0]) - mean * mean + EPS);
let _count = vec4<f32>(count[0]);
mean = dot(mu[0], _count / f32(shape[0]));

let _delta = mu[0] - mean;
let _m2 = dot(m2[0], vec4<f32>(1.0)) + dot(_delta * _delta, _count);
deviation = inverseSqrt(_m2 / f32(shape[0]) + EPS);
}
workgroupBarrier();

Expand Down
23 changes: 16 additions & 7 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2150,6 +2150,7 @@ mod tests {
const C: usize = 1000;
const T: usize = 3;
const B: usize = 2;
const EPS: f32 = 1.0e-5;

let x = [(); C * T * B]
.map(|_| 10.0 * (fastrand::f32() - 0.5))
Expand All @@ -2171,7 +2172,7 @@ mod tests {
let w_dev = TensorGpu::from_data(&context, shape, &w[..1000])?;
let b_dev = TensorGpu::from_data(&context, shape, &b[..1000])?;

let layer_norm = TensorOp::layer_norm(&w_dev, &b_dev, &x_dev, 1.0e-5)?;
let layer_norm = TensorOp::layer_norm(&w_dev, &b_dev, &x_dev, EPS)?;

let mut encoder = context.device.create_command_encoder(&Default::default());

Expand All @@ -2194,11 +2195,19 @@ mod tests {
{
let chunk = chunk.collect_vec();
let x = chunk.iter().map(|((x, _), _)| x).copied();
let sum: f32 = x.clone().sum();
let squared_sum: f32 = x.clone().map(|x| x.powi(2)).sum();

let mean = sum / C as f32;
let deviation = ((squared_sum / C as f32) - mean.powi(2)).sqrt();
// let sum: f32 = x.clone().sum();
// let squared_sum: f32 = x.clone().map(|x| x.powi(2)).sum();

// let mean = sum / C as f32;
// let deviation = ((squared_sum / C as f32) - mean.powi(2)).sqrt();
let (mean, m2, count) = x.fold((0.0f32, 0.0f32, 0u32), |(mean, m2, count), x| {
let count = count + 1;
let delta = x - mean;
let mean = mean + delta / count as f32;
let m2 = m2 + delta * (x - mean);
(mean, m2, count)
});
let deviation = (m2 / count as f32 + EPS).sqrt();

let mut x: Vec<_> = chunk
.into_iter()
Expand All @@ -2209,7 +2218,7 @@ mod tests {

for (index, (a, b)) in Iterator::zip(x_host.into_iter(), ans.into_iter()).enumerate() {
assert!(
is_approx_eps(a, b, 1.0e-3),
is_approx_eps(a, b, 0.01),
"Failed at index {index}, computed: {a} vs. answer: {b}"
);
}
Expand Down

0 comments on commit ce7572f

Please sign in to comment.