Skip to content

Commit

Permalink
FP16 extension registration
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 11, 2023
1 parent 24f0893 commit 0d34fec
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
11 changes: 9 additions & 2 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export class WebGpuBackend {
}

this.env = env;
const requiredFeatures = [];
const deviceDescriptor: GPUDeviceDescriptor = {
requiredLimits: {
maxComputeWorkgroupStorageSize: adapter.limits.maxComputeWorkgroupStorageSize,
Expand All @@ -121,14 +122,20 @@ export class WebGpuBackend {
maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY,
maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ,
},
requiredFeatures: [],
};
// WebGPU Spec: Timestamp Queries Inside Passes
// https://github.com/gpuweb/gpuweb/blob/main/proposals/timestamp-query-inside-passes.md
if (adapter.features.has('timestamp-query-inside-passes')) {
this.supportTimestampQuery = true;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
deviceDescriptor.requiredFeatures = ['timestamp-query-inside-passes' as any];
requiredFeatures.push('timestamp-query-inside-passes');
}
if (adapter.features.has('shader-f16')) {
requiredFeatures.push('shader-f16');
}
//
// eslint-disable-next-line @typescript-eslint/no-explicit-any
deviceDescriptor.requiredFeatures = requiredFeatures as any;

this.device = await adapter.requestDevice(deviceDescriptor);
this.gpuDataManager = createGpuDataManager(this);
Expand Down
5 changes: 2 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,8 @@ export interface IndicesHelper {
const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
// return type is [ storage type, runtime type ] or a single string for both
switch (type) {
// TODO: enable after "shader-f16" WSGL extension release
// case DataType.float16:
// return components > 1 ? `vec${components}<f16>` : 'f16';
case DataType.float16:
return components > 1 ? `vec${components}<f16>` : 'f16';
case DataType.float:
return components > 1 ? `vec${components}<f32>` : 'f32';
case DataType.int32:
Expand Down
7 changes: 5 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ export class ProgramManager {
}
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
const device = this.backend.device;

const extensions: string[] = [];
if (this.backend.device.features.has('shader-f16')) {
extensions.push('enable f16;');
}
const shaderHelper = createShaderHelper(normalizedDispatchGroupSize);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${shaderHelper.additionalImplementations}\n${userCode}`;
const code = `${extensions.join('\n')}${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({code, label: programInfo.name});
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);

Expand Down

0 comments on commit 0d34fec

Please sign in to comment.