Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…origin/user/pavignol/add-bfc-allocator-4
  • Loading branch information
Patrice Vignola committed Aug 16, 2023
2 parents e6ae058 + 8998b68 commit b7e40e8
Show file tree
Hide file tree
Showing 36 changed files with 6,089 additions and 6,685 deletions.
12 changes: 6 additions & 6 deletions js/node/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ const createConvTranspose2DOpProgramShaderSource =
}
for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
${output.set('batch', 'r', 'c+i', 'd1', 'dotProd[i]')}
${output.set('batch', 'r', 'c+i', 'd1', 'dotProd[i]')};
}
}`;
const codeSnippet = `
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ const createConcatProgramInfo =
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let indices = ${output.offsetToIndices('global_idx')};
var indices = ${output.offsetToIndices('global_idx')};
let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex != 0u) {
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ export const gather = (context: ComputeContext, attributes: GatherAttributes): v
const metadata = {
name: 'Gather',
inputTypes: [GpuDataType.default, GpuDataType.default],
cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10),
cacheHint: attributes.cacheKey,
};

context.compute(createGatherProgramInfo(metadata, context.inputs, attributes));
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes)
const idxZero = [];
for (let k = 0; k < input.shape.length; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
idxZero.push(`inputIndices[${k}] = 0;`); // first element
idxZero.push(input.indicesSet('inputIndices', k, 0));
}
}

Expand Down
6 changes: 3 additions & 3 deletions js/web/script/test-runner-cli-args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Options:
*** Backend Options ***
--wasm-number-threads Set the WebAssembly number of threads
-x, --wasm-number-threads Set the WebAssembly number of threads
--wasm-init-timeout Set the timeout for WebAssembly backend initialization, in milliseconds
--wasm-enable-simd Set whether to enable SIMD
--wasm-enable-proxy Set whether to enable proxy worker
Expand Down Expand Up @@ -264,9 +264,9 @@ function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssem
}

function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags {
const numThreads = args['wasm-number-threads'];
const numThreads = args.x || args['wasm-number-threads'];
if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') {
throw new Error('Flag "wasm-number-threads" must be a number value');
throw new Error('Flag "x"/"wasm-number-threads" must be a number value');
}
const initTimeout = args['wasm-init-timeout'];
if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') {
Expand Down
97 changes: 96 additions & 1 deletion js/web/test/data/ops/reshape.jsonc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[
{
"name": "Reshape with '0' and '-1' in the shape tensor input",
"name": "Reshape float32 with '0' and '-1' in the shape tensor input",
"operator": "Reshape",
"attributes": [],
"cases": [
Expand Down Expand Up @@ -93,5 +93,100 @@
]
}
]
},
{
"name": "Reshape int32 with '0' and '-1' in the shape tensor input",
"operator": "Reshape",
"attributes": [],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 1, 1, 1],
"dims": [2, 2],
"type": "int32"
},
{
"data": [-1, 0],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 1, 1, 1],
"dims": [2, 2],
"type": "int32"
}
]
},
{
"name": "Scalar to 1D",
"inputs": [
{
"data": [1],
"dims": [],
"type": "int32"
},
{
"data": [1],
"dims": [1],
"type": "int64"
}
],
"outputs": [
{
"data": [1],
"dims": [1],
"type": "int32"
}
]
},
{
"name": "Scalar to 2D",
"inputs": [
{
"data": [1],
"dims": [],
"type": "int32"
},
{
"data": [1, 1],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1],
"dims": [1, 1],
"type": "int32"
}
]
},
{
"name": "Scalar to 2D with -1 in shape hints",
"inputs": [
{
"data": [1],
"dims": [],
"type": "int32"
},
{
"data": [-1, 1],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [1],
"dims": [1, 1],
"type": "int32"
}
]
}
]
}
]
3 changes: 2 additions & 1 deletion js/web/test/test-main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ for (const group of ORT_WEB_TEST_CONFIG.op) {
let context: ProtoOpTestContext|OpTestContext;

before('Initialize Context', async () => {
context = useProtoOpTest ? new ProtoOpTestContext(test) : new OpTestContext(test);
context = useProtoOpTest ? new ProtoOpTestContext(test, ORT_WEB_TEST_CONFIG.options.sessionOptions) :
new OpTestContext(test);
await context.init();
if (ORT_WEB_TEST_CONFIG.profile) {
if (context instanceof ProtoOpTestContext) {
Expand Down
5 changes: 3 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ export class ProtoOpTestContext {
private readonly loadedData: Uint8Array; // model data, inputs, outputs
session: ort.InferenceSession;
readonly backendHint: string;
constructor(test: Test.OperatorTest) {
constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) {
const opsetImport = onnx.OperatorSetIdProto.create(test.opset);
const operator = test.operator;
const attribute = (test.attributes || []).map(attr => {
Expand Down Expand Up @@ -714,7 +714,8 @@ export class ProtoOpTestContext {
}
}
async init(): Promise<void> {
this.session = await ort.InferenceSession.create(this.loadedData, {executionProviders: [this.backendHint]});
this.session = await ort.InferenceSession.create(
this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions});
}

async dispose(): Promise<void> {
Expand Down
32 changes: 16 additions & 16 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
}
// attention with past/present state
else if (data.past_key != nullptr || data.present_key != nullptr) {
// Below logic does not support memory efficient attention with past (like pass_past_in_kv) but without bias
if (data.bias == nullptr) {
// cross attention with past state
if (data.past_key != nullptr && data.present_key == nullptr) {
Expand All @@ -344,7 +345,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
assert(data.key == nullptr);
assert(data.value == nullptr);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
max_threads_per_block, false, data.query, q));
}
// cross attention with present state or self attention with present state
else if (data.past_key == nullptr && data.present_key != nullptr) {
Expand All @@ -356,13 +357,13 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
max_threads_per_block, false, data.query, q));

// TODO: supporting packed kv for cross attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.present_key));
max_threads_per_block, false, data.key, data.present_key));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.present_value));
max_threads_per_block, false, data.value, data.present_value));
}
// self attention with past and present state
else {
Expand All @@ -375,11 +376,11 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
assert(data.value != nullptr);
// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
max_threads_per_block, false, data.query, q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, k));
max_threads_per_block, false, data.key, k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, v));
max_threads_per_block, false, data.value, v));
}
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
Expand All @@ -397,9 +398,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

// query => q, temp_k_workspace => k, temp_v_workspace => v
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);

DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
Expand All @@ -419,11 +420,11 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

// temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.temp_k_workspace, data.present_key));
max_threads_per_block, false, data.temp_k_workspace, data.present_key));

// temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.temp_v_workspace, data.present_value));
max_threads_per_block, false, data.temp_v_workspace, data.present_value));

DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size * kv_sequence_length, num_heads, qk_head_size);
Expand Down Expand Up @@ -688,8 +689,7 @@ Status QkvToContext(
if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
k = data.present_key;
v = data.present_value;
}
else {
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
k = data.temp_k_workspace;
v = data.temp_v_workspace;
Expand Down Expand Up @@ -1111,12 +1111,12 @@ Status DecoderQkvToContext(
constexpr int max_sequence_length = 0;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask<T>(ort_stream, kv_sequence_length, sequence_length, batch_size,
num_heads, nullptr, key_padding_mask, add_before_softmax,
false/*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
1.0f, mask_dimension, max_sequence_length, false, nullptr,
mask_filter_value));
} else {
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(stream, kv_sequence_length, sequence_length, batch_size, num_heads,
add_before_softmax, false/*broadcast rpb*/, scratch1, scratch2,
add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
is_unidirectional));
}

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,23 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}

const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value);

#if USE_FLASH_ATTENTION
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32;

// Exclude this case since PrepareQkv will convert the format to BNSH.
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;

bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;

bool use_memory_efficient_attention = fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
!past_no_bias &&
(relative_position_bias == nullptr || is_good_for_rpb) &&
(nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
Expand Down Expand Up @@ -226,7 +232,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data<int>();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span<const int64_t>() : key_padding_mask->Shape().GetDims();
data.past = nullptr;
const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value);
data.past_key = pass_key_value_as_past ? reinterpret_cast<const CudaT*>(key->Data<T>())
: (nullptr == past_key) ? nullptr
: reinterpret_cast<const CudaT*>(past_key->Data<T>());
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/js/js_data_types.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cpu/tensor/shape_op.h"

namespace onnxruntime {
namespace js {

using SupportedTypes =
TypeList<
float,
int32_t,
uint32_t>;

const std::vector<MLDataType>& JsepSupportedDataTypes() {
static const std::vector<MLDataType> supportedDataTypes = BuildKernelDefConstraintsFromTypeList<SupportedTypes>();
return supportedDataTypes;
}
} // namespace js
} // namespace onnxruntime
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/js/js_data_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/framework/data_types.h"

namespace onnxruntime {
namespace js {
std::vector<MLDataType>& JsepSupportedDataTypes();
}
} // namespace onnxruntime
Loading

0 comments on commit b7e40e8

Please sign in to comment.