diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 39f6ba3d947fc..a0a2d0f1e2a4a 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; -import {attention, parseAttentionAttributes} from './ops/attentiion'; +import {attention, parseAttentionAttributes} from './ops/attention'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts similarity index 90% rename from js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts rename to js/web/lib/wasm/jsep/webgpu/ops/attention.ts index d56d0b9fadd5b..033d3c265cb1e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -6,6 +6,7 @@ import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType} from '../types'; import { + castToF32, fillVector, getMaxComponents, inputVariable, @@ -33,13 +34,12 @@ export enum AttentionMaskType { mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., // key_start[batch_size - 1], key_end[batch_size - 1]] - MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. - MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length] - MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length] - MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] - MASK_UNKNOWN + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown } -; export interface AttentionParameters { batchSize: number; @@ -168,7 +168,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte const totalSequenceLength = kvSequenceLength + pastSequenceLength; const maxSequenceLength = -1; - let maskType = AttentionMaskType.none; + const maskType = AttentionMaskType.none; if (maskIndex) { // maskType = AttentionMaskType.MASK_UNKNOWN; // TODO: handle mask @@ -209,8 +209,8 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => createAttributeWithCacheKey({...attributes}); -export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, N: number, D: number) => { - const components = getMaxComponents(D); +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); const inputHelper = outputVariable('x', input.dataType, input.dims, components); let threadMaxValue = 'threadMaxVector'; @@ -221,18 +221,17 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } const dataType = tensorTypeToWsglStorageType(input.dataType); let WG = 64; - const dComp = D / components; + const dComp = d / components; if (dComp < WG) { WG = 1; } else if (dComp / 8 < 64) { WG = Math.ceil(dComp / 8); } - const elementsPerWG = Math.ceil(D / components / WG); - const castToF32 = components === 1 ? 'f32' : `vec${components}f`; + const elementsPerWG = Math.ceil(d / components / WG); const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dInv: ${dataType} = 1 / ${D}; - const dComp = ${D / components}; + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; var wgMax: array; var wgSum: array; @@ -245,7 +244,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - threadMaxVector = max(${castToF32}(x[offset + i]), threadMaxVector); + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); } wgMax[local_index] = ${threadMaxValue}; workgroupBarrier(); @@ -257,7 +256,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView var sumVector = ${fillVector('f32', components, '0')}; for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - sumVector += exp(${castToF32}(x[offset + i]) - maxValue); + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); } wgSum[local_index] = ${sumVector('sumVector', components)}; workgroupBarrier(); @@ -273,7 +272,8 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } } else { for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = ${inputHelper.type.storage}(exp(${castToF32}(x[offset + i]) - maxValue) / sum); + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); } } }`; @@ -285,7 +285,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView inputTypes: [GpuDataType.default], outputs: [], getShaderSource, - dispatchGroup: () => ({x: N}) + dispatchGroup: () => ({x: n}) }, {inputs: [input], outputs: []}); }; @@ -314,7 +314,6 @@ const computeAttentionProbs = const K = vectorizedHeadSize; const TILE_SIZE = 12; - const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), @@ -327,7 +326,7 @@ const computeAttentionProbs = const M: u32 = ${M}u; const N: u32 = ${N}u; const K: u32 = ${K}u; - const alpha: f32 = ${alpha}; + const alpha: ${dataType} = ${alpha}; const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; @@ -352,7 +351,7 @@ const computeAttentionProbs = let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; - var value = ${fillVector('f32', components)}; + var value = ${fillVector(dataType, components)}; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { if (m + local_id.y < M && w + local_id.x < K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; @@ -363,7 +362,7 @@ const computeAttentionProbs = workgroupBarrier(); for (var k: u32 = 0u; k(${new Array(components).fill(value).join(',')})`; + return `vec${components}<${dataType}>(${value})`; +}; + +export const castToF32 = (dataType: string, components: number, value: string) => { + if (dataType === 'f32') { + return value; + } + if (components === 1) { + return `f32(${value})`; + } + + return `vec${components}f(${value})`; }; export const sumVector = (name: string, components: number) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index c9def46890c12..0f9e830bb14b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -7,11 +7,14 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; import { + castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, - ShaderHelper, sumVector, + ShaderHelper, + sumVector, + tensorTypeToWsglStorageType, } from './common'; import { DataType } from '../../../wasm-common'; @@ -55,9 +58,9 @@ const createLayerNormProgramInfo = } } - // TODO: for some reason it does not work with fp16 yet + // TODO: for some reason it does not work correctly with fp16 const components = inputs[0].dataType !== DataType.float16 ? getMaxComponents(normSize) : 1; - const castToF32 = components === 1 ? 'f32' : `vec${components}f`; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const variables = [ inputVariable('x', inputs[0].dataType, inputs[0].dims, components), inputVariable('scale', scale.dataType, scale.dims, components), @@ -89,7 +92,7 @@ const createLayerNormProgramInfo = var meanSquareVector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < normSize; h++) { - let value = ${castToF32}(x[h + offset]); + let value = ${castToF32(dataType, components, 'x[h + offset]')}; meanVector += value; meanSquareVector += value * value; } @@ -98,9 +101,10 @@ const createLayerNormProgramInfo = / f32(normSize) - mean * mean + epsilon); for (var j: u32 = 0; j < normSize; j++) { - output[j + offset] = ${variables[0].type.value}( - (${castToF32}(x[j + offset]) - mean) / meanSquare * ${castToF32}(scale[j]) - ${bias ? `+${castToF32}(bias[j])` : ''} + let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; + let f32scale = ${castToF32(dataType, components, 'scale[j]')}; + output[j + offset] = ${variables[0].type.value}((f32input - mean) / meanSquare * f32scale + ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index de39439a3df1c..608b799fb8bf7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,17 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType} from '../types'; -import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat, computeInPlaceSoftmax,} from './attentiion'; +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; import { - fillVector, - inputVariable, - outputVariable, ShaderHelper, - sumVector, tensorTypeToWsglStorageType -} from './common' +} from './common'; import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { @@ -74,47 +73,47 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const headSize = Math.floor(hiddenSize / attributes.numHeads); if (pastKey && pastValue) { if (pastKey.dims.length !== 4) { - throw new Error('Input \'past_key\' is expected to have 4 dimensions'); + throw new Error('Input "past_key" is expected to have 4 dimensions'); } if (pastValue.dims.length !== 4) { - throw new Error('Input \'past_value\' is expected to have 4 dimensions') + throw new Error('Input "past_value" is expected to have 4 dimensions'); } pastSequenceLength = pastKey.dims[2]; maxSequenceLength = pastKey.dims[2]; } else if (pastKey || pastValue) { - throw new Error('Input \'past_key\' and \'past_value\' shall be both present or both absent') + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } let qkvFormat: AttentionQkvFormat; if (key) { if (query.dims.length !== 3) { - throw new Error('Input \'query\' is expected to have 3 dimensions when key is given'); + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } if (key.dims.length < 3 || key.dims.length > 5) { - throw new Error('Input \'key\' is expected to have 3, 4, or 5 dimensions'); + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); } if (query.dims[0] !== key.dims[0]) { - throw new Error('Input \'query\' and \'key\' shall have same dim 0 (batch size)'); + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); } if (key.dims.length === 3) { if (key.dims[2] !== query.dims[2]) { - throw new Error('Input \'query\' and \'key\' shall have same dim 2 (hidden_size)'); + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); } qkvFormat = AttentionQkvFormat.qkvBSNH; kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { - throw new Error('Expect \'key\' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); } if (value) { - throw new Error('Expect \'value\' be none when \'key\' has packed kv format.'); + throw new Error('Expect "value" be none when "key" has packed kv format.'); } qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { - throw new Error('Expect \'key\' shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } qkvFormat = AttentionQkvFormat.unknown; @@ -122,10 +121,10 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } } else { // packed QKV if (query.dims.length !== 3 && query.dims.length !== 5) { - throw new Error('Input \'query\' is expected to have 3 or 5 dimensions when key is empty'); + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); } if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { - throw new Error('Expect \'query\' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); } qkvFormat = AttentionQkvFormat.qkvBSN3H; @@ -133,7 +132,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr if (bias) { if (bias.dims.length !== 1) { - throw new Error('Input \'bias\' is expected to have 1 dimension'); + throw new Error('Input "bias" is expected to have 1 dimension'); } if (value) { @@ -145,19 +144,19 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr let maskType: AttentionMaskType = AttentionMaskType.none; if (keyPaddingMask) { - maskType = AttentionMaskType.MASK_UNKNOWN; + maskType = AttentionMaskType.maskUnknown; const maskDims = keyPaddingMask.dims; if (maskDims.length === 1) { if (maskDims[0] === batchSize) { maskType = AttentionMaskType.mask1dKeySeqLen; } else if (maskDims[0] === 3 * batchSize + 2) { - maskType = AttentionMaskType.mask1DKeySeqLenStart + maskType = AttentionMaskType.mask1DKeySeqLenStart; } } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { - maskType = AttentionMaskType.MASK_2D_KEY_PADDING; + maskType = AttentionMaskType.mask2dKeyPadding; } - if (maskType === AttentionMaskType.MASK_UNKNOWN) { - throw new Error('Input \'key_padding_mask\' shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); } throw new Error('Mask not supported'); } @@ -166,38 +165,35 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr let vHiddenSize = hiddenSize; if (value) { if (value.dims.length !== 3 && value.dims.length !== 4) { - throw new Error('Input \'value\' is expected to have 3 or 4 dimensions') + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } if (query.dims[0] !== value.dims[0]) { - throw new Error('Input \'query\' and \'value\' shall have same dim 0 (batch_size)') + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); } if (value.dims.length === 3) { if (kvSequenceLength !== value.dims[1]) { - throw new Error('Input \'key\' and \'value\' shall have the same dim 1 (kv_sequence_length)') + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); } vHiddenSize = value.dims[2]; } else { if (kvSequenceLength !== value.dims[2]) { - throw new Error('Input \'past_key\' and \'past_value\' shall have the same dim 2 (kv_sequence_length)') + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); } vHiddenSize = value.dims[1] * value.dims[3]; passPastInKv = true; } } - let totalSequenceLength = pastSequenceLength + kvSequenceLength; - let broadcastResPosBias = false; + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; // if (extraAddQk) { // if (extraAddQk.dims[0] === 1) { // broadcastResPosBias = true; // } // } - // if (bias) { - // throw new Error('bias is not supported'); - // } if (keyPaddingMask) { throw new Error('Key padding mask is not supported'); } @@ -240,7 +236,6 @@ export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): A createAttributeWithCacheKey({...attributes}); const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); -const packedWeightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3, 4]}); const addBiasTranspose = (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, @@ -315,206 +310,15 @@ const maybeTransposeToBNSHAndAddBias = } }; -// const getMaxComponents = (size: number) => { -// if (size % 4 === 0) { -// return 4; -// } else if (size % 3 === 0) { -// return 3; -// } else if (size % 2 === 0) { -// return 2; -// } -// -// return 1; -// }; - -const computeAttentionProbsBSN3H = - (context: ComputeContext, q: TensorView, key: TensorView, bias: TensorView|undefined, - parameters: AttentionParameters, attributes: AttentionAttrs) => { - const probsShape = [ - parameters.batchSize, parameters.sequenceLength, parameters.numHeads, - parameters.kvSequenceLength + parameters.pastSequenceLength - ]; - - const components = undefined; // getMaxComponents(parameters.headSize); - const qInput = inputVariable('q', q.dataType, q.dims, components); - const output = outputVariable('output', q.dataType, probsShape); - - const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - - const unitsOfWork = ShapeUtil.size(probsShape); - - const M = parameters.sequenceLength; - const N = parameters.totalSequenceLength; - const K = parameters.headSize; - - // since we are multiplying Q with transposed K and headSize = vHeadSize, - // we are multiplying Q head rows with K head rows for each head - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${M}u; - const N: u32 = ${N}u; - const K: u32 = ${K / (components || 1)}u; - const numHeads: u32 = ${parameters.numHeads}; - const batchSize: u32 = ${parameters.batchSize}; - const alpha = f32(${alpha}); - const beta = 1.0; - - ${shaderHelper.declareVariables(qInput, output)} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} - // batch and head index - let batchIdx = global_idx / (M * N * numHeads); - let headIdx = (global_idx / (M * N)) % numHeads; - let qSequenceIdx = (global_idx / N) % ${parameters.sequenceLength}; - let kSequenceIdx = global_idx % (M * N) % ${parameters.totalSequenceLength}; - - var headOffset = headIdx * ${parameters.headSize} * 3; - - var qOffset = qSequenceIdx * ${parameters.headSize} * numHeads * 3 + headOffset; - var batchOffset = batchIdx * ${parameters.headSize} * numHeads * 3 * M; - qOffset += batchOffset; // batch offset - let kOffset = ${parameters.headSize}u + batchOffset + headOffset + kSequenceIdx * - ${parameters.headSize} * numHeads * 3; - var value: ${qInput.type.storage} = ${fillVector(qInput.type.value, components)}; - for (var k: u32 = 0u; k<${K}u; k++) { - value += q[k + qOffset] * q[k + kOffset]; - } - - let sum = ${sumVector('value', components!)} * alpha; - // value += beta * output[global_id.x]; // no mask - output[global_idx] = sum; - }`; - - const inputTypes = [1].map(_ => GpuDataType.default); - - const probs = context.compute( - { - name: 'computeAttentionProbsBSN3H', - cacheHint: JSON.stringify(parameters), - inputTypes, - outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}) - }, - {inputs: [q], outputs: [-1]})[0]; - - computeInPlaceSoftmax( - context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, - parameters.totalSequenceLength); - - return probs; - }; - -const computeVxAttentionScoreBSN3H = (probs: TensorView, qkv: TensorView, params: AttentionParameters) => { - const attentionScoreMatMulProgramData = { - name: 'computeVxAttentionScore', - inputTypes: [GpuDataType.default, GpuDataType.default], - cacheHint: JSON.stringify(params), - }; - - const outputShape = [params.batchSize, params.sequenceLength, params.numHeads, params.vHeadSize]; - const outputSize = ShapeUtil.size(outputShape); - - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const qkvHelper = inputVariable('qkv', qkv.dataType, qkv.dims); - const output = outputVariable('output', probs.dataType, outputShape); - - const dataType = 'f32'; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${params.sequenceLength}u; - const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength}u; - const numHeads: u32 = ${params.numHeads}u; - const batchSize: u32 = ${params.batchSize}; - - ${shaderHelper.declareVariables(probsHelper, qkvHelper, output)} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let batchIdx = global_idx / (M * N * numHeads); - let headIdx = (global_idx / (M * N)) % numHeads; - let probsSequenceIdx = (global_idx / N) % ${params.sequenceLength}; - - let offsetA = probsSequenceIdx * ${params.headSize} * numHeads + batchIdx * ${params.headSize} * numHeads * M; - - var headOffset = headIdx * ${params.vHeadSize} * 3; - var batchOffset = batchIdx * ${params.vHeadSize} * numHeads * 3 * M; - - - var value = ${dataType}(0); - for (var k: u32 = 0u; k ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) - }; -}; - -export const applyPackedAttention = - (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, maskIndex: TensorView|undefined, - past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined, - relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { - const probs = computeAttentionProbsBSN3H(context, q, k, relativePositionBias, parameters, attributes); - - const attentionScoreMatMulProgramData = { - name: 'PackedAttentionScore', - inputTypes: [GpuDataType.default, GpuDataType.default], - cacheHint: JSON.stringify(parameters) + JSON.stringify(attributes), - }; - - const attentionResult = context.compute( - { - ...attentionScoreMatMulProgramData, - cacheHint: JSON.stringify(parameters), - get: () => computeVxAttentionScoreBSN3H(probs, q, parameters) - }, - {inputs: [probs, v || q], outputs: [-1]})[0]; - - context.compute( - { - ...transposeProgramMetadata, - cacheHint: JSON.stringify(parameters) + JSON.stringify(attributes), - get: () => createTransposeProgramInfo( - attentionResult, weightTransposeAttribute.perm, - [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]) - }, - {inputs: [attentionResult], outputs: [0]}); - }; - export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { - // transpose QKV from BSN3H to BNS3H - return applyPackedAttention( - context, context.inputs[0], context.inputs[1], context.inputs[2], context.inputs[4], undefined, - context.inputs[6], context.inputs[7], context.inputs[5], params, attributes); + throw new Error('Packed QKV is not implemented'); } if (context.inputs[1]?.dims.length === 5) { - // transpose Q from BSD (BSNH) to BNSH - const Q = maybeTransposeToBNSHAndAddBias( - context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], - context.inputs[3], 0); - - // transpose KV from BLN2H to BNS2H - const K = context.compute( - { - ...transposeProgramMetadata, - cacheHint: weightTransposeAttribute.cacheKey, - get: () => createTransposeProgramInfo(context.inputs[1], packedWeightTransposeAttribute.perm) - }, - {inputs: [context.inputs[0]], outputs: [-1]})[0]; - return applyAttention( - context, Q, K, context.inputs[2], context.inputs[4], undefined, context.inputs[6], context.inputs[7], - context.inputs[5], params, attributes); + throw new Error('Packed KV is not implemented'); } // applyAttention expects BNSH inputs diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index b787d2a1dc401..ccb159ae3c67e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -7,12 +7,13 @@ import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-w import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; import { + castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, -} from './common'; +} from './common' import { DataType } from '../../../wasm-common' export interface SkipLayerNormAttributes extends AttributeWithCacheKey { @@ -114,7 +115,6 @@ const createSkipLayerNormProgramInfo = variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); } const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const castToF32 = components === 1 ? 'f32' : `vec${components}f`; const getShaderSource = (shaderHelper: ShaderHelper) => ` const hiddenSize: u32 = ${hiddenSize}; const hiddenSizeVectorized: u32 = ${hiddenSize / components}; @@ -134,7 +134,7 @@ const createSkipLayerNormProgramInfo = let value = inputValue + skipValue + biasValue; ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32}(value); + let f32Value = ${castToF32(dataType, components, 'value')}; sum += f32Value; squareSum += f32Value * f32Value; } diff --git a/js/web/test/data/ops/attention-op-working.jsonc b/js/web/test/data/ops/attention-op-working.jsonc new file mode 100644 index 0000000000000..74c3d80f11938 --- /dev/null +++ b/js/web/test/data/ops/attention-op-working.jsonc @@ -0,0 +1,100 @@ +[ + { + "name": "Attention batch 1, 2 heads, hs 4", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.8, -0.5, 0.0, 1.0, + 0.5, 0.2, 0.3, -0.6 + ], + "dims": [1, 2, 4], + "type": "float16" + }, + { + "data": [ + 0.1, -0.2, 0.3, 1.0, 1.1, 0.3, 0.5, 0.2, 0.3, -0.6, 1.5, 2.0, + 0.5, 0.1, 0.4, 1.6, 1.0, 2.0, 0.4, 0.8, 0.9, 0.1, -1.3, 0.7, + 0.3, 0.2, 4.0, 2.2, 1.6, 1.1, 0.7, 0.2, 0.4, 1.0, 1.2, 0.5, + 0.2, 0.1, 0.4, 1.6, 2.4, 3.3, 2.1, 4.2, 8.4, 0.0, 2.1, 3.2 + ], + "dims": [4, 12], + "type": "float16" + }, + { + "data": [ + -0.5, 0.6, 1.2, 2.1, 0.5, 0.7, 0.2, 1.2, 0.5, 0.4, 0.3, 1.2 + ], + "dims": [12], + "type": "float16" + } + ], + "outputs": [ + { + "data": [ + 3.1495983600616455, 0.10843668878078461, 4.25, 5.6499996185302734, + 3.9696791172027588, 0.073143675923347473, 4.2499995231628418, 5.6499991416931152 + ], + "dims": [1, 2, 4], + "type": "float16" + } + ] + } + ] + }, + { + "name": "AttentionBatch2RelativePositionBias", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.8, -0.5, 0.0, 1.0, + 0.5, 0.2, 0.3, -0.6, + 0.8, -0.5, 0.0, 1.0, + 0.5, 0.2, 0.3, -0.6 + ], + "dims": [2, 2, 4], + "type": "float16" + }, + { + "data": [ + 0.1, -0.2, 0.3, 1.0, 1.1, 0.3, 0.5, 0.2, 0.3, -0.6, 1.5, 2.0, + 0.5, 0.1, 0.4, 1.6, 1.0, 2.0, 0.4, 0.8, 0.9, 0.1, -1.3, 0.7, + 0.3, 0.2, 4.0, 2.2, 1.6, 1.1, 0.7, 0.2, 0.4, 1.0, 1.2, 0.5, + 0.2, 0.1, 0.4, 1.6, 2.4, 3.3, 2.1, 4.2, 8.4, 0.0, 2.1, 3.2 + ], + "dims": [4, 12], + "type": "float16" + }, + { + "data": [ + -0.5, 0.6, 1.2, 2.1, 0.5, 0.7, + 0.2, 1.2, 0.5, 0.4, 0.3, 1.2 + ], + "dims": [12], + "type": "float16" + } + ], + "outputs": [ + { + "data": [ + 3.149597406387329,0.1084367036819458,4.25,5.649999618530273,3.9696788787841797,0.07314369082450867,4.249999523162842,5.649999141693115,3.149597406387329,0.1084367036819458,4.25,5.649999618530273,3.9696788787841797,0.07314369082450867,4.249999523162842,5.649999141693115 + ], + "dims": [2, 2, 4], + "type": "float16" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/attention-op.jsonc b/js/web/test/data/ops/attention-op.jsonc new file mode 100644 index 0000000000000..57c784b94e858 --- /dev/null +++ b/js/web/test/data/ops/attention-op.jsonc @@ -0,0 +1,52 @@ +[ + { + "name": "AttentionBatch2RelativePositionBias", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.8, -0.5, 0.0, 1.0, + 0.5, 0.2, 0.3, -0.6, + 0.8, -0.5, 0.0, 1.0, + 0.5, 0.2, 0.3, -0.6 + ], + "dims": [2, 2, 4], + "type": "float32" + }, + { + "data": [ + 0.1, -0.2, 0.3, 1.0, 1.1, 0.3, 0.5, 0.2, 0.3, -0.6, 1.5, 2.0, + 0.5, 0.1, 0.4, 1.6, 1.0, 2.0, 0.4, 0.8, 0.9, 0.1, -1.3, 0.7, + 0.3, 0.2, 4.0, 2.2, 1.6, 1.1, 0.7, 0.2, 0.4, 1.0, 1.2, 0.5, + 0.2, 0.1, 0.4, 1.6, 2.4, 3.3, 2.1, 4.2, 8.4, 0.0, 2.1, 3.2 + ], + "dims": [4, 12], + "type": "float32" + }, + { + "data": [ + -0.5, 0.6, 1.2, 2.1, 0.5, 0.7, + 0.2, 1.2, 0.5, 0.4, 0.3, 1.2 + ], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 3.149597406387329,0.1084367036819458,4.25,5.649999618530273,3.9696788787841797,0.07314369082450867,4.249999523162842,5.649999141693115,3.149597406387329,0.1084367036819458,4.25,5.649999618530273,3.9696788787841797,0.07314369082450867,4.249999523162842,5.649999141693115 + ], + "dims": [2, 2, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/attention2.jsonc b/js/web/test/data/ops/attention2.jsonc index 3203f8ba52fc7..6ab17564bab42 100644 --- a/js/web/test/data/ops/attention2.jsonc +++ b/js/web/test/data/ops/attention2.jsonc @@ -1,86 +1,5 @@ [ - { - "name": "Attention Basic 1 head, batch 3", - "operator": "Attention", - "opset": { "domain": "com.microsoft", "version": 1 }, - "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, - 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, - -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, - -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, - 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, - -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, - -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, - 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, - 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, - 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, - -0.26380985975265503, -0.25473490357398987 - ], - "dims": [3, 3, 10], - "type": "float32" - }, - { - "data": [ - 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, - 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, - -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, - 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, - -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, - -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, - 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, - 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, - -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, - 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, - 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, - 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, - -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, - 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, - 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, - -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, - -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, - -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, - -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, - -1.8803634643554688, 2.1661579608917236 - ], - "dims": [10, 15], - "type": "float32" - }, - { - "data": [ - -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688 - ], - "dims": [15], - "type": "float32" - } - ], - "outputs": [ - { - "data": [ - -5.545168399810791,-0.1483151912689209,9.630617141723633,-2.6852126121520996,-6.09011697769165,-5.545279026031494,-0.1482163518667221,9.630945205688477,-2.6851868629455566,-6.090045928955078,-5.545281887054443,-0.14821410179138184,9.630952835083008,-2.6851863861083984,-6.090044975280762,-3.6367568969726562,-4.700472354888916,2.8948633670806885,-5.297839641571045,-8.878729820251465,-3.6367805004119873,-4.7004804611206055,2.8948824405670166,-5.297871112823486,-8.878717422485352,1.1043529510498047,-1.6557061672210693,3.653333902359009,3.3767333030700684,-4.849300861358643,0.06528311967849731,-8.90493106842041,-4.284236431121826,9.792343139648438,-10.294163703918457,0.06407018005847931,-8.901933670043945,-4.28200626373291,9.785630226135254,-10.293055534362793,-2.3217902183532715,-3.0068209171295166,0.1039692759513855,-3.417605400085449,-8.115047454833984 - ], - "dims": [3, 3, 5], - "type": "float32" - } - ] - } - ] - }, - { + { "name": "Attention Basic 2 head, batch 3", "operator": "Attention", "opset": { "domain": "com.microsoft", "version": 1 }, @@ -152,7 +71,7 @@ "outputs": [ { "data": [ - -2.485170602798462,-3.214747667312622,9.630952835083008,-2.6851868629455566,0,-4.958700656890869,-0.669141948223114,9.630952835083008,-2.6851863861083984,0,-2.882824182510376,-2.520465850830078,-3.6364173889160156,-4.700357437133789,2.8907482624053955,-5.291064739227295,1.401298464324817e-45,-3.6367807388305664,-4.7004804611206055,2.8949179649353027,-5.2971978187561035,1.555441295400547e-43,1.208133578300476,-1.6758347749710083,-2.1411385536193848,-2.8134799003601074,-4.256412029266357,9.76078987121582,1.1210387714598537e-43,0.06528311967849731,-8.90493106842041,-4.257097244262695,9.71064567565918,2.8074450492858887,-2.3217902183532715,-3.0068209171295166,7.57670783996582,-3.6506946086883545,2.6705641746520996,3.859153985977173,5.887318134307861,-2.1149888038635254,4.807590961456299,2.8616080284118652,-1.1043622493743896 + -2.485170602798462,-3.214747667312622,9.630952835083008,-2.6851868629455566,0,-4.958700656890869,-0.669141948223114,9.630952835083008,-2.6851863861083984,0,-2.882824182510376,-2.520465850830078,-3.6364173889160156,-4.700357437133789,2.8907482624053955,-5.291064739227295,0,-3.6367807388305664,-4.7004804611206055,2.8949179649353027,-5.2971978187561035,0,1.208133578300476,-1.6758347749710083,-2.1411385536193848,-2.8134799003601074,-4.256412029266357,9.76078987121582,0,0.06528311967849731,-8.90493106842041,-4.257097244262695,9.71064567565918,0,-2.3217902183532715,-3.0068209171295166,7.57670783996582,-3.6506946086883545,0,0,0,0,0,0,0 ], "dims": [3, 3, 5], "type": "float32" diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 928192bb219f2..d32001110cf5e 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -10,19 +10,19 @@ { "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], "dims": [1, 1, 3, 3], - "type": "float32" + "type": "float16" }, { "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [370, 470, 670, 770], "dims": [1, 1, 2, 2], - "type": "float32" + "type": "float16" } ] } @@ -39,19 +39,19 @@ { "data": [10, 20, 30, 40, 50, 60, 70, 80], "dims": [1, 2, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [1, 2, 3, 4, 5, 6, 7, 8], "dims": [1, 2, 2, 2], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [2040], "dims": [1, 1, 1, 1], - "type": "float32" + "type": "float16" } ] } @@ -68,24 +68,24 @@ { "data": [10, 20, 30, 40, 50, 60, 70, 80], "dims": [1, 2, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "dims": [4, 2, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [0.1, 0.2, 0.3, 0.4], "dims": [4], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [360.1, 360.2, 360.3, 360.4], "dims": [1, 4, 1, 1], - "type": "float32" + "type": "float16" } ] } @@ -102,24 +102,24 @@ { "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [1, 1, 1, 1], "dims": [1, 1, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [5], "dims": [1], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [15], "dims": [1, 1, 1, 1], - "type": "float32" + "type": "float16" } ] } @@ -139,19 +139,19 @@ { "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0], "dims": [1, 2, 3, 3], - "type": "float32" + "type": "float16" }, { "data": [1.0, 2.0], "dims": [2, 1, 1, 1], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], "dims": [1, 2, 3, 3], - "type": "float32" + "type": "float16" } ] } @@ -174,24 +174,24 @@ 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 ], "dims": [1, 3, 3, 3], - "type": "float32" + "type": "float16" }, { "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], "dims": [3, 1, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [0.1, 0.2, 0.3], "dims": [3], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3], "dims": [1, 3, 2, 2], - "type": "float32" + "type": "float16" } ] } @@ -214,19 +214,19 @@ 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 ], "dims": [1, 3, 3, 4], - "type": "float32" + "type": "float16" }, { "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], "dims": [3, 1, 2, 2], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [34, 44, 54, 74, 84, 94, 386, 412, 438, 490, 516, 542, 1122, 1164, 1206, 1290, 1332, 1374], "dims": [1, 3, 2, 3], - "type": "float32" + "type": "float16" } ] } @@ -247,24 +247,24 @@ 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0 ], "dims": [1, 8, 2, 2], - "type": "float32" + "type": "float16" }, { "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0], "dims": [2, 8, 1, 1], - "type": "float32" + "type": "float16" }, { "data": [0.5, 0.4], "dims": [2], - "type": "float32" + "type": "float16" } ], "outputs": [ { "data": [560.5, 588.5, 616.5, 644.5, 1456.4, 1548.4, 1640.4, 1732.4], "dims": [1, 2, 2, 2], - "type": "float32" + "type": "float16" } ] } diff --git a/js/web/webpack.config.js b/js/web/webpack.config.js index 45cc85dd99795..b510914945b0e 100644 --- a/js/web/webpack.config.js +++ b/js/web/webpack.config.js @@ -203,9 +203,10 @@ function buildTestRunnerConfig({ }, extensions: ['.ts', '.js'], fallback: { - './binding/ort-wasm.js': false, - './binding/ort-wasm-threaded.js': false, - './binding/ort-wasm-threaded.worker.js': false + // './binding/ort-wasm.js': false, + // './binding/ort-wasm-threaded.js': false, + // './binding/ort-wasm-threaded.worker.js': false, + // './binding/ort-wasm-simd-threaded.worker.js': false } }, plugins: [ diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 1020d3bdeab77..895e75156a8f5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -283,6 +283,12 @@ class AttentionCPUBase : public AttentionBase { } } }); + +// std::cout << "Before transpose."; +// for (size_t i = 0; i < batch_size * num_heads_ * sequence_length * v_head_size; ++i) { +// std::cout << tmp_buffer[i] << " "; +// } +// std::cout << std::endl; } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0b55cb7804c61..c4a7ddf2b3c1e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -139,12 +139,30 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat }); } +// std::cout << "After bias add."; +// std::cout << std::endl; +// auto tensor = qkv_with_bias.GetMutable(); +// auto data = tensor->MutableData(); +// for (size_t i = 0; i < batch_size * sequence_length * hidden_size; ++i) { +// std::cout << data[i] << " "; +// } +// std::cout << std::endl; + // Reshape Q from BxSxD to BxSxNxH ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); // Transpose Q from BxSxNxH to BxNxSxH ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed)); +// std::cout << "After transpose."; +// std::cout << std::endl; +// tensor = qkv_with_bias_transposed.GetMutable(); +// data = tensor->MutableData(); +// for (size_t i = 0; i < batch_size * sequence_length * hidden_size; ++i) { +// std::cout << data[i] << " "; +// } +// std::cout << std::endl; + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/js/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc similarity index 100% rename from onnxruntime/contrib_ops/js/attention.cc rename to onnxruntime/contrib_ops/js/bert/attention.cc diff --git a/onnxruntime/contrib_ops/js/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h similarity index 100% rename from onnxruntime/contrib_ops/js/attention.h rename to onnxruntime/contrib_ops/js/bert/attention.h diff --git a/onnxruntime/contrib_ops/js/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc similarity index 100% rename from onnxruntime/contrib_ops/js/multi_head_attention.cc rename to onnxruntime/contrib_ops/js/bert/multi_head_attention.cc diff --git a/onnxruntime/contrib_ops/js/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h similarity index 100% rename from onnxruntime/contrib_ops/js/multi_head_attention.h rename to onnxruntime/contrib_ops/js/bert/multi_head_attention.h