Skip to content

Commit

Permalink
Implement block-wise int8.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 11, 2024
1 parent f2fca20 commit 5f9dca1
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 720 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.6.5"
version = "0.6.6"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
6 changes: 3 additions & 3 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ fn restore_terminal(terminal: &mut Terminal<CrosstermBackend<std::io::Stdout>>)

async fn run(cli: Cli) -> Result<()> {
let tokenizer = load_tokenizer()?;
let model = cli.model.unwrap_or(
let model = cli.model.unwrap_or_else(|| {
std::fs::read_dir("assets/models")
.unwrap()
.filter_map(|x| x.ok())
.find(|x| x.path().extension().is_some_and(|x| x == "st"))
.unwrap()
.path(),
);
.path()
});

let file = File::open(model)?;
let map = unsafe { Mmap::map(&file)? };
Expand Down
6 changes: 3 additions & 3 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ fn load_prompt(path: Option<PathBuf>) -> Result<Prompt> {

async fn run(cli: Cli) -> Result<()> {
let tokenizer = load_tokenizer()?;
let model = cli.model.unwrap_or(
let model = cli.model.unwrap_or_else(|| {
std::fs::read_dir("assets/models")
.unwrap()
.filter_map(|x| x.ok())
.find(|x| x.path().extension().is_some_and(|x| x == "st"))
.unwrap()
.path(),
);
.path()
});
let prompt = load_prompt(cli.prompt)?;
let sampler = cli.sampler;

Expand Down
6 changes: 3 additions & 3 deletions examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ fn load_model<M: Model>(

async fn run(cli: Cli) -> Result<()> {
let tokenizer = load_tokenizer()?;
let model = cli.model.unwrap_or(
let model = cli.model.unwrap_or_else(|| {
std::fs::read_dir("assets/models")
.unwrap()
.filter_map(|x| x.ok())
.find(|x| x.path().extension().is_some_and(|x| x == "st"))
.unwrap()
.path(),
);
.path()
});

let file = File::open(model)?;
let map = unsafe { Mmap::map(&file)? };
Expand Down
6 changes: 3 additions & 3 deletions examples/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ fn load_model<M: Model>(

async fn run(cli: Cli) -> Result<()> {
let tokenizer = load_tokenizer()?;
let model = cli.model.unwrap_or(
let model = cli.model.unwrap_or_else(|| {
std::fs::read_dir("assets/models")
.unwrap()
.filter_map(|x| x.ok())
.find(|x| x.path().extension().is_some_and(|x| x == "st"))
.unwrap()
.path(),
);
.path()
});

let file = File::open(model)?;
let map = unsafe { Mmap::map(&file)? };
Expand Down
58 changes: 24 additions & 34 deletions src/shaders/matmul_mat_int8.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,17 @@ struct Input {
@group(0) @binding(1) var<uniform> vb: View; // [K, N, B]
@group(0) @binding(2) var<uniform> destination: View; // [M, N, B]

@group(0) @binding(3) var<storage, read> mx: array<vec4<f32>>; // (B, K)
@group(0) @binding(4) var<storage, read> rx: array<vec4<f32>>; // (B, K)
@group(0) @binding(5) var<storage, read> my: array<vec4<f32>>; // (B, M)
@group(0) @binding(6) var<storage, read> ry: array<vec4<f32>>; // (B, M)

@group(0) @binding(7) var<storage, read> xa: array<u32>; // (B, M, K)
@group(0) @binding(8) var<storage, read> xb: array<vec2<u32>>; // (B, N, K)
@group(0) @binding(3) var<storage, read> minmax: array<u32>;
@group(0) @binding(4) var<storage, read> xa: array<u32>; // (B, M, K)
@group(0) @binding(5) var<storage, read> xb: array<vec2<u32>>; // (B, N, K)
#ifdef OUT_FP16
@group(0) @binding(9) var<storage, read_write> output: array<vec2<u32>>; // (B, N, M)
@group(0) @binding(6) var<storage, read_write> output: array<vec2<u32>>; // (B, N, M)
#else
@group(0) @binding(9) var<storage, read_write> output: array<vec4<f32>>; // (B, N, M)
@group(0) @binding(6) var<storage, read_write> output: array<vec4<f32>>; // (B, N, M)
#endif

const TILE_SIZE: u32 = BLOCK_SIZE * 4u;

var<workgroup> smx: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> srx: array<vec4<f32>, BLOCK_SIZE>;
const INT8_BLOCK_STEP: u32 = INT8_BLOCK_SIZE / 4u;

var<workgroup> sa: array<array<u32, BLOCK_SIZE>, TILE_SIZE>;
var<workgroup> sb: array<array<vec2<u32>, BLOCK_SIZE>, TILE_SIZE>;
Expand All @@ -42,12 +36,17 @@ fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 {
return dot(vec3<u32>(batch, token, index) + offset, vec3<u32>(view.stride.y * stride, stride, 1u));
}

fn pack4x16float(x: vec4<f32>) -> vec2<u32> {
return vec2<u32>(pack2x16float(x.xy), pack2x16float(x.zw));
}

fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
}

fn pack4x16float(x: vec4<f32>) -> vec2<u32> {
return vec2<u32>(pack2x16float(x.xy), pack2x16float(x.zw));
fn unpack_minmax(index: u32) -> vec2<f32> {
let i = index / INT8_BLOCK_STEP;
return unpack2x16float(minmax[i]);
}

fn squared_relu(x: vec4<f32>) -> vec4<f32> {
Expand All @@ -64,22 +63,8 @@ fn matmul(in: Input) {
let rb = vec2<u32>(vb.shape.x / 4u, vb.shape.y);
let stride = min(ra.x, rb.x);

let myy = my[in.uid.z * ra.y + in.uid.x];
let ryy = ry[in.uid.z * ra.y + in.uid.x];

var local_sum: mat4x4<f32>;
for (var k = 0u; k < stride; k += BLOCK_SIZE) {
// load 8x4 mx and rx
let i = in.tid.x;
let x = k + i;
if x < ra.x {
smx[i] = mx[in.uid.z * ra.x + x];
srx[i] = rx[in.uid.z * ra.x + x];
} else {
smx[i] = vec4<f32>(0.0);
srx[i] = vec4<f32>(0.0);
}

// load 8x4 rows from each of the matrix, each with 8x4 columns
for (var j = in.tid.y; j < TILE_SIZE; j += BLOCK_SIZE) {
let i = in.tid.x;
Expand All @@ -102,17 +87,22 @@ fn matmul(in: Input) {

// each thread multiplies and sums up 4x4 blocks along the reduced dimension
if all(u < vec2<u32>(ra.y, rb.y)) {
var i = compute_index(va, in.uid.z, u.x, k);
var b: array<vec2<f32>, 4>;
b[0] = unpack_minmax(i); i += stride;
b[1] = unpack_minmax(i); i += stride;
b[2] = unpack_minmax(i); i += stride;
b[3] = unpack_minmax(i);

for (var x = 0u; x < BLOCK_SIZE; x += 1u) {
if k + x >= stride {
break;
}
let mxx = smx[x];
let rxx = srx[x];
let aa = mat4x4<f32>(
fma(unpack4x8unorm(sa[t.x][x]), ryy[0] * rxx, myy[0] + mxx),
fma(unpack4x8unorm(sa[t.x + 1u][x]), ryy[1] * rxx, myy[1] + mxx),
fma(unpack4x8unorm(sa[t.x + 2u][x]), ryy[2] * rxx, myy[2] + mxx),
fma(unpack4x8unorm(sa[t.x + 3u][x]), ryy[3] * rxx, myy[3] + mxx),
fma(unpack4x8unorm(sa[t.x][x]), vec4<f32>(b[0][1] - b[0][0]), vec4<f32>(b[0][0])),
fma(unpack4x8unorm(sa[t.x + 1u][x]), vec4<f32>(b[1][1] - b[1][0]), vec4<f32>(b[1][0])),
fma(unpack4x8unorm(sa[t.x + 2u][x]), vec4<f32>(b[2][1] - b[2][0]), vec4<f32>(b[2][0])),
fma(unpack4x8unorm(sa[t.x + 3u][x]), vec4<f32>(b[3][1] - b[3][0]), vec4<f32>(b[3][0])),
);
let bb = mat4x4<f32>(
unpack4x16float(sb[t.y][x]),
Expand Down
41 changes: 21 additions & 20 deletions src/shaders/matmul_mat_nf4.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct Input {
#endif

const TILE_SIZE: u32 = BLOCK_SIZE * 4u;
const NF4_BLOCK_STEP: u32 = 2u * NF4_BLOCK_SIZE / 8u;
const NF4_BLOCK_STEP: u32 = NF4_BLOCK_SIZE / 8u;

var<workgroup> sa: array<array<u32, BLOCK_SIZE>, TILE_SIZE>;
var<workgroup> sb: array<array<vec4<u32>, BLOCK_SIZE>, TILE_SIZE>;
Expand All @@ -46,6 +46,11 @@ fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
}

fn unpack_absmax(index: u32) -> f32 {
let i = index / NF4_BLOCK_STEP; // 1 block of absmax: NF4_BLOCK_SIZE / 8u entries in matrix
return unpack2x16float(absmax[i >> 1u])[i & 1u];
}

fn unpack_matrix_0(v: u32) -> vec4<f32> {
let i = vec4<u32>(
(v & 0x0000000fu),
Expand Down Expand Up @@ -93,7 +98,6 @@ fn matmul(in: Input) {
if in.index == 0u {
q = quant;
}
// workgroupBarrier();

var local_sum: mat4x4<f32>;
for (var k = 0u; k < stride; k += BLOCK_SIZE) {
Expand All @@ -120,32 +124,29 @@ fn matmul(in: Input) {

// each thread multiplies and sums up 4x4 blocks along the reduced dimension
if all(u < vec2<u32>(ra.y, rb.y)) {
let i = compute_index(va, in.uid.z, u.x, k, 4u);
let j = (k / BLOCK_SIZE) % 2u;
let a = vec4<f32>(
unpack2x16float(absmax[i / NF4_BLOCK_STEP])[j],
unpack2x16float(absmax[(i + stride) / NF4_BLOCK_STEP])[j],
unpack2x16float(absmax[(i + 2u * stride) / NF4_BLOCK_STEP])[j],
unpack2x16float(absmax[(i + 3u * stride) / NF4_BLOCK_STEP])[j],
);
var i = compute_index(va, in.uid.z, u.x, k, 4u);
var a: vec4<f32>;
a[0] = unpack_absmax(i); i += stride;
a[1] = unpack_absmax(i); i += stride;
a[2] = unpack_absmax(i); i += stride;
a[3] = unpack_absmax(i);

for (var x = 0u; x < BLOCK_SIZE; x += 1u) {
if k + x >= stride {
break;
}

let ssa = vec4<u32>(
let la = vec4<u32>(
sa[t.x][x],
sa[t.x + 1u][x],
sa[t.x + 2u][x],
sa[t.x + 3u][x],
);

var aa = mat4x4<f32>(
a[0] * unpack_matrix_0(ssa[0]),
a[1] * unpack_matrix_0(ssa[1]),
a[2] * unpack_matrix_0(ssa[2]),
a[3] * unpack_matrix_0(ssa[3]),
a[0] * unpack_matrix_0(la[0]),
a[1] * unpack_matrix_0(la[1]),
a[2] * unpack_matrix_0(la[2]),
a[3] * unpack_matrix_0(la[3]),
);
var bb = mat4x4<f32>(
unpack4x16float(sb[t.y][x].xy),
Expand All @@ -156,10 +157,10 @@ fn matmul(in: Input) {
local_sum += transpose(aa) * bb;

aa = mat4x4<f32>(
a[0] * unpack_matrix_1(ssa[0]),
a[1] * unpack_matrix_1(ssa[1]),
a[2] * unpack_matrix_1(ssa[2]),
a[3] * unpack_matrix_1(ssa[3]),
a[0] * unpack_matrix_1(la[0]),
a[1] * unpack_matrix_1(la[1]),
a[2] * unpack_matrix_1(la[2]),
a[3] * unpack_matrix_1(la[3]),
);
bb = mat4x4<f32>(
unpack4x16float(sb[t.y][x].zw),
Expand Down
44 changes: 17 additions & 27 deletions src/shaders/matmul_vec_int8.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,17 @@ struct View {
@group(0) @binding(2) var<uniform> destination: View; // [R, T, B]

@group(0) @binding(3) var<storage, read> matrix: array<u32>; // (B, R, C)
@group(0) @binding(4) var<storage, read> mx: array<vec4<f32>>; // (B, C)
@group(0) @binding(5) var<storage, read> rx: array<vec4<f32>>; // (B, C)
@group(0) @binding(6) var<storage, read> my: array<vec4<f32>>; // (B, R)
@group(0) @binding(7) var<storage, read> ry: array<vec4<f32>>; // (B, R)
@group(0) @binding(4) var<storage, read> minmax: array<u32>;

#ifdef IN_FP16
@group(0) @binding(8) var<storage, read> input: array<vec2<u32>>; // (B, T, C)
#else
@group(0) @binding(8) var<storage, read> input: array<vec4<f32>>; // (B, T, C)
#endif
@group(0) @binding(5) var<storage, read> input: array<vec2<u32>>; // (B, T, C)
#ifdef OUT_FP16
@group(0) @binding(9) var<storage, read_write> output: array<vec2<u32>>; // (B, T, R)
@group(0) @binding(6) var<storage, read_write> output: array<vec2<u32>>; // (B, T, R)
#else
@group(0) @binding(9) var<storage, read_write> output: array<vec4<f32>>; // (B, T, R)
@group(0) @binding(6) var<storage, read_write> output: array<vec4<f32>>; // (B, T, R)
#endif

const INT8_BLOCK_STEP: u32 = INT8_BLOCK_SIZE / 4u;

var<workgroup> sketch: array<vec4<f32>, BLOCK_SIZE>;

fn compute_index(view: View, batch: u32, token: u32, index: u32) -> u32 {
Expand All @@ -41,6 +36,11 @@ fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
return vec4<f32>(unpack2x16float(x.x), unpack2x16float(x.y));
}

fn unpack_minmax(index: u32) -> vec2<f32> {
let i = index / INT8_BLOCK_STEP;
return unpack2x16float(minmax[i]);
}

fn squared_relu(x: vec4<f32>) -> vec4<f32> {
let p = max(x, vec4<f32>(0.0));
return p * p;
Expand All @@ -64,11 +64,6 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let bb = compute_index(source, batch, token, 0u);
let cb = batch * shape.y * stride + channel * 4u * stride;

// let myc = unpack4x16float(my[channel]);
// let ryc = unpack4x16float(ry[channel]);
let myc = my[batch * shape.y + channel];
let ryc = ry[batch * shape.y + channel];

var local_sum = vec4<f32>(0.0);
for (var i = index; i < stride; i += BLOCK_SIZE) {
let bti = bb + i;
Expand All @@ -81,18 +76,13 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
let x = input[bti];
#endif

// let mxi = unpack4x16float(mx[i]);
// let rxi = unpack4x16float(rx[i]);
let mxi = mx[batch * stride + i];
let rxi = rx[batch * stride + i];

// read 4 rows from the matrix, each with 4 unpacked floats, forming a 4x4 sub-block
var m: mat4x4<f32>;

m[0] = fma(unpack4x8unorm(matrix[ci]), ryc[0] * rxi, myc[0] + mxi); ci += stride;
m[1] = fma(unpack4x8unorm(matrix[ci]), ryc[1] * rxi, myc[1] + mxi); ci += stride;
m[2] = fma(unpack4x8unorm(matrix[ci]), ryc[2] * rxi, myc[2] + mxi); ci += stride;
m[3] = fma(unpack4x8unorm(matrix[ci]), ryc[3] * rxi, myc[3] + mxi);
var b: vec2<f32>;
b = unpack_minmax(ci); m[0] = fma(unpack4x8unorm(matrix[ci]), vec4<f32>(b[1] - b[0]), vec4<f32>(b[0])); ci += stride;
b = unpack_minmax(ci); m[1] = fma(unpack4x8unorm(matrix[ci]), vec4<f32>(b[1] - b[0]), vec4<f32>(b[0])); ci += stride;
b = unpack_minmax(ci); m[2] = fma(unpack4x8unorm(matrix[ci]), vec4<f32>(b[1] - b[0]), vec4<f32>(b[0])); ci += stride;
b = unpack_minmax(ci); m[3] = fma(unpack4x8unorm(matrix[ci]), vec4<f32>(b[1] - b[0]), vec4<f32>(b[0]));
local_sum += transpose(m) * x;
}
sketch[index] = local_sum;
Expand Down Expand Up @@ -123,4 +113,4 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
output[btc] = out;
#endif
}
}
}
5 changes: 3 additions & 2 deletions src/shaders/matmul_vec_nf4.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct View {
@group(0) @binding(7) var<storage, read_write> output: array<vec4<f32>>; // (B, T, R)
#endif

const NF4_BLOCK_STEP: u32 = NF4_BLOCK_SIZE / 8u;

var<workgroup> sketch: array<vec4<f32>, BLOCK_SIZE>;
var<workgroup> q: array<vec4<f32>, 4u>;

Expand All @@ -37,7 +39,7 @@ fn unpack4x16float(x: vec2<u32>) -> vec4<f32> {
}

fn unpack_absmax(index: u32) -> f32 {
let i = index / (NF4_BLOCK_SIZE / 8u); // 1 block of absmax: NF4_BLOCK_SIZE / 8u entries in matrix
let i = index / NF4_BLOCK_STEP; // 1 block of absmax: NF4_BLOCK_SIZE / 8u entries in matrix
return unpack2x16float(absmax[i >> 1u])[i & 1u];
}

Expand Down Expand Up @@ -97,7 +99,6 @@ fn matmul(@builtin(global_invocation_id) invocation_id: vec3<u32>) {
if index == 0u {
q = quant;
}
workgroupBarrier();

var local_sum = vec4<f32>(0.0);
for (var i = index; i < stride; i += BLOCK_SIZE) {
Expand Down
Loading

0 comments on commit 5f9dca1

Please sign in to comment.