Skip to content

Commit

Permalink
Everything works
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 11, 2023
1 parent edce8d0 commit 8b492e8
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 70 deletions.
59 changes: 31 additions & 28 deletions js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@ import {
ShaderHelper,
sumVector,
tensorTypeToWsglStorageType
} from './common'
import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'
} from './common';
import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose';

export enum AttentionQkvFormat {
UNKNOWN, // enum value not set, or depends on qkv projection implementation details
Q_K_V_BNSH, // for non-packed qkv, permuted
Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
QKV_BSN3H, // for TRT fused attention, qkv are packed
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed
unknown, // enum value not set, or depends on qkv projection implementation details
qkvBNSH, // for non-packed qkv, permuted
qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
qkvBSN3H, // for TRT fused attention, qkv are packed
qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed
qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed
}

export enum AttentionMaskType {
MASK_NONE, // No mask
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
MASK_1D_END_START, // [2 * batch_size] with end positions and start positions
MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
none, // No mask
mask1dKeySeqLen, // [batch_size], key sequence length
mask1dEndStart, // [2 * batch_size] with end positions and start positions
mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
// ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
// key_start[batch_size - 1], key_end[batch_size - 1]]
MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
Expand Down Expand Up @@ -170,7 +170,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
const totalSequenceLength = kvSequenceLength + pastSequenceLength;
const maxSequenceLength = -1;

let maskType = AttentionMaskType.MASK_NONE;
let maskType = AttentionMaskType.none;
if (maskIndex) {
// maskType = AttentionMaskType.MASK_UNKNOWN;
// TODO: handle mask
Expand Down Expand Up @@ -204,7 +204,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
scale: attributes.scale,
broadcastResPosBias: false,
passPastInKv: false,
qkvFormat: AttentionQkvFormat.Q_K_V_BNSH,
qkvFormat: AttentionQkvFormat.qkvBNSH,
};
};

Expand Down Expand Up @@ -422,9 +422,8 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v:
const M: u32 = ${params.sequenceLength}u;
const N: u32 = ${params.vHeadSize}u;
const K: u32 = ${params.totalSequenceLength / components}u;
const numHeads: u32 = ${params.numHeads}u;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
Expand All @@ -437,30 +436,32 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v:
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
let headIdx = workgroup_id.z;
let m = workgroup_id.y * TILE_SIZE + local_id.y;
let n = workgroup_id.x * TILE_SIZE + local_id.x;
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let lm = m + local_id.y;
let ln = n + local_id.x;
let offsetA = headIdx * (M * K) + m * K;
let offsetB = headIdx * (N * K) + n;
var value = ${fillVector(dataType, components)};
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
if (m < M && w + local_id.x < K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
if (m + local_id.y < M && w + local_id.x < K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + local_id.y * K + w + local_id.x];
}
if (n < N && w + local_id.y < K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N];
if (n + local_id.y < N && w + local_id.x < K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + local_id.y * K + w + local_id.x];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k];
}
workgroupBarrier();
}
let headOffset = headIdx * M * N;
if (m < M && n < N) {
let outputIdx = headOffset + m * N + n;
if (lm < M && ln < N) {
let outputIdx = headOffset + lm * N + ln;
output[outputIdx] = ${sumVector('value', components)};
}
}`;
Expand Down Expand Up @@ -591,7 +592,9 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri
let outputIdx = offset + m * N + n;
outputQ[outputIdx] = valueQ;
outputK[outputIdx] = valueK;
outputV[outputIdx] = valueV;
// transpose V to use vec4 optimizations in compute score
let outputIdxV = offset + n * M + m;
outputV[outputIdxV] = valueV;
}
}`;

Expand Down
31 changes: 14 additions & 17 deletions js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import { TensorView } from '../../tensor';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata } from '../types';

import {
ShaderHelper,
fillVector,
getMaxComponents,
inputVariable,
tensorTypeToWsglStorageType,
outputVariable,
getMaxComponents,
fillVector
} from './common'
import { DataType } from '../../../wasm-common'
ShaderHelper,
tensorTypeToWsglStorageType
} from './common';
import { DataType } from '../../../wasm-common';

export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
Expand Down Expand Up @@ -99,11 +99,6 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi
const unitsOfWork = n * c / components;
const wgSize = Math.ceil(h / WG);

let divisor = `${dataType}(H)`;
if (input.dataType === DataType.float16 && h > 65504) {
divisor = `f16(${h / 2}) / 2.0h`;
}

const getMeanShaderSource = (shaderHelper: ShaderHelper) => `
const H: u32 = ${h};
const C: u32 = ${c / components};
Expand Down Expand Up @@ -131,8 +126,8 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi
squaredSum += value * value;
}
// we need to divide it here to avoid fp16 overflow
sum = sum / ${divisor};
squaredSum = squaredSum / ${divisor};
sum = sum / ${wgSize};
squaredSum = squaredSum / ${wgSize};
output[global_idx] = ${setOutputValue('sum', 'squaredSum')};
}`;

Expand Down Expand Up @@ -172,6 +167,8 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorVi
sum += value[0];
squaredSum += value[1];
}
sum = sum / ${h / wgSize};
squaredSum = squaredSum / ${h / wgSize};
let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon);
let channelScale = invStdDev * scale[currentChannelNumber];
let channelShift = bias[currentChannelNumber] - sum * channelScale;
Expand Down
14 changes: 7 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
if (key.dims[2] !== query.dims[2]) {
throw new Error('Input \'query\' and \'key\' shall have same dim 2 (hidden_size)');
}
qkvFormat = AttentionQkvFormat.Q_K_V_BSNH;
qkvFormat = AttentionQkvFormat.qkvBSNH;
kvSequenceLength = key.dims[1];
} else if (key.dims.length === 5) {
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
Expand All @@ -111,14 +111,14 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
if (value) {
throw new Error('Expect \'value\' be none when \'key\' has packed kv format.');
}
qkvFormat = AttentionQkvFormat.Q_KV_BSNH_BSN2H;
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
kvSequenceLength = key.dims[1];
} else { // key_dims.size() == 4 (cross-attention with past_key)
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
throw new Error('Expect \'key\' shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
}

qkvFormat = AttentionQkvFormat.UNKNOWN;
qkvFormat = AttentionQkvFormat.unknown;
kvSequenceLength = key.dims[2];
}
} else { // packed QKV
Expand All @@ -129,7 +129,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
throw new Error('Expect \'query\' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
}

qkvFormat = AttentionQkvFormat.QKV_BSN3H;
qkvFormat = AttentionQkvFormat.qkvBSN3H;
}

if (bias) {
Expand All @@ -144,15 +144,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
}
}

let maskType: AttentionMaskType = AttentionMaskType.MASK_NONE;
let maskType: AttentionMaskType = AttentionMaskType.none;
if (keyPaddingMask) {
maskType = AttentionMaskType.MASK_UNKNOWN;
const maskDims = keyPaddingMask.dims;
if (maskDims.length === 1) {
if (maskDims[0] === batchSize) {
maskType = AttentionMaskType.MASK_1D_KEY_SEQ_LEN;
maskType = AttentionMaskType.mask1dKeySeqLen;
} else if (maskDims[0] === 3 * batchSize + 2) {
maskType = AttentionMaskType.MASK_1D_KEY_SEQ_LEN_START
maskType = AttentionMaskType.mask1DKeySeqLenStart
}
} else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) {
maskType = AttentionMaskType.MASK_2D_KEY_PADDING;
Expand Down
1 change: 0 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ export class ProgramManager {
// usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
// });
//
//
// const commandEncoder = this.backend.getCommandEncoder();
// commandEncoder?.copyBufferToBuffer(
// output.buffer,
Expand Down
17 changes: 0 additions & 17 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,23 +327,6 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
});
}

std::cout << "Prepare completed.";
std::cout << "First 10 values at Q: ";
for (size_t i = 0; i < qkv_head_size[0] * sequence_length * batch_size * num_heads_; ++i) {
std::cout << Q[i] << " ";
}
std::cout << std::endl;
std::cout << "First 10 values at K: ";
for (size_t i = 0; i < qkv_head_size[1] * sequence_length * batch_size * num_heads_; ++i) {
std::cout << K[i] << " ";
}
std::cout << std::endl;
std::cout << "First 10 values at V: ";
for (size_t i = 0; i < qkv_head_size[2] * sequence_length * batch_size * num_heads_; ++i) {
std::cout << V[i] << " ";
}
std::cout << std::endl;

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
output, nullptr /* present_key */, nullptr /* present_value */,
Expand Down

0 comments on commit 8b492e8

Please sign in to comment.