diff --git a/CITATION.cff b/CITATION.cff
index 82bcac5a7b750..10b7290022aef 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -3,8 +3,7 @@ title: ONNX Runtime
message: "Please use this information to cite ONNX Runtime in
research or other publications."
authors:
- - affiliation: Microsoft Corporation
- given-names: ONNX Runtime developers
+ - name: ONNX Runtime developers
date-released: 2018-11-29
url: "https://onnxruntime.ai"
repository-code: "https://github.com/microsoft/onnxruntime"
diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake
index a56864ebf4644..8161ea574b8cc 100644
--- a/cmake/adjust_global_compile_flags.cmake
+++ b/cmake/adjust_global_compile_flags.cmake
@@ -92,13 +92,8 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()
-# Enable stream for all the non-minimal build, except for DML. There's currently a bug
-# in the allocation planner when reusing buffers and more than one streams are used that
-# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
-# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
-# safest option for now.
-# https://github.com/microsoft/onnxruntime/issues/19480
-if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
+# Enable stream for all the non-minimal build
+if (NOT onnxruntime_MINIMAL_BUILD)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 85a9bf50460d3..1bb70e9c2ed27 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
"bert/fastertransformer_decoder_attention/*"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
- "bert/fast_gelu_impl.cu"
- "bert/fast_gelu_impl.h"
- "bert/fast_gelu.cc"
- "bert/fast_gelu.h"
"bert/relative_attn_bias.cc"
"bert/relative_attn_bias.h"
"bert/relative_attn_bias_impl.cu"
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
index 715aed7e1d64f..7f3d5d6624b07 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs
@@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
private void CanRunInferenceOnAModelWithTensorRT()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
-
+
int deviceId = 0;
string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8ff2135c6b1f6..b0ed68d595c42 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -127,6 +127,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)|
+|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
@@ -606,6 +607,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)|
|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)|
|||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)|
+|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
@@ -617,6 +619,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index e5a2804894148..e79ce1c10d2d9 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -33,6 +33,8 @@ class Node;
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"
+struct OrtRunOptions;
+
namespace onnxruntime {
/**
@@ -51,6 +53,8 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};
+using RunOptions = OrtRunOptions;
+
enum class DataLayout {
NCHW,
NHWC,
@@ -189,7 +193,7 @@ class IExecutionProvider {
Run may not be finished on device This function should be regarded as the
point after which a new Run would start to submit commands from CPU
*/
- virtual common::Status OnRunStart() { return Status::OK(); }
+ virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }
/**
Called when InferenceSession::Run ended
@@ -197,7 +201,9 @@ class IExecutionProvider {
may not be finished on device This function should be regarded as the point
that all commands of current Run has been submmited by CPU
*/
- virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
+ virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
+ return Status::OK();
+ }
/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h
index 1fef077860be3..00e7dec5727d1 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_resource.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h
@@ -19,4 +19,4 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
-};
\ No newline at end of file
+};
diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
index ebd573dca893c..6ada9702031a8 100644
--- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
@@ -31,5 +31,17 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";
+// Set HTP performance mode for QNN HTP backend before session run.
+// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
+// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
+// "sustained_high_performance". Default to "default".
+static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";
+
+// Set HTP performance mode for QNN HTP backend post session run.
+static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";
+
+// Set RPC control latency for QNN HTP backend
+static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
+
// TODO
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "ep.cuda.cuda_graph_annotation";
diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts
index e8eb0e9babf5a..927953b4f1dd6 100644
--- a/js/node/lib/backend.ts
+++ b/js/node/lib/backend.ts
@@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise {
return new Promise((resolve, reject) => {
- process.nextTick(() => {
+ setImmediate(() => {
try {
resolve(this.#inferenceSession.run(feeds, fetches, options));
} catch (e) {
@@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend {
async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise {
return new Promise((resolve, reject) => {
- process.nextTick(() => {
+ setImmediate(() => {
try {
resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {}));
} catch (e) {
diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock
index 9e20a286c4e27..6f05faf046098 100644
--- a/js/react_native/e2e/yarn.lock
+++ b/js/react_native/e2e/yarn.lock
@@ -3351,9 +3351,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"
ip@^1.1.5:
- version "1.1.8"
- resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
- integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+ version "1.1.9"
+ resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+ integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
is-accessor-descriptor@^0.1.6:
version "0.1.6"
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 98990a6fe477b..3e3a191ec3ead 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -385,11 +385,16 @@ export class WebGpuBackend {
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < inputTensorViews.length; ++i) {
- const gpuData = this.gpuDataManager.get(inputTensorViews[i].data);
+ const data = inputTensorViews[i].data;
+ // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
+ if (data === 0) {
+ continue;
+ }
+ const gpuData = this.gpuDataManager.get(data);
if (!gpuData) {
- throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`);
+ throw new Error(`no GPU data for input: ${data}`);
}
- inputDatas[i] = gpuData;
+ inputDatas.push(gpuData);
}
const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);
@@ -419,6 +424,11 @@ export class WebGpuBackend {
const tensorView = (isTemporary || isPersistent) ?
createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
+ outputTensorViews.push(tensorView);
+ // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
+ if (tensorView.data === 0) {
+ continue;
+ }
const gpuData = this.gpuDataManager.get(tensorView.data);
if (!gpuData) {
throw new Error(`no GPU data for output: ${tensorView.data}`);
@@ -434,10 +444,24 @@ export class WebGpuBackend {
}
persistentData.push(gpuData);
}
- outputTensorViews.push(tensorView);
outputDatas.push(gpuData);
}
+ // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
+ // zero-sized tensors.
+ if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
+ // if all outputs are zero-sized tensors, there is no need to run the program.
+ if (outputDatas.length === 0) {
+ TRACE_FUNC_END(program.name);
+ return outputTensorViews;
+ }
+ // if some outputs are zero-sized tensors, report an error.
+ //
+ // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
+ // If we see such use case, we need to make a change here to support it.
+ throw new Error(
+ `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`);
+ }
// load uniforms
// TODO: add cache for uniform (is it necessary?)
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index 786ae41646554..b64abf9cc5424 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -104,7 +104,8 @@ class ComputeContextImpl implements ComputeContext {
throw new Error(`Unsupported data type: ${dataType}`);
}
const bufferSize = elementSize * ShapeUtil.size(dims);
- return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims);
+ const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
+ return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
};
return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput);
}
diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts
index c0517ce363644..9a1d5463f7843 100644
--- a/js/web/lib/wasm/jsep/util.ts
+++ b/js/web/lib/wasm/jsep/util.ts
@@ -56,7 +56,16 @@ export class BroadcastUtil {
if (aLen !== bLen && aLen > 1 && bLen > 1) {
return undefined;
}
- cdims[crank - i] = Math.max(aLen, bLen);
+ const max = Math.max(aLen, bLen);
+ if (aLen && bLen) {
+ cdims[crank - i] = Math.max(aLen, bLen);
+ } else {
+ // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable.
+ if (max > 1) {
+ return undefined;
+ }
+ cdims[crank - i] = 0;
+ }
}
return cdims;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
index b5b6a2a15cd8c..11c8778b72335 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
@@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
-import {biasSnippet, typeSnippet} from './activation_util';
+import {biasSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
const conv2dTransposeCommonSnippet =
- (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
- const type = typeSnippet(innerElementSize, 'f32');
+ (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
+ innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
@@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))];
let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))];
let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))];
- return vec4(v0, v1, v2, v3);
+ return ${type}(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
@@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
- fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} {
+ fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
@@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
{name: 'pads', type: 'i32', length: pads.length}
];
appendActivationUniforms(attributes, uniforms);
+ const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
+ if (elemType !== 'f16' && elemType !== 'f32') {
+ throw new Error(`elemType ${elemType} is not supported.`);
+ }
return `
${utilFunctions('uniforms.result_strides')}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
- ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
+ ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
${
isVec4 ? makeMatMulPackedVec4Source(
- elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
+ elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
- elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
+ elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
undefined, sequentialAccessByThreads)}`;
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
index b06c9fb496d15..b142a82e551a7 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
@@ -154,7 +154,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
- context.compute(createConcatProgramInfo(context.inputs, attributes.axis));
+ // 0 length tensors are valid for concat, remove them
+ const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
+ context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
};
export const parseConcatAttributes = (attributes: Record): ConcatAttributes =>
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
index 5c31e6dd86c00..d48bb909f7f8f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
@@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
if (idx${x} < 0) {
idx${x} = idx${x} + uniforms.axisDimLimit;
}
- var dataIndices${x} = ${data.type.indices}(0);
+ var dataIndices${x} : ${data.type.indices};
`;
for (let i = 0, j = 0; i < inputRank; i++) {
if (i === axis) {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
index cfee07a9239d7..a6375847fc42f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
@@ -27,7 +27,7 @@ const createWhereOpProgramShader =
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
- const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
+ const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
return `
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
@@ -38,6 +38,7 @@ const createWhereOpProgramShader =
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
+ let component_c${x} = offset_c${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc
index e5b4ff2b53148..dd15134861ef0 100644
--- a/js/web/test/data/ops/add.jsonc
+++ b/js/web/test/data/ops/add.jsonc
@@ -157,6 +157,28 @@
"type": "float32"
}
]
+ },
+ {
+ "name": "T[2,0] T[2,1]",
+ "inputs": [
+ {
+ "data": [],
+ "dims": [2, 0],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2],
+ "dims": [2, 1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [],
+ "dims": [2, 0],
+ "type": "float32"
+ }
+ ]
}
]
}
diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc
index 047fd6fd7511b..990120dd3708e 100644
--- a/js/web/test/data/ops/where.jsonc
+++ b/js/web/test/data/ops/where.jsonc
@@ -168,5 +168,39 @@
]
}
]
+ },
+ {
+ "name": "Where with no attributes",
+ "operator": "Where",
+ "attributes": [],
+ "cases": [
+ {
+ "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
+ "inputs": [
+ {
+ "data": [true, false],
+ "dims": [1, 1, 2, 1],
+ "type": "bool"
+ },
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 4],
+ "type": "float32"
+ },
+ {
+ "data": [5, 6, 7, 8, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 4, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
}
]
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index b43b1ac37e37d..88555a27be82e 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -1231,7 +1231,7 @@
"test_split_variable_parts_1d",
"test_split_variable_parts_2d",
"test_split_variable_parts_default_axis",
- // // "test_split_zero_size_splits",
+ "test_split_zero_size_splits",
"test_sqrt_example",
"test_sqrt",
"test_squeeze_negative_axes",
diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts
index ecc7d4b4a09a5..7c03e5b915fd7 100644
--- a/js/web/test/test-runner.ts
+++ b/js/web/test/test-runner.ts
@@ -573,7 +573,9 @@ export async function sessionRun(options: {
// replace the CPU tensors in feeds into GPU tensors
for (const name in feeds) {
if (Object.hasOwnProperty.call(feeds, name)) {
- feeds[name] = createGpuTensorForInput(feeds[name]);
+ if (feeds[name].size > 0) {
+ feeds[name] = createGpuTensorForInput(feeds[name]);
+ }
}
}
}
@@ -582,7 +584,11 @@ export async function sessionRun(options: {
for (const name in options.outputsMetaInfo) {
if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) {
const {type, dims} = options.outputsMetaInfo[name];
- fetches[name] = createGpuTensorForOutput(type, dims);
+ if (dims.some(d => d === 0)) {
+ fetches[name] = new ort.Tensor(type, [], dims);
+ } else {
+ fetches[name] = createGpuTensorForOutput(type, dims);
+ }
}
}
}
@@ -627,8 +633,8 @@ export async function runModelTestSet(
try {
const feeds: Record = {};
const outputsMetaInfo: Record = {};
- testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
- testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
+ testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
+ testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
const [start, end, outputs] =
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
if (context.perfData.count === 0) {
diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc
index 556699192d2eb..3e0533dd8b9e5 100644
--- a/onnxruntime/contrib_ops/cpu/activations.cc
+++ b/onnxruntime/contrib_ops/cpu/activations.cc
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/cpu/activation/activations.h"
-#include "activations.h"
+#include "contrib_ops/cpu/activations.h"
namespace onnxruntime {
namespace contrib {
@@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()),
ThresholdedRelu);
-ONNX_OPERATOR_KERNEL_EX(
- Gelu,
- kMSDomain,
- 1,
- kCpuExecutionProvider,
- KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
- Gelu);
-
ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h
index aed4c2229215d..7e64235d3fc3d 100644
--- a/onnxruntime/contrib_ops/cpu/activations.h
+++ b/onnxruntime/contrib_ops/cpu/activations.h
@@ -54,47 +54,6 @@ namespace contrib {
DEFINE_ELE_KERNEL(ScaledTanh);
DEFINE_ELE_KERNEL(ParametricSoftplus);
-template
-class Gelu : public OpKernel {
- public:
- Gelu(const OpKernelInfo& info) : OpKernel(info) {
- }
-
- Status Compute(OpKernelContext* context) const override {
- const Tensor* input = context->Input(0);
- const T* input_data = input->Data();
-
- Tensor* output = context->Output(0, input->Shape());
- T* output_data = output->MutableData();
-
- concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
- int64_t elem_count = input->Shape().Size();
- constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
- int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
- concurrency::ThreadPool::TryBatchParallelFor(
- tp, static_cast(task_count),
- [&](ptrdiff_t task_idx) {
- const auto start = task_idx * length_per_task;
- const T* p_input = input_data + start;
- T* p_output = output_data + start;
- int64_t count = std::min(length_per_task, elem_count - start);
-
- for (int64_t i = 0; i < count; i++) {
- T value = p_input[i];
- p_output[i] = value * static_cast(M_SQRT1_2);
- }
-
- MlasComputeErf(p_output, p_output, narrow(count));
-
- for (int64_t i = 0; i < count; i++) {
- p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
- }
- },
- 0);
- return Status::OK();
- }
-};
-
// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
// MlasComputeLogistic instead of using Eigen for better perf.
template
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc
index 1a86c5dbece5a..6303858b9bd48 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.cc
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc
@@ -49,7 +49,6 @@ namespace cuda {
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
-UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h
index ab339f276c2bd..fc9a71b0b7fa1 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.h
@@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise {
float beta_;
};
-template
-class Gelu final : public UnaryElementwise {
- public:
- Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
-
- Status ComputeInternal(OpKernelContext* context) const override;
-
- private:
- MAKE_FUNC_CTX_NULL()
-};
-
template
class QuickGelu final : public UnaryElementwise {
public:
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
index 0c856815fd437..36f33fbb24c18 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
@@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh {
}
};
-template
-struct OP_Gelu : public CtxGelu {
- __device__ __inline__ T operator()(const T& a) const {
- return _Gelu(a);
- }
-};
-
-template <>
-struct OP_Gelu : public CtxGelu {
- __device__ __inline__ half operator()(const half& a) const {
- return static_cast(_Gelu(static_cast(a)));
- }
-};
-
template
struct OP_QuickGelu : public CtxQuickGelu {
__device__ __inline__ T operator()(const T& a) const {
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
index 5d18283a395e3..782d4bf59a5ad 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
@@ -11,14 +11,12 @@ namespace cuda {
typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
-typedef onnxruntime::cuda::CtxNull CtxGelu;
typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
- UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
index 892f5c181a607..e8974a29476b6 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
@@ -4,9 +4,14 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "fast_gelu.h"
-#include "fast_gelu_impl.h"
+#include "core/providers/cuda/tensor/gelu_impl.h"
#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
-#include "transformer_common.h"
+#ifdef USE_ROCM
+#include "contrib_ops/rocm/bert/elementwise.h"
+#endif
+#ifdef USE_CUDA
+#include "contrib_ops/cuda/bert/transformer_common.h"
+#endif
namespace onnxruntime {
namespace contrib {
@@ -31,8 +36,10 @@ using namespace ONNX_NAMESPACE;
template
FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
+#ifdef USE_CUDA
const TransformerOptions* options = TransformerOptions::GetInstance();
use_half2_ = !options->DisableHalf2();
+#endif
}
template
@@ -50,6 +57,14 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType::MappedType CudaT;
+#ifdef USE_ROCM
+ return LaunchElementwiseKernel(
+ GetTuningContext(), context->GetComputeStream(),
+ reinterpret_cast(input->Data()), static_cast(input_length),
+ (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length),
+ reinterpret_cast(output->MutableData()));
+#endif
+#ifdef USE_CUDA
return LaunchFastGeluKernel(GetDeviceProp(),
Stream(context),
static_cast(input_length),
@@ -58,6 +73,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const {
(nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr,
reinterpret_cast(output->MutableData()),
use_half2_);
+#endif
}
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
index 3e642a70afef5..d563556593e6e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h
@@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const override;
private:
- bool use_half2_;
+ bool use_half2_; // Only applicable to CUDA kernel (not ROCM).
};
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index be8c0dc86c135..57e951d3a68ff 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -203,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
#endif
+#ifdef ENABLE_CUDA_NHWC_OPS
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
+#endif
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
KernelCreateInfo info;
@@ -408,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
#endif
+#ifdef ENABLE_CUDA_NHWC_OPS
+ BuildKernelCreateInfo,
+#endif
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc
index 4c2999c279e0a..2500de39d3536 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.cc
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc
@@ -9,22 +9,23 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define REGISTER_KERNEL_TYPED(T) \
+#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
- kMSDomain, \
- 1, \
+ DOMAIN, \
+ VERSION, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- GridSample);
+ onnxruntime::contrib::cuda::GridSample);
-REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
-template
-GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
+template
+GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros");
align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0));
@@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
}
}
-template
-Status GridSample::ComputeInternal(OpKernelContext* context) const {
+template
+Status GridSample::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input(0);
const auto& dims_input = X->Shape().GetDims();
const Tensor* Grid = context->Input(1);
@@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
+ using Ch = Channels;
+
TensorShapeVector dims_output(4);
- dims_output[0] = dims_input[0];
- dims_output[1] = dims_input[1];
- dims_output[2] = dims_grid[1];
- dims_output[3] = dims_grid[2];
+ dims_output[Ch::N] = dims_input[Ch::N];
+ dims_output[Ch::C] = dims_input[Ch::C];
+ dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
+ dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
@@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType::MappedType CudaT;
CudaT* Y_data = reinterpret_cast(Y->MutableData());
- GridSampleImpl(
+ GridSampleImpl(
Stream(context),
reinterpret_cast(X->Data()),
reinterpret_cast(Grid->Data()),
@@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
}
} // namespace cuda
} // namespace contrib
+
+namespace cuda {
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
+} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h
index 08ca58c7cc458..16581bfe77482 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.h
@@ -12,7 +12,7 @@ namespace cuda {
using namespace onnxruntime::cuda;
-template
+template
class GridSample final : public CudaKernel {
public:
explicit GridSample(const OpKernelInfo& info);
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
index 8a391eca7e86a..b23da635bc83d 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
@@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) {
return static_cast(fx);
}
-template
+template
__device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
- int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
+ int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
T pixel = 0.0f;
+
+ auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t {
+ return Layout == LAYOUT_NCHW
+ ? (bIdx * C * H * W + cIdx * H * W + y * W + x)
+ : (bIdx * H * W * C + y * W * C + x * C + cIdx);
+ };
+
if (padding_mode == 0) { // zeros
if (x >= 0 && x < W && y >= 0 && y < H) {
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ pixel = input_data[PixelOffset(x, y)];
}
- } else if (padding_mode == 1) { //border
+ } else if (padding_mode == 1) { // border
x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ pixel = input_data[PixelOffset(x, y)];
} else { // Reflection
- x = (int64_t) GsReflect(x, border[0], border[2]);
- y = (int64_t) GsReflect(y, border[1], border[3]);
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ x = (int64_t)GsReflect(x, border[0], border[2]);
+ y = (int64_t)GsReflect(y, border[1], border[3]);
+ pixel = input_data[PixelOffset(x, y)];
}
return pixel;
}
-__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
-{
+__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) {
float cubic_alpha = -0.75f;
x = abs(x);
coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha);
@@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) {
return pixel;
}
-template
+template
__global__ void _GridSampleKernel(
const T* input_data,
const T* grid_data,
@@ -110,16 +116,32 @@ __global__ void _GridSampleKernel(
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
- int BIdx = idx / (C * H_out * W_out );
- int tmpBCnt = BIdx * (C * H_out * W_out);
+ int BIdx, yIdx, xIdx, cIdx;
+ if constexpr (Layout == LAYOUT_NCHW) {
+ BIdx = idx / (C * H_out * W_out);
+ int tmpBCnt = BIdx * (C * H_out * W_out);
+
+ cIdx = (idx - tmpBCnt) / (H_out * W_out);
+ int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
- int cIdx = (idx - tmpBCnt) / (H_out * W_out);
- int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
+ yIdx = (idx - tmpCCnt) / W_out;
+ int tmpHCnt = tmpCCnt + yIdx * W_out;
- int yIdx = (idx - tmpCCnt) / W_out;
- int tmpHCnt = tmpCCnt + yIdx * W_out;
+ xIdx = (idx - tmpHCnt);
+ } else {
+ static_assert(Layout == LAYOUT_NHWC, "Unsupported layout");
- int xIdx = (idx - tmpHCnt);
+ BIdx = idx / (H_out * W_out * C);
+ int tmpBCnt = BIdx * (H_out * W_out * C);
+
+ yIdx = (idx - tmpBCnt) / (W_out * C);
+ int tmpHCnt = tmpBCnt + yIdx * (W_out * C);
+
+ xIdx = (idx - tmpHCnt) / C;
+ int tmpWCnt = tmpHCnt + xIdx * C;
+
+ cIdx = (idx - tmpWCnt);
+ }
int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
T grid_X = grid_data[grid_idx * 2 + 0];
@@ -147,8 +169,9 @@ __global__ void _GridSampleKernel(
if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
if (padding_mode == 1) { // border
- grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
- grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
+ // Clamping must not be done here, see #10607
+ // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
+ // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
} else if (padding_mode == 2) { // reflection
grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
@@ -175,10 +198,10 @@ __global__ void _GridSampleKernel(
w_lb = w_b * w_l;
w_rb = w_b * w_r;
- T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
- T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
- T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
- T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
+ T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
+ T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
+ T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
+ T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
output_data[outIdx] = interpoV;
return;
@@ -186,7 +209,8 @@ __global__ void _GridSampleKernel(
if (mode == 1) { // nearest
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
- output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
+ output_data[outIdx] =
+ PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
return;
}
if (mode == 2) { // bicubic
@@ -195,7 +219,8 @@ __global__ void _GridSampleKernel(
T p[4][4] = {}; // [H][W]
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
- p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
+ p[h][w] =
+ PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
}
}
T dx = grid_x_imgSpace - x0 - 1;
@@ -204,7 +229,7 @@ __global__ void _GridSampleKernel(
}
}
-template
+template
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
@@ -216,17 +241,23 @@ void GridSampleImpl(
const int64_t H_out,
const int64_t W_out,
T* output_data) {
- int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock));
- _GridSampleKernel<<>>(
- input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data);
+ using Ch = Channels;
+
+ int blocksPerGrid = static_cast(
+ ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));
+ _GridSampleKernel<<>>(
+ input_data, grid_data, mode, padding_mode, align_corners,
+ dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W],
+ H_out, W_out, output_data);
}
-#define SPECIALIZED_IMPL(T) \
- template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \
- const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
- const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
+#define SPECIALIZED_IMPL(T, IsNHWC) \
+ template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \
+ const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
+ const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
-SPECIALIZED_IMPL(float)
+SPECIALIZED_IMPL(float, false) // NCHW
+SPECIALIZED_IMPL(float, true) // NHWC
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
index 6df86ce161908..62cd66a48fa84 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
@@ -8,7 +8,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-template
+template
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
deleted file mode 100644
index 9cb414e4e8980..0000000000000
--- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#include "contrib_ops/rocm/bert/fast_gelu.h"
-
-#include "core/providers/rocm/rocm_common.h"
-#include "core/providers/rocm/miopen_common.h"
-#include "contrib_ops/cpu/bert/bias_gelu_helper.h"
-#include "contrib_ops/rocm/bert/elementwise.h"
-#include "contrib_ops/rocm/bert/transformer_common.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
- FastGelu, \
- kMSDomain, \
- 1, \
- T, \
- kRocmExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
- FastGelu);
-
-REGISTER_KERNEL_TYPED(float)
-REGISTER_KERNEL_TYPED(MLFloat16)
-REGISTER_KERNEL_TYPED(BFloat16)
-
-using namespace ONNX_NAMESPACE;
-
-template
-Status FastGelu::ComputeInternal(OpKernelContext* context) const {
- ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
-
- const Tensor* input = context->Input(0);
- const Tensor* bias = context->Input(1);
- Tensor* output = context->Output(0, input->Shape());
-
- int64_t input_length = input->Shape().Size();
- if (input_length == 0) {
- return Status::OK();
- }
- int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
- typedef typename ToHipType::MappedType HipT;
-
- const HipT* input_buffer = reinterpret_cast(input->Data());
- const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr;
- return LaunchElementwiseKernel(
- GetTuningContext(), context->GetComputeStream(),
- input_buffer, static_cast(input_length),
- bias_buffer, static_cast(bias_length),
- reinterpret_cast(output->MutableData()));
-}
-
-} // namespace rocm
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
deleted file mode 100644
index 42bfe5a0b0246..0000000000000
--- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/common/common.h"
-#include "core/providers/rocm/rocm_kernel.h"
-
-namespace onnxruntime {
-namespace contrib {
-namespace rocm {
-
-using namespace onnxruntime::rocm;
-
-template
-class FastGelu final : public RocmKernel {
- public:
- FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {}
- Status ComputeInternal(OpKernelContext* ctx) const override;
-};
-
-} // namespace rocm
-} // namespace contrib
-} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
index b3d3e92209b39..c6ca16bfdfc80 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh
@@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
auto block_size = metadata->constants.at("BLOCK_SIZE");
auto hw_size = metadata->constants.at("HW_SIZE");
auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status {
- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
- "Input skip or bias is not supported by triton kernel.");
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size,
"Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (",
@@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
}
// Construct args for launch kernel
struct {
- void* X;
- void* Y;
+ const void* src;
+ const void* skip;
+ const void* bias;
+ void* out;
+ void* add_out;
const void* gamma;
const void* beta;
int hw;
int c;
int c_per_group;
float eps;
+ bool has_skip;
+ bool has_bias;
+ bool broadcast_skip;
} args = {
- (void*)params->src,
+ (const void*)params->src,
+ (const void*)params->skip,
+ (const void*)params->bias,
(void*)params->dst,
+ (void*)params->skip_workspace,
(const void*)params->gamma,
(const void*)params->beta,
params->hw,
params->c,
params->channels_per_group,
- params->epsilon};
+ params->epsilon,
+ params->skip != nullptr,
+ params->bias != nullptr,
+ params->broadcast_skip,
+ };
// Grid dim is (batch_count, groups, 1)
return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
index 5368cb1cf635b..5ba96ebc117f0 100644
--- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
+++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py
@@ -12,13 +12,19 @@
@triton.jit
def group_norm_kernel(
input_ptr,
+ skip_ptr,
+ bias_ptr,
output_ptr,
+ add_out_ptr,
gamma_ptr,
beta_ptr,
img_size,
c,
c_per_group,
eps,
+ has_skip,
+ has_bias,
+ broadcast_skip,
BLOCK_SIZE: tl.constexpr,
HW_SIZE: tl.constexpr,
ACTIVATION_SILU: tl.constexpr,
@@ -36,14 +42,35 @@ def group_norm_kernel(
offsets = hw[:, None] * c + cols[None, :]
mask = (cols < c_per_group)[None, :]
+ bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ if has_skip:
+ add_out_ptr += row_x * stride + row_y * c_per_group
+ if broadcast_skip:
+ broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group
+ bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
+ else:
+ skip_ptr += row_x * stride + row_y * c_per_group
+ if has_bias:
+ bias_ptr += row_y * c_per_group
+ bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32)
+
# Calculate mean and variance
_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
_square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
x_ptr = input_ptr + i * HW_SIZE * c
a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ if has_skip and not broadcast_skip:
+ s_ptr = skip_ptr + i * HW_SIZE * c
+ s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ a += s
+ if has_bias or broadcast_skip:
+ a += bias
_sum += a
_square_sum += a * a
+ if has_skip:
+ add_y_ptr = add_out_ptr + i * HW_SIZE * c
+ tl.store(add_y_ptr + offsets, a, mask=mask)
# Set axis=None (or leave it unspecified) to reduce all axes.
# TODO: In older Triton we have to reduce an axis at a time, but in our case
@@ -57,9 +84,13 @@ def group_norm_kernel(
gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32)
beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32)
for i in range(tl.cdiv(img_size, HW_SIZE)):
- x_ptr = input_ptr + i * HW_SIZE * c
y_ptr = output_ptr + i * HW_SIZE * c
- x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ if has_skip:
+ add_y_ptr = add_out_ptr + i * HW_SIZE * c
+ x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ else:
+ x_ptr = input_ptr + i * HW_SIZE * c
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - group_mean) * rstd
y = x_hat * gamma + beta
if ACTIVATION_SILU:
@@ -77,7 +108,7 @@ def group_norm_kernel(
hw_sizes = [8, 16, 32, 64, 128, 256]
warps = [1, 2, 4, 8, 16]
name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}"
-sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32"
+sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1"
group_pattern = "GroupNormTriton_{}_{}"
@@ -88,7 +119,7 @@ def get_function_table():
silu_suffix = "Silu" if silu else "Pass"
name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp)
group = group_pattern.format(silu_suffix, dtype)
- sig = sig_pattern.format(dtype, dtype)
+ sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype)
kwargs = {
"num_warps": warp,
"constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)},
diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc
index ea7a6432a7507..158ab8ed610f4 100644
--- a/onnxruntime/core/framework/allocation_planner.cc
+++ b/onnxruntime/core/framework/allocation_planner.cc
@@ -182,7 +182,6 @@ class PlannerImpl {
// upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node
// upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream
InlinedHashMap> dependence_graph_;
- InlinedHashMap> value_consumer_map_;
InlinedHashMap value_node_map_;
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
@@ -295,7 +294,7 @@ class PlannerImpl {
}
#endif
- // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node.
+ // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node.
bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input,
bool* is_strided_tensor) {
*is_strided_tensor = false;
@@ -530,6 +529,7 @@ class PlannerImpl {
// Initialize allocation plan:
plan_.allocation_plan.resize(num_ml_values);
+ for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i;
}
bool HasExternalOutputs(const Node& node) const {
@@ -1065,7 +1065,8 @@ class PlannerImpl {
// build the consumer list for each value
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
- value_consumer_map_.reserve(num_ml_values);
+ InlinedHashMap> value_consumer_map;
+ value_consumer_map.reserve(num_ml_values);
// iterate each stream from back, so the first element is the last consumer in single stream case
for (auto& stream : stream_nodes_) {
@@ -1078,10 +1079,10 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
- auto origin = Buffer(value_idx);
- if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
+ auto origin = AllocPlan(value_idx).reused_buffer;
+ if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
- value_consumer_map_[origin].insert(node_index);
+ value_consumer_map[origin].insert(node_index);
}
}
return Status::OK();
@@ -1138,8 +1139,8 @@ class PlannerImpl {
std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl;
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
- value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
- value_consumer_map_[output_idx_global].end());
+ value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
+ value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
found_reusable = true;
break;
@@ -1168,8 +1169,8 @@ class PlannerImpl {
allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = reusable_input;
- value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(),
- value_consumer_map_[output_idx_global].end());
+ value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(),
+ value_consumer_map[output_idx_global].end());
reused.insert(reusable_input);
continue;
} // if
@@ -1187,11 +1188,11 @@ class PlannerImpl {
OrtValueIndex input_arg_index{};
if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() &&
allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) {
- if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
+ if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) {
allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse;
allocation_plan[output_idx_global].reused_buffer = input_arg_index;
- value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(),
- value_consumer_map_[output_idx_global].end());
+ value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(),
+ value_consumer_map[output_idx_global].end());
reused.insert(input_arg_index);
}
}
@@ -1266,7 +1267,7 @@ class PlannerImpl {
}
bool all_covered = true;
- for (auto consumer : value_consumer_map_[output_idx_global]) {
+ for (auto consumer : value_consumer_map[output_idx_global]) {
if (deps->find(consumer) == deps->end()) {
all_covered = false;
break;
@@ -1277,9 +1278,9 @@ class PlannerImpl {
allocation_plan[downstream_value].reused_buffer = output_idx_global;
get_reused = true;
// add new consumer for the value to be reused
- value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]);
- value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(),
- value_consumer_map_[downstream_value].end());
+ value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]);
+ value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(),
+ value_consumer_map[downstream_value].end());
node_iter = size_iter->second.erase(node_iter);
if (size_iter->second.empty()) {
local_iter->second.erase(size_iter);
@@ -1342,8 +1343,9 @@ class PlannerImpl {
ort_value_usecount.reserve(ort_value_info_.size());
#endif
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
- // compute use count first
+ // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough!
ORT_RETURN_IF_ERROR(ComputeReuseCount());
+ for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j;
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
if (i == 0) {
for (auto ort_value_info : ort_value_info_) {
@@ -1693,8 +1695,8 @@ class PlannerImpl {
const auto& name = input.Name();
int value_idx;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx));
- auto origin = Buffer(value_idx);
- if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) {
+ auto origin = AllocPlan(value_idx).reused_buffer;
+ if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) {
// add current node as consumer for origin buffer
value_consumers[origin].push_back(node_index);
}
@@ -1889,7 +1891,7 @@ class PlannerImpl {
// 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op.
// for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream.
// in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching
- OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type();
+ OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type();
WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device);
if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) {
if (node_to_notification.find(node_index) == node_to_notification.end()) {
diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc
index 875e7f395bfa8..dd7f4d35b34bd 100644
--- a/onnxruntime/core/framework/stream_execution_context.cc
+++ b/onnxruntime/core/framework/stream_execution_context.cc
@@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
}
#ifdef USE_CANN
+ // Leave it to CANN EP to fill the gap if they want to use run_options
+ static onnxruntime::RunOptions run_options;
// For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool,
// which is different from CUDA Runtime API, but similar to CUDA Driver API.
auto& execution_providers = ctx.GetSessionState().GetExecutionProviders();
for (auto& xp : execution_providers) {
- auto status = xp->OnRunStart();
+ auto status = xp->OnRunStart(run_options);
if (!status.IsOK()) {
ctx.SetStatus(status);
return;
diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
index 4505d4afdf1e0..a8717b99a8750 100644
--- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
+++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
@@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
}
#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
+// TODO(mtavenrath) generate list from registered kernels using nhwc domain
const std::unordered_set& GetCUDALayoutSensitiveOps() {
static std::unordered_set cuda_nhwc_ops = []() {
return std::unordered_set{
@@ -41,6 +42,7 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() {
"MaxPool",
"GlobalAveragePool",
"AveragePool",
+ "GridSample",
};
}();
return cuda_nhwc_ops;
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc
index 752b742805a7c..9a242919665bb 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.cc
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc
@@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() {
}
// All threads share the same context and stream
-Status CANNExecutionProvider::OnRunStart() {
+Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id));
return Status::OK();
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h
index 63ae980869c65..d83bd88d6958f 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.h
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.h
@@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider {
explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info);
virtual ~CANNExecutionProvider();
- Status OnRunStart() override;
+ Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
template
Status Fill(Tensor* y, void* addr, aclrtStream stream) const {
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 813fdc54ecd0d..48e4617b33b4d 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -1035,6 +1035,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu);
#if !defined(DISABLE_FLOAT8_TYPES)
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
@@ -2562,6 +2563,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#if !defined(DISABLE_FLOAT8_TYPES)
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc
new file mode 100644
index 0000000000000..d55973eda180f
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc
@@ -0,0 +1,108 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/common.h"
+#include "core/common/narrow.h"
+#include "core/framework/op_kernel.h"
+#include "core/util/math_cpuonly.h"
+#include "core/mlas/inc/mlas.h"
+
+#include "core/platform/threadpool.h"
+#include
+#include "core/providers/cpu/element_wise_ranged_transform.h"
+#include "core/providers/cpu/tensor/gelu.h"
+
+using onnxruntime::narrow;
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+
+// May revisit the implementations to support inplace computation, if needed.
+
+ONNX_CPU_OPERATOR_KERNEL(
+ Gelu,
+ 20,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ Gelu);
+
+#ifndef DISABLE_CONTRIB_OPS
+namespace contrib {
+ONNX_OPERATOR_KERNEL_EX(
+ Gelu,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ Gelu);
+}
+#endif
+
+template
+Status Gelu::Compute(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const T* input_data = input->Data();
+
+ Tensor* output = context->Output(0, input->Shape());
+ T* output_data = output->MutableData();
+
+ concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
+ int64_t elem_count = input->Shape().Size();
+ constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
+ int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
+
+ if (approximation_algorithm_ == "tanh") {
+ // FastGelu allows optional bias. Here we split input data into chunks. Each chunk
+ // has N elements (except the last chunk), and use thread pool to parallel chunks.
+ // N = 4096 is selected based on performance test results on input shape 1x128x768.
+ // FastGelu uses approximation for Gelu. The formula is 0.5 * (1 + Tanh(x * (C * x * x + B))) * x.
+ static constexpr float B = 0.7978845608028654f; // sqrt(2.0 / M_PI)
+ static constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0 / M_PI)
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ tp, static_cast(task_count),
+ [&](ptrdiff_t task_idx) {
+ const auto start = task_idx * length_per_task;
+ const T* p_input = input_data + start;
+ T* p_output = output_data + start;
+ int64_t count = std::min(length_per_task, elem_count - start);
+
+ for (int64_t i = 0; i < count; i++) {
+ T value = p_input[i];
+ p_output[i] = value * (static_cast(C) * value * value + static_cast(B));
+ }
+
+ MlasComputeTanh(p_output, p_output, narrow(count));
+
+ for (int64_t i = 0; i < count; i++) {
+ p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
+ }
+ },
+ 0);
+ return Status::OK();
+ } else if (approximation_algorithm_ == "none") {
+ concurrency::ThreadPool::TryBatchParallelFor(
+ tp, static_cast(task_count),
+ [&](ptrdiff_t task_idx) {
+ const auto start = task_idx * length_per_task;
+ const T* p_input = input_data + start;
+ T* p_output = output_data + start;
+ int64_t count = std::min(length_per_task, elem_count - start);
+
+ for (int64_t i = 0; i < count; i++) {
+ T value = p_input[i];
+ p_output[i] = value * static_cast(M_SQRT1_2);
+ }
+
+ MlasComputeErf(p_output, p_output, narrow(count));
+
+ for (int64_t i = 0; i < count; i++) {
+ p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
+ }
+ },
+ 0);
+ return Status::OK();
+ }
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h
new file mode 100644
index 0000000000000..13238028d878a
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/tensor/gelu.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+namespace onnxruntime {
+
+template
+class Gelu final : public OpKernel {
+ public:
+ explicit Gelu(const OpKernelInfo& info) : OpKernel(info) {
+ approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none");
+ }
+ Status Compute(OpKernelContext* ctx) const override;
+
+ private:
+ std::string approximation_algorithm_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index e006a0deeb4e8..f42a71de24337 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -400,7 +400,7 @@ Status CUDAExecutionProvider::Sync() const {
return Status::OK();
}
-Status CUDAExecutionProvider::OnRunStart() {
+Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptureSkippedOnRun() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
@@ -410,7 +410,7 @@ Status CUDAExecutionProvider::OnRunStart() {
return Status::OK();
}
-Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
+Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptureSkippedOnRun() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
@@ -1274,6 +1274,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1347,6 +1348,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape);
#endif
+// Opset 20
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
return {};
@@ -2161,6 +2167,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
@@ -2240,6 +2247,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+
+ // Opset 20
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#endif
};
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
index 41e337e4dd691..deb1d9d4230c8 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h
@@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider {
Status Sync() const override;
- Status OnRunStart() override;
+ Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
- Status OnRunEnd(bool sync_stream) override;
+ Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override;
DataLayout GetPreferredLayout() const override;
@@ -116,6 +116,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy,
CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg);
~PerThreadContext();
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);
cublasHandle_t CublasHandle() const {
return cublas_handle_;
diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h
index fdd14dedad47e..2cbeb13696270 100644
--- a/onnxruntime/core/providers/cuda/cudnn_common.h
+++ b/onnxruntime/core/providers/cuda/cudnn_common.h
@@ -24,12 +24,12 @@ class CudnnTensor final {
operator cudnnTensorDescriptor_t() const { return tensor_; }
+ Status CreateTensorIfNeeded();
+
template
static cudnnDataType_t GetDataType();
private:
- Status CreateTensorIfNeeded();
-
cudnnTensorDescriptor_t tensor_;
};
diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
index 99c1f48e21c74..b61b104790fe5 100644
--- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
+++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
@@ -9,40 +9,49 @@ namespace onnxruntime {
namespace cuda {
template
-void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle,
- const cudnnRNNDescriptor_t rnn_desc,
- const int pseudo_layer,
- const cudnnTensorDescriptor_t x_desc,
- const cudnnFilterDescriptor_t w_desc,
- const cudnnFilterDescriptor_t filter_desc,
- const void* reorganized_w_data,
- const int lin_layer_id,
- const T* pos,
- int& offset,
- bool is_matrix,
- cudaStream_t cuda_stream) const {
+Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle,
+ const cudnnRNNDescriptor_t rnn_desc,
+ const int pseudo_layer,
+ size_t reorganized_w_data_size,
+ const void* reorganized_w_data,
+ const int lin_layer_id,
+ const T* pos,
+ int& offset,
+ bool is_matrix,
+ cudaStream_t cuda_stream) const {
int numDims;
- std::vector matDims(3);
+ std::array matDims;
+ std::array strideA;
cudnnDataType_t dt;
- cudnnTensorFormat_t tf;
T* mem_offset;
- if (is_matrix) {
- cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset);
- } else {
- cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset);
- }
+ CudnnTensor tensor_desc_matrix, tensor_desc_bias;
+ ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded());
+ ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded());
- cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data());
+ T *mem_offset_matrix, *mem_offset_bias;
+ CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams(
+ handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data,
+ lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias));
+ CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor(
+ is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data()));
+
+ mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias;
int count = matDims[0] * matDims[1] * matDims[2];
+
+ if (strideA[0] != count) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed");
+ }
CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream));
+
offset += count;
+
+ return Status::OK();
}
template
Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
const cudnnRNNDescriptor_t rnn_desc,
- const cudnnTensorDescriptor_t x_desc,
- const cudnnFilterDescriptor_t w_desc,
+ size_t reorganized_w_data_size,
void* reorganized_w_data,
const T* W_data,
const T* R_data,
@@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
int w_offset = 0;
int r_offset = 0;
int bias_offset = 0;
- CudnnFilterDescriptor filter_desc;
for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) {
for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(
+ cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream));
if (B_data != nullptr) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream));
}
}
for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream));
if (B_data != nullptr) {
- SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream);
+ ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data,
+ R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream));
}
}
}
@@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
template
Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B,
+ size_t& reorganized_w_data_size_in_bytes,
IAllocatorUniquePtr& reorganized_w_data,
CudnnFilterDescriptor& target_w_desc,
CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const {
@@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
TensorShapeVector dims_w({w_size, 1, 1});
ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType()));
- TensorShapeVector fake_dims_x({1, input_size, 1});
- CudnnTensor fake_x_desc;
- ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType()));
-
// Prepare the weight data
- reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream);
+ reorganized_w_data_size_in_bytes = w_size * sizeof(T);
+ reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream);
// In many cases, this allocation is bigger than needed, leaving part of
- // the buffer unintialized. non-zero garbage data leads to wrong result
+ // the buffer uninitialized. non-zero garbage data leads to wrong result
// in call to cudnnRNNForwardInference()
// TODO! refine allocation size for each case.
cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr;
- cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream);
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream));
const T* W_data = W->Data();
const T* R_data = R->Data();
@@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons
auto* ort_cuda_stream = dynamic_cast(ort_stream);
cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle();
- ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc,
- reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream));
+ ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc,
+ reorganized_w_data_size_in_bytes, reorganized_w_data.get(),
+ W_data, R_data, B_data, cuda_stream));
return Status::OK();
}
@@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) {
bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R);
bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B);
+ bool has_bias = B != nullptr;
+
if (get_W && get_R) {
CudnnRNN tmp_rnn_desc;
- ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(),
+ auto proj_size = hidden_size_;
+ ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size
hidden_size_,
+ proj_size,
RNN_NUM_LAYERS,
cudnn_dropout_desc_,
cudnn_direction_mode_,
rnn_mode_,
- CudnnTensor::GetDataType(),
- GetDeviceProp()));
+ has_bias,
+ CudnnTensor::GetDataType()));
if (get_B) {
- ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr));
+ ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B,
+ w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
+ tmp_rnn_desc, nullptr));
} else {
- ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr));
+ ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr,
+ w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_,
+ tmp_rnn_desc, nullptr));
}
cudaStreamSynchronize(nullptr);
+
weight_cached_ = true;
}
@@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
ORT_ENFORCE(nullptr != X);
// optional inputs
- const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size]
- const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_]
+ // [batch_size]
+ const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens);
+ // initial hidden. [num_directions_, batch_size, hidden_size_]
+ const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h);
const Tensor* initial_c(nullptr);
if (rnn_mode_ == CUDNN_LSTM) {
- initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_]
+ // initial cell. [num_directions_, batch_size, hidden_size_]
+ initial_c = ctx->Input(RNN_Input_Index::initial_c);
}
+ size_t proj_size = hidden_size_;
int64_t seq_length = X->Shape()[0];
int64_t batch_size = X->Shape()[1];
int64_t input_size = X->Shape()[2];
+ // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]?
+ std::vector sequence_lengths_temp;
+ if (!sequence_lens) {
+ sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length));
+ }
+
+ const int32_t* sequence_lens_data = (sequence_lens == nullptr)
+ ? sequence_lengths_temp.data()
+ : sequence_lens->Data();
+
+ // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1
+ // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence
+ int64_t zero_seq_count = 0;
+ std::vector zero_seq_index_cache(batch_size, 0);
+
+ CudaAsyncBuffer sequence_lens_buffer(this, batch_size);
+ int32_t* seq_len_array = sequence_lens_buffer.CpuPtr();
+
+ // 0-len sequences are not supported by cuDNN.
+ // Replace them by sequences of len 1 and mask them out with SetZeroSequences
+ for (int i = 0; i < batch_size; ++i) {
+ if (0 == sequence_lens_data[i]) {
+ seq_len_array[i] = 1;
+ zero_seq_index_cache[zero_seq_count] = i;
+ ++zero_seq_count;
+ } else {
+ seq_len_array[i] = sequence_lens_data[i];
+ }
+ }
+
+ // Calculate the zero position cache for reverse direction if it's bidirectional
+ // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since
+ // we hacked the 0 sequence to 1
+ if (zero_seq_count && num_directions_ > 1) {
+ zero_seq_index_cache.resize(zero_seq_count * num_directions_);
+ for (int64_t i = 0; i < zero_seq_count; ++i) {
+ zero_seq_index_cache[static_cast(zero_seq_count) + i] =
+ static_cast(batch_size + zero_seq_index_cache[i]);
+ }
+ zero_seq_count *= num_directions_;
+ }
+
+ // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must
+ // be copied to the GPU always.
+ ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
+ // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must
+ // be copied to the GPU only for the ReverseBySequence kernels.
+ // if (reverse_) {
+ // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream()));
+ // }
+
// optional outputs
TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_});
TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_});
@@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy);
Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc);
- std::vector dims_x({batch_size, input_size, 1});
- std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1});
-
- CudnnTensor x_desc_temp;
- ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType()));
- CudnnTensor y_desc_temp;
- ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType()));
- std::vector x_desc(seq_length, x_desc_temp);
- std::vector y_desc(seq_length, y_desc_temp);
-
- CudnnTensor hx_desc;
- CudnnTensor cx_desc;
- CudnnTensor y_h_desc;
- CudnnTensor y_c_desc;
- ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
- ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
- ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
- ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType()));
-
IAllocatorUniquePtr x_reversed_data;
const T* x_data = X->Data();
if (reverse_) {
@@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream());
ReverseBySequence(Stream(ctx),
gsl::narrow_cast(seq_length),
+ sequence_lens_buffer.GpuPtr(),
gsl::narrow_cast(batch_size),
gsl::narrow_cast(input_size),
reinterpret_cast(x_data),
@@ -226,115 +284,82 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const {
y_data = y_alloc_data.get();
}
- const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data();
+ const Tensor* B = ctx->Input(RNN_Input_Index::B);
+ bool has_bias = B != nullptr;
CudnnRNN rnn_desc;
- ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx),
+ ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size,
hidden_size_,
+ proj_size,
RNN_NUM_LAYERS,
cudnn_dropout_desc_,
cudnn_direction_mode_,
rnn_mode_,
- CudnnTensor::GetDataType(),
- GetDeviceProp()));
+ has_bias,
+ CudnnTensor::GetDataType()));
// Prepare the weight data
+ size_t w_data_size_in_bytes = 0;
IAllocatorUniquePtr w_data;
CudnnFilterDescriptor w_desc;
if (!weight_cached_) {
const Tensor& W = *ctx->Input(RNN_Input_Index::W);
const Tensor& R = *ctx->Input(RNN_Input_Index::R);
const Tensor* B = ctx->Input(RNN_Input_Index::B);
- ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream()));
+ ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc,
+ rnn_desc, ctx->GetComputeStream()));
}
- // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
- CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED));
+ CudnnDataTensor x_desc1;
+ ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size,
+ input_size, seq_len_array));
+ CudnnDataTensor y_desc1;
+ ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size,
+ ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_,
+ seq_len_array));
- size_t workspace_bytes;
- CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes));
- auto workspace_cuda = GetScratchBuffer