diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 27c5566ab9fed..95bfae033e222 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -94,11 +94,27 @@ const getProgramInfoUniqueKey = return key; }; +export class AdapterInfo { + private vendor: string; + + constructor(adapterInfo: GPUAdapterInfo) { + if (adapterInfo) { + this.vendor = adapterInfo.vendor; + } + } + + // vendor could be intel, nvidia, amd, etc. + isVendor(vendor: string): boolean { + return this.vendor === vendor; + } +} + /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as * the first parameter so that it is stored for future use. */ export class WebGpuBackend { + adapterInfo: AdapterInfo; device: GPUDevice; /** * an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping @@ -212,6 +228,8 @@ export class WebGpuBackend { } this.device = await adapter.requestDevice(deviceDescriptor); + const adapterInfo = await adapter.requestAdapterInfo(); + this.adapterInfo = new AdapterInfo(adapterInfo); this.gpuDataManager = createGpuDataManager(this); this.programManager = new ProgramManager(this); this.kernels = new Map(); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b64abf9cc5424..3a53d3d510e47 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -6,7 +6,7 @@ import {Env} from 'onnxruntime-common'; import {OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; -import {WebGpuBackend} from './backend-webgpu'; +import {AdapterInfo, WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; import {TensorView} from './tensor-view'; import {ShapeUtil} from './util'; @@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView { } class ComputeContextImpl implements ComputeContext { + readonly adapterInfo: AdapterInfo; readonly opKernelContext: number; readonly inputs: readonly TensorView[]; readonly outputCount: number; @@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext { private customDataOffset = 0; private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { + this.adapterInfo = backend.adapterInfo; const heapU32 = module.HEAPU32; // extract context data diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 5afec0389fac8..3f6a242a39179 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -148,11 +148,12 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */ const isChannelsLast = attributes.format === 'NHWC'; if (attributes.group !== 1) { - // Temporarily disable createGroupedConvVectorizeProgramInfo path due to bots failures with below two cases: + // One CI bot with NVIDIA GPU fails with below 2 cases, but we couldn't repro them with any other GPUs, including NVIDIA ones. // [webgpu]Conv - conv - vectorize group - B // [webgpu]Conv - conv - vectorize group - D - const disableGroupedConvVectorize = true; - if (!disableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && + // Disable vectorize on NVIDIA to make bots happy. BTW, no obvious perf gain with vectorize is seen on NVIDIA GPUs. + const enableGroupedConvVectorize = context.adapterInfo.isVendor('nvidia') ? false : true; + if (enableGroupedConvVectorize && isChannelsLast && inputs[1].dims[0] === attributes.group && inputs[1].dims[1] === 1 && attributes.dilations[0] === 1 && attributes.dilations[1] === 1) { const outputShape = calculateOutputShape( inputs[0].dims, inputs[1].dims, attributes.dilations, adjustedAttributes.pads, attributes.strides, diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index ba5b84fcfe067..84fb832ed4b6f 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {DataType} from '../../wasm-common'; +import {AdapterInfo} from '../backend-webgpu' import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; @@ -146,6 +147,11 @@ export interface ComputeContextInputsOutputsMapping { * A ComputeContext instance carries the states that representing the current running of a kernel. */ export interface ComputeContext { + /** + * gpu adapter info + */ + readonly adapterInfo: AdapterInfo; + /** * stores the pointer to OpKernelContext */