Skip to content

Commit

Permalink
migrate from google style casing to webgpu camelcasing, qkv projectio…
Browse files Browse the repository at this point in the history
…n (wip)
  • Loading branch information
austinvhuang committed Jun 28, 2024
1 parent fae185f commit 65ae859
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 82 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ invoked from the host using this library.
#include <cstdio>
#include <future>
using namespace gpu; // CreateContext, CreateTensor, CreateKernel,
// CreateShader, DispatchKernel, Wait, ToCPU
using namespace gpu; // createContext, createTensor, createKernel,
// createShader, dispatchKernel, wait, toCPU
// Tensor, Kernel, Context, Shape, kf32
static const char *kGelu = R"(
Expand All @@ -69,22 +69,22 @@ fn main(
int main(int argc, char **argv) {
printf("\nHello gpu.cpp!\n\n");
Context ctx = CreateContext();
Context ctx = createContext();
static constexpr size_t N = 10000;
std::array<float, N> inputArr, outputArr;
for (int i = 0; i < N; ++i) {
inputArr[i] = static_cast<float>(i) / 10.0; // dummy input data
}
Tensor input = CreateTensor(ctx, Shape{N}, kf32, inputArr.data());
Tensor output = CreateTensor(ctx, Shape{N}, kf32);
Tensor input = createTensor(ctx, Shape{N}, kf32, inputArr.data());
Tensor output = createTensor(ctx, Shape{N}, kf32);
std::promise<void> promise;
std::future<void> future = promise.get_future();
Kernel op = CreateKernel(ctx, CreateShader(kGelu, 256, kf32),
Kernel op = createKernel(ctx, createShader(kGelu, 256, kf32),
TensorList{input, output},
/* nthreads */ {N, 1, 1});
DispatchKernel(ctx, op, promise);
Wait(ctx, future);
ToCPU(ctx, output, outputArr.data(), sizeof(outputArr));
dispatchKernel(ctx, op, promise);
wait(ctx, future);
toCPU(ctx, output, outputArr.data(), sizeof(outputArr));
for (int i = 0; i < 16; ++i) {
printf(" gelu(%.2f) = %.2f\n", inputArr[i], outputArr[i]);
}
Expand All @@ -99,11 +99,11 @@ in a separate file to be loaded at runtime. The WGSL code is compiled and runs
on the GPU.

The CPU code in `main()` sets up the host coordination for the GPU computation.
The ahead-of-time resource acquisition functions are prefaced with `Create`,
such as `CreateContext`, `CreateTensor`, `CreateKernel`, `CreateShader`.
The ahead-of-time resource acquisition functions are prefaced with `create`,
such as `createContext`, `createTensor`, `createKernel`, `createShader`.

The dispatch occurs asynchronously via the `DispatchKernel` invocation. `Wait`
blocks until the GPU computation is complete and `ToCPU` moves data from the
The dispatch occurs asynchronously via the `dispatchKernel` invocation. `wait`
blocks until the GPU computation is complete and `toCPU` moves data from the
GPU to CPU. This example is available in `examples/hello_world/run.cpp`.

## Quick Start
Expand Down
4 changes: 4 additions & 0 deletions experimental/transformer/reference_impls.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#ifndef REFERENCE_IMPLS_H
#define REFERENCE_IMPLS_H

namespace ref {

#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)

void gelu_forward_cpu(float* out, const float* inp, int N) {
Expand Down Expand Up @@ -230,4 +232,6 @@ void attention_forward_cpu(float* out, float* preatt, float* att,
}
}

}

#endif
103 changes: 47 additions & 56 deletions experimental/transformer/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "utils/logging.h"
#include <array>

#include "reference_impls.h"

using namespace gpu;

static const char *kShaderGelu = R"(
Expand Down Expand Up @@ -43,51 +45,20 @@ static const char *kShaderMatmul0 = R"(
@group(0) @binding(0) var<storage, read_write> A: array<f32>;
@group(0) @binding(1) var<storage, read_write> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@compute @workgroup_size(workgroupSizeX, workgroupSizeY, 1)
fn matmul(
@compute @workgroup_size({{workgroupSize}})
fn main(
@builtin(global_invocation_id) global_id : vec3<u32>) {
// row and column of C
let row = global_id.y;
let row = global_id.y; // row and column of C
let col = global_id.x;
for (var k = 0u; k < {{K}}; k = k + 1u) {
// B is stored as B^T, effectively column-major
C[row * {{N}} + col] += A[row * {{K}} + k] * B[k + col * {{N}}];
}
}
");
static const char *kShaderMatMul = R"(
@group(0) @binding(0) var<storage, read_write> A: array<f32>;
@group(0) @binding(1) var<storage, read_write> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
var<workgroup> tileA: array<f32, workgroupSizeY * workgroupSizeX>;
var<workgroup> tileB: array<f32, workgroupSizeY * workgroupSizeX>;
@compute @workgroup_size(workgroupSizeX, workgroupSizeY, 1)
fn matmul(
@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>,
@builtin(workgroup_id) workgroup_id : vec3<u32>
) {
let row = global_id.x;
let col = global_id.y;
if (row >= {{M}} || col >= {{N}}) {
return;
}
var result: f32 = 0.0;
for (var i = 0u; i < {{K}}; i = i + workgroupSizeX) {
// Load tiles into shared memory
tileA[local_id.y][local_id.x] = A[row][i + local_id.x];
tileB[local_id.y][local_id.x] = B[i + local_id.y][col];
// Synchronize to make sure the tile is loaded
workgroupBarrier();
// Perform partial dot product for the current tile
for (var k = 0u; k < workgroupSizeX; k = k + 1u) {
result = result + tileA[local_id.y][k] * tileB[k][local_id.x];
}
// Synchronize before loading the next tile
workgroupBarrier();
var total: f32 = 0.0;
for (var k = 0u; k < {{K}}; k = k + 1u) {
// B is stored as B^T, effectively column-major
total += A[row * {{K}} + k] * B[col * {{N}} + k];
}
C[row][col] = result;
C[row * {{N}} + col] = total;
}
)";

Expand Down Expand Up @@ -128,7 +99,10 @@ void initTransformer(Context &ctx, size_t modelDim, size_t qkvDim,
// Initialize values
std::unique_ptr<float[]> qkvInit(new float[modelDim * 3 * qkvDim]);
randn(qkvInit.get(), size(transformer.qkv.shape), gen);
printf("%s", show<float>(qkvInit.get(), transformer.qkv.shape[0], transformer.qkv.shape[1], "QKV Weights").c_str());
LOG(kDefLog, kInfo, "%s",
show<float>(qkvInit.get(), transformer.qkv.shape[0],
transformer.qkv.shape[1], "QKV Weights")
.c_str());
toGPU(ctx, qkvInit.get(), transformer.qkv);

activations = {
Expand All @@ -140,10 +114,10 @@ void initTransformer(Context &ctx, size_t modelDim, size_t qkvDim,
};
}

inline ShaderCode createMatmul(const char *shaderTemplate,
const size_t M, const size_t K, const size_t N,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
inline ShaderCode createMatmul(const char *shaderTemplate, const size_t M,
const size_t K, const size_t N,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
std::string codeString(shaderTemplate);
ReplaceAll(codeString, "{{workgroupSize}}", toString(workgroupSize));
ReplaceAll(codeString, "{{precision}}", toString(precision));
Expand All @@ -156,34 +130,51 @@ inline ShaderCode createMatmul(const char *shaderTemplate,
int main() {
printf("\033[2J\033[1;1H");
Context ctx = createContext();
// static constexpr N = 3072;
static constexpr size_t N = 128;
static constexpr size_t seqLen = 24;
static constexpr size_t batchSize = 1;
static constexpr size_t modelDim = 3072;
static constexpr size_t modelDim = 2; // 3072;
static constexpr size_t hiddenWidth = modelDim * 2;
static constexpr size_t qkvDim = 256;
static constexpr size_t qkvDim = 1; //256;
std::mt19937 gen(314);

Transformer transformer;
Activations activations;
KVCache kvcache;
printf("Initializing transformer, allocating GPU buffers ...\n");
LOG(kDefLog, kInfo, "Initializing transformer, allocating GPU buffers ...\n");
initTransformer(ctx, modelDim, qkvDim, batchSize, seqLen, hiddenWidth,
transformer, activations, kvcache);

std::array<float, modelDim> inputArr;
std::array<float, modelDim * 3 * qkvDim> weightsArr;
randn(inputArr, gen);
randn(weightsArr, gen);
LOG(kDefLog, kInfo, "%s",
show<float>(inputArr.data(), 1, modelDim, "Input").c_str());
Tensor input = createTensor(ctx, Shape{modelDim}, kf32, inputArr.data());
Tensor output = createTensor(ctx, Shape{3 * qkvDim}, kf32);

ShaderCode matmul = createMatmul(kShaderMatmul0, modelDim, 3 * qkvDim, modelDim);





printf("Done\n");
ShaderCode matmul = createMatmul(kShaderMatmul0, 1, modelDim, 3 * qkvDim);
Kernel qkv =
createKernel(ctx, matmul, TensorList{transformer.qkv, input, output},
/*nthreads*/ {modelDim, 1, 1});
std::promise<void> promise;
std::future<void> future = promise.get_future();
dispatchKernel(ctx, qkv, promise);
wait(ctx, future);
std::array<float, 3 * qkvDim> outputArr;
toCPU(ctx, output, outputArr.data(), sizeof(outputArr));
LOG(kDefLog, kInfo, "Output: %s",
show<float>(outputArr.data(), 1, 3 * qkvDim, "QKV Output").c_str());

std::array<float, 3 * qkvDim> outputRefArr;
ref::matmul_forward_cpu(
outputRefArr.data(), inputArr.data(), weightsArr.data(), NULL,
/* batch */ 1, /* T */ 1, /* C */ modelDim, /* OC */ 3 * qkvDim);
LOG(kDefLog, kInfo, "Reference Output: %s",
show<float>(outputRefArr.data(), 1, 3 * qkvDim, "QKV Output (Reference)")
.c_str());

LOG(kDefLog, kInfo, isclose(outputArr.data(), outputRefArr.data(), 3 * qkvDim) ? "PASS" : "FAIL");

LOG(kDefLog, kInfo, "Done");
}
10 changes: 0 additions & 10 deletions experimental/transformer/test_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@

using namespace gpu;

bool isclose(float *a, float *b, size_t n, float tol = 1e-3) {
for (size_t i = 0; i < n; i++) {
if (std::abs(a[i] - b[i]) > tol || std::isnan(a[i]) || std::isnan(b[i])) {
LOG(kDefLog, kInfo, "Mismatch at index %d: %f != %f", i, a[i], b[i]);
return false;
}
}
return true;
}

void testResidual(Context &ctx) {
constexpr size_t N = 200000;
constexpr size_t workgroupSize = 256;
Expand Down
8 changes: 5 additions & 3 deletions gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ struct ShaderCode {
: data(data), workgroupSize(workgroupSize), precision(precision) {}
std::string data;
Shape workgroupSize;
NumType precision;
NumType precision = kf32;
std::string label = "shader";
std::string entryPoint = "main";
};

/**
Expand Down Expand Up @@ -228,7 +230,7 @@ struct Context {
TensorPool pool = TensorPool(this);
KernelPool kernelPool = KernelPool(this);
~Context() {
LOG(kDefLog, kTrace, "Destroying context");
LOG(kDefLog, kInfo, "Destroying context");
if (queue) {
wgpuQueueRelease(queue);
wgpuInstanceProcessEvents(instance);
Expand All @@ -252,7 +254,7 @@ struct Context {
} else {
LOG(kDefLog, kWarn, "Instance is null");
}
LOG(kDefLog, kTrace, "Destroyed context");
LOG(kDefLog, kInfo, "Destroyed context");
}
};

Expand Down
10 changes: 10 additions & 0 deletions utils/array_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ inline void flip(float* a, size_t R, size_t C, bool horizontal = true) {
}
}

bool isclose(float *a, float *b, size_t n, float tol = 1e-3) {
for (size_t i = 0; i < n; i++) {
if (std::abs(a[i] - b[i]) > tol || std::isnan(a[i]) || std::isnan(b[i])) {
LOG(kDefLog, kInfo, "Mismatch at index %d: %f != %f", i, a[i], b[i]);
return false;
}
}
return true;
}

} // namespace gpu

#endif

0 comments on commit 65ae859

Please sign in to comment.