Skip to content

Commit

Permalink
Attention WIP: temp old softmax code, fixed transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 27, 2023
1 parent 070f06a commit acc200e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 98 deletions.
122 changes: 58 additions & 64 deletions js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
sumVector,
tensorTypeToWsglStorageType
} from './common';
import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose';
import {transposeProgramMetadata} from './transpose';

export enum AttentionQkvFormat {
unknown, // enum value not set, or depends on qkv projection implementation details
Expand Down Expand Up @@ -211,74 +211,44 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
createAttributeWithCacheKey({...attributes});

const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]});

export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, N: number, D: number) => {
const components = getMaxComponents(D);
const inputHelper = outputVariable('x', input.dataType, input.dims, components);

let threadMaxValue = 'threadMaxVector';
if (components === 2) {
threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)';
} else if (components === 4) {
threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))';
}
const dataType = tensorTypeToWsglStorageType(input.dataType);
let WG = 64;
const dComp = D / components;
if (dComp < WG) {
WG = 1;
} else if (dComp / 8 < 64) {
WG = Math.ceil(dComp / 8);
}
const elementsPerWG = Math.ceil(D / components / WG);

// 6.2.4 in wgsl spec
const threadMaxMinValue = dataType === 'f32' ? '-3.402823e+38f' : '-65504.0h';
const getShaderSource = (shaderHelper: ShaderHelper) => `
const dInv: ${dataType} = 1 / ${D};
const dComp = ${D / components};
var<workgroup> wgMax: array<${dataType}, ${WG}>;
var<workgroup> wgSum: array<${dataType}, ${WG}>;
${shaderHelper.declareVariables(inputHelper)}
@compute @workgroup_size(${WG}, 1, 1)
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_index) local_index : u32) {
let localOffset = local_index * ${elementsPerWG};
let offset: u32 = workgroup_id.x * dComp + localOffset;
var threadMaxVector = ${fillVector(dataType, components, threadMaxMinValue)};
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
threadMaxVector = max(x[offset + i], threadMaxVector);
}
wgMax[local_index] = ${threadMaxValue};
workgroupBarrier();
@group(0) @binding(0) var<storage, read_write> x: array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> x2: array<${dataType}>;
var maxValue = ${threadMaxMinValue};
for (var i = 0u; i < ${WG}; i++) {
maxValue = max(wgMax[i], maxValue);
${shaderHelper.mainStart()}
if (global_idx >= ${N}) {
return;
}
let offset: u32 = global_idx * ${D};
var sumVector = ${fillVector(dataType, components, '0')};
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
sumVector += exp(x[offset + i] - maxValue);
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
for (var i: u32 = 0; i < ${D}; i++) {
threadMax = max(f32(x[offset + i]), threadMax);
}
wgSum[local_index] = ${sumVector('sumVector', components)};
workgroupBarrier();
var sum: ${dataType} = 0;
for (var i = 0u; i < ${WG}; i++) {
sum += wgSum[i];
var sum: f32 = 0.0;
for (var i: u32 = 0; i < ${D}; i++) {
let val: f32 = exp(f32(x[offset + i]) - threadMax);
// x[offset + i] = ${dataType}(val);
sum += val;
}
// for (var i: u32 = 0; i < ${D}; i++) {
// sum += x[offset + i];
// }
if (sum == 0) {
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
x[offset + i] = ${fillVector(dataType, components, 'dInv')};
for (var i: u32 = 0; i < ${D}; i++) {
x[offset + i] = dInv;
x2[offset + i] = dInv;
}
} else {
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
x[offset + i] = exp(x[offset + i] - maxValue) / sum;
for (var i: u32 = 0; i < ${D}; i++) {
x[offset + i] = ${dataType}(exp(f32(x[offset + i]) - threadMax) / sum);
x2[offset + i] = x[offset + i];
}
}
}`;
Expand All @@ -288,11 +258,13 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
name: 'computeAttentionProbsSoftmax',
cacheHint: '0',
inputTypes: [GpuDataType.default],
outputs: [],
outputs: [
{dims: input.dims, dataType: DataType.float, gpuDataType: GpuDataType.default}
],
getShaderSource,
dispatchGroup: () => ({x: N})
dispatchGroup: () => ({x: Math.ceil(N / 64)})
},
{inputs: [input], outputs: []});
{inputs: [input], outputs: [-1]});
};

const computeAttentionProbs =
Expand Down Expand Up @@ -409,11 +381,11 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v:

const dataType = tensorTypeToWsglStorageType(probs.dataType);

const TILE_SIZE = 8;
const TILE_SIZE = 1;
const dispatch = {
x: Math.ceil(params.vHeadSize / TILE_SIZE),
y: Math.ceil(params.sequenceLength / TILE_SIZE),
z: params.batchSize * params.numHeads
y: Math.ceil(params.totalSequenceLength / TILE_SIZE),
z: params.batchSize * params.numHeads,
};

const getShaderSource = (shaderHelper: ShaderHelper) => `
Expand Down Expand Up @@ -483,13 +455,35 @@ export const applyAttention =

const attentionResult = computeVxAttentionScore(context, probs, v, parameters);

const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize];
const input = inputVariable('input', q.dataType, attentionResult.dims);
const output = outputVariable('output', q.dataType, outputShape);
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
${shaderHelper.mainStart(parameters.numHeads * parameters.batchSize)}
let batchIndex = global_idx / ${parameters.numHeads};
let headIndex = global_idx % ${parameters.numHeads};
// let in = input[0];
var inputOffset = ${parameters.sequenceLength * parameters.vHeadSize} * global_idx;
var outputOffset = (batchIndex * ${parameters.sequenceLength * parameters.numHeads} + headIndex) * ${parameters.vHeadSize};
for (var j = 0; j < ${parameters.sequenceLength}; j++) {
for (var i: u32 = 0; i < ${parameters.vHeadSize}; i++) {
output[outputOffset + i] = input[inputOffset + i];
}
inputOffset += ${parameters.vHeadSize};
outputOffset += ${parameters.vHiddenSize};
}
}`;

context.compute(
{
...transposeProgramMetadata,
cacheHint: JSON.stringify(parameters) + JSON.stringify(attributes),
get: () => createTransposeProgramInfo(
attentionResult, weightTransposeAttribute.perm,
[parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize])
cacheHint: JSON.stringify(parameters),
outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({ x: 1 }),
},
{inputs: [attentionResult], outputs: [0]});
};
Expand Down
68 changes: 34 additions & 34 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,40 @@ export class ProgramManager {

this.backend.pendingDispatchNumber++;

// this.backend.endComputePass();
// const kernelId = this.backend.currentKernelId!;
// const kernelName = this.backend.kernels.get(kernelId)![0];
// for (const output of outputs) {
// const stagingBuffer = this.backend.device.createBuffer({
// size: output.buffer.size,
// // eslint-disable-next-line no-bitwise
// usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
// });
//
// const commandEncoder = this.backend.getCommandEncoder();
// commandEncoder?.copyBufferToBuffer(
// output.buffer,
// 0, // Source offset
// stagingBuffer,
// 0, // Destination offset
// output.buffer.size,
// );
// this.backend.flush();
//
// stagingBuffer
// .mapAsync(
// GPUMapMode.READ,
// 0, // Offset
// output.buffer.size,
// )
// .then(() => {
// const copyArrayBuffer = stagingBuffer.getMappedRange(0, output.buffer.size);
// const data = copyArrayBuffer.slice(0);
// stagingBuffer.unmap();
// console.log(`${kernelId}|${kernelName}:`);
// console.log(new Float32Array(data));
// });
// }
this.backend.endComputePass();
const kernelId = this.backend.currentKernelId!;
const kernelName = this.backend.kernels.get(kernelId)![0];
for (const output of outputs) {
const stagingBuffer = this.backend.device.createBuffer({
size: output.buffer.size,
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST,
});

const commandEncoder = this.backend.getCommandEncoder();
commandEncoder?.copyBufferToBuffer(
output.buffer,
0, // Source offset
stagingBuffer,
0, // Destination offset
output.buffer.size,
);
this.backend.flush();

stagingBuffer
.mapAsync(
GPUMapMode.READ,
0, // Offset
output.buffer.size,
)
.then(() => {
const copyArrayBuffer = stagingBuffer.getMappedRange(0, output.buffer.size);
const data = copyArrayBuffer.slice(0);
stagingBuffer.unmap();
console.log(`${kernelId}|${kernelName}:`);
console.log(new Float32Array(data));
});
}

if (profilingEnabled) {
// profiling write end timestamp
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,23 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
});
}

std::cout << "Prepare completed.";
std::cout << "First 10 values at Q: ";
for (size_t i = 0; i < qkv_head_size[0] * sequence_length * batch_size * num_heads_; ++i) {
std::cout << Q[i] << " ";
}
std::cout << std::endl;
std::cout << "First 10 values at K: ";
for (size_t i = 0; i < qkv_head_size[1] * sequence_length * batch_size * num_heads_; ++i) {
std::cout << K[i] << " ";
}
std::cout << std::endl;
std::cout << "First 10 values at V: ";
for (size_t i = 0; i < qkv_head_size[2] * sequence_length * batch_size * num_heads_; ++i) {
std::cout << V[i] << " ";
}
std::cout << std::endl;

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
output, nullptr /* present_key */, nullptr /* present_value */,
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,24 @@ class AttentionCPUBase : public AttentionBase {
});
}

std::cout << "Probs before softmax.";
for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) {
std::cout << attention_probs[i] << " ";
}
std::cout << std::endl;

// attention_probs(B, N, S, T) = Softmax(attention_probs)
{
const int N = batch_size * num_heads_ * sequence_length;
const int D = total_sequence_length;
ComputeAttentionSoftmaxInplace(attention_probs, N, D, tp);
}

std::cout << "Probs after softmax.";
for (size_t i = 0; i < total_sequence_length * sequence_length * batch_size * num_heads_; ++i) {
std::cout << attention_probs[i] << " ";
}
std::cout << std::endl;
}

template <typename T>
Expand Down

0 comments on commit acc200e

Please sign in to comment.