Skip to content

Commit

Permalink
clean up from X Y Z workgroups implementaiton for CreateKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
austinvhuang committed Jun 16, 2024
1 parent ac7743b commit e1617bb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 26 deletions.
2 changes: 0 additions & 2 deletions examples/hello_world/run.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "gpu.h"
#include "nn/shaders.h"
#include "utils/logging.h"
#include <array>
#include <cstdio>

Expand All @@ -24,7 +23,6 @@ fn main(
)";

int main(int argc, char **argv) {
log(kDefLog, kInfo, "Hello, gpu.cpp!");
GPUContext ctx = CreateContext();
fprintf(stdout, "\nHello, gpu.cpp\n\n");
static constexpr size_t N = 3072;
Expand Down
35 changes: 11 additions & 24 deletions gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template <std::size_t N> GPUTensors(std::array<GPUTensor, N>) -> GPUTensors<N>;
template <typename... Args> GPUTensors(Args...) -> GPUTensors<sizeof...(Args)>;

struct TensorPool {
TensorPool(GPUContext *ctx) : ctx(ctx), data() {};
TensorPool(GPUContext *ctx) : ctx(ctx), data(){};
GPUContext *ctx;
std::unordered_map<WGPUBuffer, GPUTensor> data;
~TensorPool();
Expand Down Expand Up @@ -140,7 +140,7 @@ struct Kernel {
size_t outputSize;
size_t numBuffers;
size_t numInputs;
WGPUCommandBuffer commandBuffer; // destroyed upon submission
WGPUCommandBuffer commandBuffer; // destroyed upon submission
WGPUComputePipeline computePipeline; // persists between submission
WGPUBuffer readbackBuffer;
CallbackDataDyn callbackData;
Expand Down Expand Up @@ -175,7 +175,8 @@ struct MultiKernel {
// paramSizes = 0 means no params buffer
std::unique_ptr<size_t[]> numInputs; // length = numShaders
WGPUCommandBuffer commandBuffer; // All kernels in the pipeline
WGPUComputePipeline computePipeline; // TODO(avh): decide how to handle compute pipelines for multikernel
WGPUComputePipeline computePipeline; // TODO(avh): decide how to handle
// compute pipelines for multikernel
WGPUBuffer readbackBuffer; // Readback buffer for the final output buffer
CallbackDataDyn callbackData;
std::promise<void> promise;
Expand Down Expand Up @@ -342,17 +343,6 @@ inline void check(bool condition, const char *message,
}
}

void showDeviceInfo(WGPUAdapter &adapter) {
WGPUAdapterProperties properties;
wgpuAdapterGetProperties(adapter, &properties);
printf("Device Name: %s\n", properties.name);
printf("Vendor ID: %u\n", properties.vendorID);
printf("Device ID: %u\n", properties.deviceID);
WGPULimits limits;
WGPUSupportedLimits supportedLimits;
wgpuAdapterGetLimits(adapter, &supportedLimits);
}

GPUContext CreateContext(bool quietLogging = true,
const WGPUInstanceDescriptor &desc = {},
const WGPURequestAdapterOptions &adapterOpts = {},
Expand Down Expand Up @@ -525,13 +515,7 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
const GPUTensor *inputs, size_t numInputs,
const GPUTensor &output, const void *params,
size_t paramsSize, Shape nThreads) {
if (nThreads.rank < 3) {
const size_t rank = nThreads.rank;
nThreads.rank = 3;
for (size_t i = rank; i < 3; i++) {
nThreads[i] = 1;
}
}
assert(nThreads.rank == 3);
WGPUDevice device = ctx.device;
WGPUQueue queue = ctx.queue;
Kernel op;
Expand Down Expand Up @@ -727,16 +711,19 @@ Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
return op;
}

// default nThreads to output.shape
Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
const GPUTensor *inputs, size_t numInputs,
const GPUTensor &output, const void *params = nullptr,
size_t paramsSize = 0) {
Shape nThreads = output.shape;
nThreads.rank = 3;
for (size_t i = output.shape.rank; i < 3; i++) {
nThreads[i] = 1;
}
return CreateKernel(ctx, shader, inputs, numInputs, output, params,
paramsSize, output.shape);
paramsSize, nThreads);
}

// comptime template for paramtype - is this needed?
template <typename ParamsType = NoParam>
Kernel CreateKernel(GPUContext &ctx, const ShaderCode &shader,
const GPUTensor *inputs, size_t numInputs,
Expand Down

0 comments on commit e1617bb

Please sign in to comment.