From b07317dc18991c570e40127d9ce24310d2872b19 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Fri, 29 Sep 2023 18:03:17 +0400 Subject: [PATCH] Everything works --- js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts | 41 +++++++++--------- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 43 ++++++++++--------- .../jsep/webgpu/ops/multi-head-attentiion.ts | 5 +-- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 4 +- js/web/lib/wasm/wasm-core-impl.ts | 2 +- onnxruntime/contrib_ops/cpu/bert/attention.cc | 32 +++++++------- .../contrib_ops/cpu/bert/attention_cpu_base.h | 20 ++++----- tools/ci_build/build.py | 2 +- 8 files changed, 74 insertions(+), 75 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 4f64adc847972..d56d0b9fadd5b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType} from '../types'; @@ -229,14 +228,13 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(D / components / WG); + const castToF32 = components === 1 ? 'f32' : `vec${components}f`; - // 6.2.4 in wgsl spec - const threadMaxMinValue = dataType === 'f32' ? '-3.402823e+38f' : '-65504.0h'; const getShaderSource = (shaderHelper: ShaderHelper) => ` const dInv: ${dataType} = 1 / ${D}; const dComp = ${D / components}; - var wgMax: array<${dataType}, ${WG}>; - var wgSum: array<${dataType}, ${WG}>; + var wgMax: array; + var wgSum: array; ${shaderHelper.declareVariables(inputHelper)} @compute @workgroup_size(${WG}, 1, 1) @@ -245,26 +243,26 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView let localOffset = local_index * ${elementsPerWG}; let offset: u32 = workgroup_id.x * dComp + localOffset; - var threadMaxVector = ${fillVector(dataType, components, threadMaxMinValue)}; + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - threadMaxVector = max(x[offset + i], threadMaxVector); + threadMaxVector = max(${castToF32}(x[offset + i]), threadMaxVector); } wgMax[local_index] = ${threadMaxValue}; workgroupBarrier(); - var maxValue = ${threadMaxMinValue}; + var maxValue = -3.402823e+38f; for (var i = 0u; i < ${WG}; i++) { maxValue = max(wgMax[i], maxValue); } - var sumVector = ${fillVector(dataType, components, '0')}; + var sumVector = ${fillVector('f32', components, '0')}; for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - sumVector += exp(x[offset + i] - maxValue); + sumVector += exp(${castToF32}(x[offset + i]) - maxValue); } wgSum[local_index] = ${sumVector('sumVector', components)}; workgroupBarrier(); - var sum: ${dataType} = 0; + var sum: f32 = 0; for (var i = 0u; i < ${WG}; i++) { sum += wgSum[i]; } @@ -275,7 +273,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } } else { for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = exp(x[offset + i] - maxValue) / sum; + x[offset + i] = ${inputHelper.type.storage}(exp(${castToF32}(x[offset + i]) - maxValue) / sum); } } }`; @@ -315,7 +313,8 @@ const computeAttentionProbs = const N = parameters.totalSequenceLength; const K = vectorizedHeadSize; - const TILE_SIZE = 8; + const TILE_SIZE = 12; + const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), @@ -328,7 +327,7 @@ const computeAttentionProbs = const M: u32 = ${M}u; const N: u32 = ${N}u; const K: u32 = ${K}u; - const alpha = ${dataType}(${alpha}); + const alpha: f32 = ${alpha}; const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; @@ -353,7 +352,7 @@ const computeAttentionProbs = let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; - var value = ${fillVector(dataType, components)}; + var value = ${fillVector('f32', components)}; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { if (m + local_id.y < M && w + local_id.x < K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; @@ -364,7 +363,7 @@ const computeAttentionProbs = workgroupBarrier(); for (var k: u32 = 0u; k (dispatch) }, @@ -508,7 +507,7 @@ let h = global_idx % ${parameters.vHeadSize}; name: 'AttentionTranspose', cacheHint: JSON.stringify(parameters), inputTypes: [GpuDataType.default], - outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], getShaderSource, dispatchGroup: () => ({ x: Math.ceil(outputSize / 64) }), }, @@ -531,7 +530,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri const K = parameters.inputHiddenSize; const N = parameters.headSize; - const TILE_SIZE = 8; + const TILE_SIZE = 12; const dispatch = { x: Math.ceil(parameters.headSize / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index ecf0ee33d83ad..c9def46890c12 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -12,8 +12,8 @@ import { inputVariable, outputVariable, ShaderHelper, sumVector, - tensorTypeToWsglStorageType -} from './common' +} from './common'; +import { DataType } from '../../../wasm-common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -55,8 +55,9 @@ const createLayerNormProgramInfo = } } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const components = getMaxComponents(normSize); + // TODO: for some reason it does not work with fp16 yet + const components = inputs[0].dataType !== DataType.float16 ? getMaxComponents(normSize) : 1; + const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const variables = [ inputVariable('x', inputs[0].dataType, inputs[0].dims, components), inputVariable('scale', scale.dataType, scale.dims, components), @@ -70,33 +71,37 @@ const createLayerNormProgramInfo = const hasInvStdOutput = outputCount > 2; if (hasMeanDataOutput) { - variables.push(outputVariable('meanDataOutput', inputs[0].dataType, meanInvStdDevDim)); + variables.push(outputVariable('meanDataOutput', DataType.float, meanInvStdDevDim)); } if (hasInvStdOutput) { - variables.push(outputVariable('invStdOutput', inputs[0].dataType, meanInvStdDevDim)); + variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); } const getShaderSource = (shaderHelper: ShaderHelper) => ` const normSize: u32 = ${normSize / components}; - const normSizeTyped: ${dataType} = ${normSize}; - const epsilon: ${dataType} = ${attributes.epsilon}; + const epsilon: f32 = ${attributes.epsilon}; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(normCount)} let offset = global_idx * normSize; - var meanVector = ${fillVector(dataType, components)}; - var meanSquareVector = ${fillVector(dataType, components)}; + var meanVector = ${fillVector('f32', components)}; + var meanSquareVector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < normSize; h++) { - meanVector += x[h + offset]; - meanSquareVector += x[h + offset] * x[h + offset]; + let value = ${castToF32}(x[h + offset]); + meanVector += value; + meanSquareVector += value * value; } - let mean = ${sumVector('meanVector', components)} / normSizeTyped; - let meanSquare = sqrt(${sumVector('meanSquareVector', components)} / normSizeTyped - mean * mean + epsilon); + let mean = ${sumVector('meanVector', components)} / f32(normSize); + let meanSquare = sqrt(${sumVector('meanSquareVector', components)} + / f32(normSize) - mean * mean + epsilon); for (var j: u32 = 0; j < normSize; j++) { - output[j + offset] = (x[j + offset] - mean) / meanSquare * scale[j] ${bias ? '+ bias[j]' : ''}; + output[j + offset] = ${variables[0].type.value}( + (${castToF32}(x[j + offset]) - mean) / meanSquare * ${castToF32}(scale[j]) + ${bias ? `+${castToF32}(bias[j])` : ''} + ); } ${hasMeanDataOutput ? 'meanDataOutput[global_idx] = mean' : ''}; @@ -104,14 +109,10 @@ const createLayerNormProgramInfo = }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; if (hasMeanDataOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } if (hasInvStdOutput) { - outputs.push( - {dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, - ); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index e362f587dd2d0..de39439a3df1c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,4 +1,3 @@ -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -274,7 +273,7 @@ const addBiasTranspose = return context.compute( { ...addBiasTransposeMetadata, - outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], getShaderSource, dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }, @@ -452,7 +451,7 @@ const computeVxAttentionScoreBSN3H = (probs: TensorView, qkv: TensorView, params }`; return { ...attentionScoreMatMulProgramData, - outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], getShaderSource, dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 2de86729e7660..b787d2a1dc401 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -149,10 +149,10 @@ const createSkipLayerNormProgramInfo = }`; const outputs = [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}]; if (outputCount > 1) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } if (outputCount > 2) { - outputs.push({dims: meanInvStdDevDim, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float, gpuDataType: GpuDataType.default}); } if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 924078d8085c3..36328d5f7e459 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -50,7 +50,7 @@ const initOrt = (numThreads: number, loggingLevel: number): void => { */ export const initRuntime = async(env: Env): Promise => { // init ORT - initOrt(10, logLevelStringToEnum(env.logLevel)); + initOrt(navigator.hardwareConcurrency || 4, logLevelStringToEnum(env.logLevel)); if (!BUILD_DEFS.DISABLE_WEBGPU) { // init JSEP if available diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 3d2f7a29d41f8..f408af62ac128 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -327,22 +327,22 @@ Status Attention::Compute(OpKernelContext* context) const { }); } - std::cout << "Prepare completed."; - std::cout << "First 10 values at Q: "; - for (size_t i = 0; i < qkv_head_size[0] * sequence_length * batch_size * num_heads_; ++i) { - std::cout << Q[i] << " "; - } - std::cout << std::endl; - std::cout << "First 10 values at K: "; - for (size_t i = 0; i < qkv_head_size[1] * sequence_length * batch_size * num_heads_; ++i) { - std::cout << K[i] << " "; - } - std::cout << std::endl; - std::cout << "First 10 values at V: "; - for (size_t i = 0; i < qkv_head_size[2] * sequence_length * batch_size * num_heads_; ++i) { - std::cout << V[i] << " "; - } - std::cout << std::endl; +// std::cout << "Prepare completed."; +// std::cout << "First 10 values at Q: "; +// for (size_t i = 0; i < qkv_head_size[0] * sequence_length * batch_size * num_heads_; ++i) { +// std::cout << Q[i] << " "; +// } +// std::cout << std::endl; +// std::cout << "First 10 values at K: "; +// for (size_t i = 0; i < qkv_head_size[1] * sequence_length * batch_size * num_heads_; ++i) { +// std::cout << K[i] << " "; +// } +// std::cout << std::endl; +// std::cout << "First 10 values at V: "; +// for (size_t i = 0; i < qkv_head_size[2] * sequence_length * batch_size * num_heads_; ++i) { +// std::cout << V[i] << " "; +// } +// std::cout << std::endl; // Compute the attention score and apply the score to V return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */, diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 040364d256b7b..1020d3bdeab77 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -200,11 +200,11 @@ class AttentionCPUBase : public AttentionBase { }); } - std::cout << "Probs before softmax."; - for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) { - std::cout << attention_probs[i] << " "; - } - std::cout << std::endl; +// std::cout << "Probs before softmax."; +// for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) { +// std::cout << attention_probs[i] << " "; +// } +// std::cout << std::endl; // attention_probs(B, N, S, T) = Softmax(attention_probs) { @@ -213,11 +213,11 @@ class AttentionCPUBase : public AttentionBase { ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp); } - std::cout << "Probs after softmax."; - for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) { - std::cout << attention_probs[i] << " "; - } - std::cout << std::endl; +// std::cout << "Probs after softmax."; +// for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) { +// std::cout << attention_probs[i] << " "; +// } +// std::cout << std::endl; } template diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 74003e07a606c..500906172a4e4 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -402,7 +402,7 @@ def convert_arg_line_to_args(self, arg_line): # WebAssembly build parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build for WebAssembly static library") - parser.add_argument("--emsdk_version", default="3.1.46", help="Specify version of emsdk") + parser.add_argument("--emsdk_version", default="3.1.45", help="Specify version of emsdk") parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threads support")