Skip to content

Commit

Permalink
[JS/WebGPU] Creating devices with subgroup features enabled if possib…
Browse files Browse the repository at this point in the history
…le (#21833)

This CL make WebGPU backend support subgroup features and thus allow
using subgroup optimizations in the future.

### Description
With this CL WebGPU backends will create devices with subgroups and
subgroups-f16 features (both are under origin trial in Chrome) or
chromium-experimental-subgroups feature enabled whenever available.

### Motivation and Context
This CL would allow WebGPU operator shaders to use subgroup
optimizations in the future, and might get some significant speedup with
these optimization.
  • Loading branch information
jiangzhaoming authored Nov 7, 2024
1 parent 3b7a6eb commit 6a295eb
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 26 deletions.
40 changes: 34 additions & 6 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { ProgramManager } from './webgpu/program-manager';
import {
AdapterInfo,
ComputeContext,
DeviceInfo,
GpuArchitecture,
GpuData,
GpuVendor,
Expand Down Expand Up @@ -134,13 +135,34 @@ class AdapterInfoImpl implements AdapterInfo {
}
}

class DeviceInfoImpl implements DeviceInfo {
readonly subgroupsSupported: boolean;
readonly subgroupsF16Supported: boolean;
readonly subgroupSizeRange?: readonly [number, number];

constructor(device: GPUDevice) {
this.subgroupsSupported = device.features.has('subgroups' as GPUFeatureName);
this.subgroupsF16Supported = device.features.has('subgroups' as GPUFeatureName);
// Currently subgroups feature is still experimental and size attributes are not in the WebGPU IDL, so we have to
// workaround the IDL type checks.
// TODO: clean this after subgroups feature is settled in IDL.
const deviceSubgroupsLimits = device.limits as { minSubgroupSize?: number; maxSubgroupSize?: number };
if (!this.subgroupsSupported || !deviceSubgroupsLimits.minSubgroupSize || !deviceSubgroupsLimits.maxSubgroupSize) {
this.subgroupSizeRange = undefined;
} else {
this.subgroupSizeRange = [deviceSubgroupsLimits.minSubgroupSize, deviceSubgroupsLimits.maxSubgroupSize];
}
}
}

/**
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
* the first parameter so that it is stored for future use.
*/
export class WebGpuBackend {
adapterInfo: AdapterInfoImpl;
device: GPUDevice;
deviceInfo: DeviceInfoImpl;
/**
* an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
*/
Expand Down Expand Up @@ -243,16 +265,22 @@ export class WebGpuBackend {
requiredFeatures,
};

if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
} else if (adapter.features.has('timestamp-query')) {
requiredFeatures.push('timestamp-query');
// Try requiring WebGPU features
const requireFeatureIfAvailable = (feature: GPUFeatureName) =>
adapter.features.has(feature) && requiredFeatures.push(feature) && true;
// Try chromium-experimental-timestamp-query-inside-passes and fallback to timestamp-query
if (!requireFeatureIfAvailable('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName)) {
requireFeatureIfAvailable('timestamp-query');
}
if (adapter.features.has('shader-f16')) {
requiredFeatures.push('shader-f16');
requireFeatureIfAvailable('shader-f16');
// Try subgroups
if (requireFeatureIfAvailable('subgroups' as GPUFeatureName)) {
// If subgroups feature is available, also try subgroups-f16
requireFeatureIfAvailable('subgroups-f16' as GPUFeatureName);
}

this.device = await adapter.requestDevice(deviceDescriptor);
this.deviceInfo = new DeviceInfoImpl(this.device);
this.adapterInfo = new AdapterInfoImpl(adapter.info || (await adapter.requestAdapterInfo()));
this.gpuDataManager = createGpuDataManager(this);
this.programManager = new ProgramManager(this);
Expand Down
22 changes: 9 additions & 13 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ import { WebGpuBackend } from './backend-webgpu';
import { LOG_DEBUG } from './log';
import { TensorView } from './tensor-view';
import { ShapeUtil } from './util';
import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types';
import {
AdapterInfo,
ComputeContext,
ComputeContextInputsOutputsMapping,
DeviceInfo,
ProgramInfo,
} from './webgpu/types';
import { WebNNBackend } from './backend-webnn';

/* eslint-disable no-bitwise */
Expand Down Expand Up @@ -70,6 +76,7 @@ class TensorViewImpl implements TensorView {

class ComputeContextImpl implements ComputeContext {
readonly adapterInfo: AdapterInfo;
readonly deviceInfo: DeviceInfo;
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
Expand All @@ -87,6 +94,7 @@ class ComputeContextImpl implements ComputeContext {
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
this.deviceInfo = backend.deviceInfo;

// extract context data
const ptrSize = module.PTR_SIZE;
Expand All @@ -112,18 +120,6 @@ class ComputeContextImpl implements ComputeContext {
this.inputs = inputs;
}

getMaxComputeWorkgroupSizes(): [number, number, number] {
return [
this.backend.device.limits.maxComputeWorkgroupSizeX,
this.backend.device.limits.maxComputeWorkgroupSizeY,
this.backend.device.limits.maxComputeWorkgroupSizeZ,
];
}

getMaxComputeWorkgroupStoragesize(): number {
return this.backend.device.limits.maxComputeWorkgroupStorageSize;
}

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
Expand Down
20 changes: 15 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,23 @@ export class ProgramManager {
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
extensions.push('enable f16;');
}
const enableDirectives: string[] = [];

// Enable WGSL extensions based on available WebGPU features
const extensionsInfo: Array<{ feature: GPUFeatureName; extension: string }> = [
{ feature: 'shader-f16', extension: 'f16' },
{ feature: 'subgroups' as GPUFeatureName, extension: 'subgroups' },
{ feature: 'subgroups-f16' as GPUFeatureName, extension: 'subgroups_f16' },
];
extensionsInfo.forEach((info) => {
if (device.features.has(info.feature)) {
enableDirectives.push(`enable ${info.extension};`);
}
});

const shaderHelper = createShaderHelper(normalizedDispatchGroupSize, this.backend.device.limits);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const code = `${enableDirectives.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({ code, label: programInfo.name });
LOG_DEBUG('verbose', () => `[WebGPU] ${programInfo.name} shader code: ${code}`);

Expand Down
12 changes: 10 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ export interface AdapterInfo {
isArchitecture: (architecture: GpuArchitecture) => boolean;
isVendor: (vendor: GpuVendor) => boolean;
}
export interface DeviceInfo {
readonly subgroupsSupported: boolean;
readonly subgroupsF16Supported: boolean;
readonly subgroupSizeRange?: readonly [number, number];
}

export interface GpuData {
type: GpuDataType;
Expand Down Expand Up @@ -160,6 +165,11 @@ export interface ComputeContext {
*/
readonly adapterInfo: AdapterInfo;

/**
* gpu device info
*/
readonly deviceInfo: DeviceInfo;

/**
* stores the pointer to OpKernelContext
*/
Expand Down Expand Up @@ -187,8 +197,6 @@ export interface ComputeContext {

compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[];
output(index: number, dims: readonly number[]): number;
getMaxComputeWorkgroupSizes(): [number, number, number];
getMaxComputeWorkgroupStoragesize(): number;
}

export type TimestampQuery = 'none' | 'inside-passes' | 'at-passes';

0 comments on commit 6a295eb

Please sign in to comment.