From 1c2dca95d813e3bf7a2b59a70fcedae9c84bed7d Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Wed, 3 Jan 2024 04:38:33 +0000
Subject: [PATCH 01/20] pass rotary embedding to attention op (#18846)
### Description
### Motivation and Context
---
docs/ContribOperators.md | 2 ++
.../contrib_ops/cpu/bert/attention_base.cc | 1 +
.../contrib_ops/cpu/bert/attention_base.h | 2 ++
.../contrib_ops/cpu/bert/attention_common.h | 1 +
.../cuda/bert/add_bias_transpose.cu | 19 +++++-----
.../cuda/bert/add_bias_transpose.h | 2 +-
.../cuda/bert/attention_prepare_qkv.cu | 3 +-
.../core/graph/contrib_ops/bert_defs.cc | 4 +++
.../test_parity_neox_attention.py | 36 +++++++++++--------
9 files changed, 45 insertions(+), 25 deletions(-)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 131db5d8d9b37..38fceef67de25 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+rotary_embedding_dim : int
+Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index 5d224bdc2235f..515a967aa2386 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
+ output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index 5ee40c4b98664..a6782daa58f1a 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -38,6 +38,7 @@ class AttentionBase {
is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
@@ -72,6 +73,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
bool do_rotary_; // whether or not to use rotary embeddings
+ int rotary_embedding_; // rotary embedding dimension
float mask_filter_value_; // the value to be used for filtered out positions
float scale_; // the scale to be used for softmax
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index a7f83469a768d..c9ed23895b60c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -56,6 +56,7 @@ struct AttentionParameters {
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
+ int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index 626e4c0b87a3c..1ea2540db486f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -640,7 +640,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count,
- bool do_rotary = false, int past_sequence_length = 0) {
+ bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) {
assert(num_heads <= max_threads_per_block);
if (do_rotary) {
@@ -650,20 +650,20 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
- if (qk_head_size != 64 && qk_head_size != 128) {
- ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
+ if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) {
+ ORT_THROW("rotary_embedding must be 32, 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention");
}
const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length;
- size_t smem_size = 2 * qk_head_size * sizeof(T);
+ size_t smem_size = 2 * rotary_embedding * sizeof(T);
const dim3 grid(sequence_length, num_heads, batch_size);
const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1);
AddBiasTransposeQKV<<>>(total_matrix_count, input, biases, output,
- qkv_add_bias, qk_head_size, qk_head_size,
+ qkv_add_bias, rotary_embedding, qk_head_size,
step, format);
#else
ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__);
@@ -727,7 +727,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size,
- half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) {
+ half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
@@ -753,7 +753,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
- qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
+ qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length);
}
}
@@ -763,7 +763,7 @@ void LaunchAddBiasTranspose(
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output, bool /*enable_half4*/,
const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary,
- int past_sequence_length) {
+ int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
@@ -789,7 +789,8 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
- qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
+ qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding,
+ past_sequence_length);
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
index d903267c99a01..efc31db43bcdb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
@@ -33,7 +33,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
- int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0);
+ int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// For self attention:
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 5c65a30918ece..a513d9e8d2211 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
- 3, parameters.do_rotary, parameters.past_sequence_length);
+ 3, parameters.do_rotary, parameters.rotary_embedding,
+ parameters.past_sequence_length);
}
return Status::OK();
}
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index ea67218b5c927..f8f63650615fd 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -333,6 +333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
+ .Attr("rotary_embedding_dim",
+ "Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
.Attr("mask_filter_value",
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py
index 8c8e871a854b0..a98bb623beaea 100644
--- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py
+++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py
@@ -29,6 +29,7 @@ def create_neox_attention_graph(
qkv_weight,
qkv_bias,
num_heads,
+ rotary_embedding,
):
nodes = [
helper.make_node(
@@ -43,6 +44,7 @@ def create_neox_attention_graph(
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
+ rotary_embedding=rotary_embedding,
domain="com.microsoft",
),
]
@@ -174,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
class GPTNeoXAttention(nn.Module):
- def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
+ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64):
super().__init__()
self.do_rotary = True
self.num_attention_heads = num_head
self.hidden_size = hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
- self.rotary_ndims = int(self.head_size)
+ self.rotary_ndims = rotary_ndims
max_positions = 2048
self.register_buffer(
"bias",
@@ -197,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
# self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size))))
if past_seq_len > 0:
+ assert self.rotary_ndims == self.head_size
self.onnx_graph = create_neox_decoder_masked_self_attention_graph(
batch_size,
seq_len,
@@ -220,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
.transpose(0, 1),
self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1),
self.num_attention_heads,
+ self.rotary_ndims,
)
@classmethod
@@ -422,17 +426,21 @@ def test_gpt_neox_attention(self):
for batch_size in [1, 2, 4, 8]:
for seq_len in [32, 128, 512, 1024, 2048]:
for num_head in [12]:
- for hidden_size in [768]:
- attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)
-
- hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
- torch.float32
- )
-
- torch_output = attn.torch_forward(hidden_states)
- ort_output = attn.onnx_forward(hidden_states)
- if ort_output is not None:
- assert torch.allclose(torch_output, ort_output, atol=1e-4)
+ for rotary_ndims in [32, 64]:
+ for hidden_size in [768, 960]:
+ attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims)
+
+ hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
+ torch.float32
+ )
+
+ torch_output = attn.torch_forward(hidden_states)
+ ort_output = attn.onnx_forward(hidden_states)
+ if ort_output is not None:
+ assert torch.allclose(torch_output, ort_output, atol=1e-3)
+ print(
+ f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}"
+ )
def test_gpt_neox_decoder_masked_self_attention(self):
for batch_size in [1, 2, 4, 8]:
@@ -466,7 +474,7 @@ def test_gpt_neox_decoder_masked_self_attention(self):
hidden_states, attention_mask=attention_mask, layer_past=layer_past
)
if ort_output is not None:
- assert torch.allclose(torch_output, ort_output, atol=1e-4)
+ assert torch.allclose(torch_output, ort_output, atol=1e-3)
if __name__ == "__main__":
From c97e3f48216d66dfbc6aa951ddcb7f32e313d314 Mon Sep 17 00:00:00 2001
From: Yi Zhang
Date: Wed, 3 Jan 2024 14:53:31 +0800
Subject: [PATCH 02/20] [Fix] exception in Fuzz Test pipeline (#18984)
### Description
### Motivation and Context
The file path is not correct.
---
.../github/azure-pipelines/win-ci-fuzz-testing.yml | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml
index 98f1bf7ea1a16..b8f9566274acc 100644
--- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml
+++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml
@@ -20,7 +20,11 @@ jobs:
workspace:
clean: all
steps:
- - template: win-ci-prebuild-steps.yml
+ - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+ displayName: 'Clean Agent Directories'
+ condition: always()
+
+ - template: templates/jobs/win-ci-prebuild-steps.yml
parameters:
EnvSetupScript: $(EnvSetupScript)
DownloadCUDA: false
@@ -69,7 +73,3 @@ jobs:
script: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)\onnxruntime_security_fuzz.exe /t /f "$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)\testdata\mnist.onnx" 1 m'
workingDirectory: $(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)
failOnStderr: false # Optional
-
- - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
- displayName: 'Clean Agent Directories'
- condition: always()
From 7a454acd6197f4ba1ffca13ec9948915ce82d20e Mon Sep 17 00:00:00 2001
From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com>
Date: Wed, 3 Jan 2024 17:25:15 +0800
Subject: [PATCH 03/20] [ROCm] Update CI/Packaging pipeline to ROCm6.0 (#18985)
Update CI/Packaing pipeline to ROCm6.0
---
...mi200.huggingface.bert-large-rocm6.0.json} | 28 +++++++++----------
.../linux-migraphx-ci-pipeline.yml | 2 +-
.../orttraining-pai-ci-pipeline.yml | 2 +-
...orttraining-py-packaging-pipeline-rocm.yml | 24 ++++++++--------
.../docker/Dockerfile.manylinux2_28_rocm | 2 +-
.../migraphx-ci-pipeline-env.Dockerfile | 2 +-
.../docker/scripts/setup_rocm_yum_repo.sh | 6 ++--
.../pai/rocm-ci-pipeline-env.Dockerfile | 4 +--
8 files changed, 35 insertions(+), 35 deletions(-)
rename orttraining/tools/ci_test/results/{ci-mi200.huggingface.bert-large-rocm5.7.json => ci-mi200.huggingface.bert-large-rocm6.0.json} (61%)
diff --git a/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm5.7.json b/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.0.json
similarity index 61%
rename from orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm5.7.json
rename to orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.0.json
index a4ac02b566848..05fcf08cd3232 100644
--- a/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm5.7.json
+++ b/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.0.json
@@ -2,56 +2,56 @@
"steps": [
{
"step": 20,
- "loss": 2.0017
+ "loss": 2.0136
},
{
"step": 40,
- "loss": 1.8337
+ "loss": 1.8466
},
{
"step": 60,
- "loss": 1.7538
+ "loss": 1.7525
},
{
"step": 80,
- "loss": 1.6728
+ "loss": 1.6682
},
{
"step": 100,
- "loss": 1.6656
+ "loss": 1.658
},
{
"step": 120,
- "loss": 1.6752
+ "loss": 1.6749
},
{
"step": 140,
- "loss": 1.6335
+ "loss": 1.6263
},
{
"step": 160,
- "loss": 1.6815
+ "loss": 1.6828
},
{
"step": 180,
- "loss": 1.6155
+ "loss": 1.6145
},
{
"step": 200,
- "loss": 1.6177
+ "loss": 1.6197
},
{
"step": 220,
- "loss": 1.632
+ "loss": 1.6353
},
{
"step": 240,
- "loss": 1.5161
+ "loss": 1.5266
},
{
"step": 260,
- "loss": 1.5433
+ "loss": 1.5441
}
],
- "samples_per_second": 32.335
+ "samples_per_second": 34.561
}
diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
index 5dac8fc9cda63..f7571a3b7eab6 100644
--- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
@@ -36,7 +36,7 @@ variables:
- name: render
value: 109
- name: RocmVersion
- value: 5.7
+ value: 6.0
jobs:
- job: Linux_Build
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml
index 8d02a5e5809a2..a53f91fb317cb 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml
@@ -25,7 +25,7 @@ variables:
- name: render
value: 109
- name: RocmVersion
- value: 5.7
+ value: 6.0
- name: BuildConfig
value: Release
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml
index f2ba99369c144..bbdbe0fd8e376 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml
@@ -9,51 +9,51 @@ resources:
ref: 5eda9aded5462201e6310105728d33016e637ea7
stages:
-- stage: "Python_Packaging_ROCm57_Release"
+- stage: "Python_Packaging_ROCm60_Release"
jobs:
- template: templates/rocm.yml
parameters:
PythonVersion: '3.8'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.9'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.10'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
-- stage: "Python_Packaging_ROCm57_Debug"
+- stage: "Python_Packaging_ROCm60_Debug"
jobs:
- template: templates/rocm.yml
parameters:
PythonVersion: '3.8'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
BuildConfig: 'Debug'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.9'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
BuildConfig: 'Debug'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.10'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
BuildConfig: 'Debug'
-- stage: "Python_Packaging_ROCm56_Release"
+- stage: "Python_Packaging_ROCm57_Release"
condition: ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true')
jobs:
- template: templates/rocm.yml
parameters:
PythonVersion: '3.8'
- RocmVersion: '5.6'
+ RocmVersion: '5.7'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.9'
- RocmVersion: '5.6'
+ RocmVersion: '5.7'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.10'
- RocmVersion: '5.6'
+ RocmVersion: '5.7'
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
index 9e12fe8c75451..b9fd88083f218 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
+++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
@@ -31,7 +31,7 @@ RUN yum install -y hipify-clang
RUN yum -y install wget
# rocm lib
-RUN yum install -y miopen-hip-devel rocblas-devel rocrand-devel rccl-devel hipsparse-devel hipfft-devel hipcub-devel hipblas-devel rocthrust-devel migraphx-devel
+RUN yum install -y migraphx-devel
ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM}
ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8
diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile
index d02e7d8b91d11..85d738d2167e1 100644
--- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile
+++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile
@@ -1,7 +1,7 @@
# Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete
FROM ubuntu:22.04
-ARG ROCM_VERSION=5.7
+ARG ROCM_VERSION=6.0
ARG AMDGPU_VERSION=${ROCM_VERSION}
ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600'
diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh
index fcd9086061227..269337bbba042 100755
--- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh
+++ b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh
@@ -2,7 +2,7 @@
set -e -x
# version
-ROCM_VERSION=5.6
+ROCM_VERSION=6.0
while getopts "r:" parameter_Option
do case "${parameter_Option}"
@@ -14,7 +14,7 @@ done
tee /etc/yum.repos.d/amdgpu.repo <
Date: Thu, 4 Jan 2024 02:13:17 +0800
Subject: [PATCH 04/20] [js/webgpu] Introduce trace support (#18928)
This is to leverage console.timeStamp to add a single marker to
browsers' (only Chromium and Firefox support it) performance tool. With
this support, we can dump both CPU and GPU timestamps, and use
post-processing tool to clearly understand the calibrated timeline. A
demo tool can be found at https://github.com/webatintel/ort-test, and
more detailed info can be found at
https://docs.google.com/document/d/1TuVxjE8jnELBXdhI4QGFgMnUqQn6Q53QA9y4a_dH688/edit.
---
js/common/lib/env.ts | 7 +++
js/common/lib/index.ts | 1 +
js/common/lib/inference-session-impl.ts | 5 +++
js/common/lib/trace.ts | 44 +++++++++++++++++++
js/web/lib/backend-wasm.ts | 4 ++
js/web/lib/wasm/jsep/backend-webgpu.ts | 4 +-
.../lib/wasm/jsep/webgpu/program-manager.ts | 6 +++
js/web/lib/wasm/session-handler-inference.ts | 6 ++-
8 files changed, 75 insertions(+), 2 deletions(-)
create mode 100644 js/common/lib/trace.ts
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index 0cded7e5edbcb..b007b5e164bf3 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -33,6 +33,13 @@ export declare namespace Env {
*/
simd?: boolean;
+ /**
+ * set or get a boolean value indicating whether to enable trace.
+ *
+ * @defaultValue `false`
+ */
+ trace?: boolean;
+
/**
* Set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds. A zero
* value indicates no timeout is set.
diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts
index 9cbfcc4e8bcdc..d7c98380f3fa4 100644
--- a/js/common/lib/index.ts
+++ b/js/common/lib/index.ts
@@ -21,5 +21,6 @@ export * from './backend.js';
export * from './env.js';
export * from './inference-session.js';
export * from './tensor.js';
+export * from './trace.js';
export * from './onnx-value.js';
export * from './training-session.js';
diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts
index 9bc2088f2088a..55f40c8907a89 100644
--- a/js/common/lib/inference-session-impl.ts
+++ b/js/common/lib/inference-session-impl.ts
@@ -6,6 +6,7 @@ import {InferenceSessionHandler} from './backend.js';
import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
+import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from './trace.js';
type SessionOptions = InferenceSessionInterface.SessionOptions;
type RunOptions = InferenceSessionInterface.RunOptions;
@@ -20,6 +21,7 @@ export class InferenceSession implements InferenceSessionInterface {
run(feeds: FeedsType, options?: RunOptions): Promise;
run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise;
async run(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise {
+ TRACE_FUNC_BEGIN();
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
@@ -117,6 +119,7 @@ export class InferenceSession implements InferenceSessionInterface {
}
}
}
+ TRACE_FUNC_END();
return returnValue;
}
@@ -132,6 +135,7 @@ export class InferenceSession implements InferenceSessionInterface {
static async create(
arg0: string|ArrayBufferLike|Uint8Array, arg1?: SessionOptions|number, arg2?: number,
arg3?: SessionOptions): Promise {
+ TRACE_FUNC_BEGIN();
// either load from a file or buffer
let filePathOrUint8Array: string|Uint8Array;
let options: SessionOptions = {};
@@ -196,6 +200,7 @@ export class InferenceSession implements InferenceSessionInterface {
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
+ TRACE_FUNC_END();
return new InferenceSession(handler);
}
diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts
new file mode 100644
index 0000000000000..404f7ef8089af
--- /dev/null
+++ b/js/common/lib/trace.ts
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {env} from './env-impl.js';
+
+export const TRACE = (deviceType: string, label: string) => {
+ if (!env.wasm.trace) {
+ return;
+ }
+ // eslint-disable-next-line no-console
+ console.timeStamp(`${deviceType}::ORT::${label}`);
+};
+
+const TRACE_FUNC = (msg: string, extraMsg?: string) => {
+ const stack = new Error().stack?.split(/\r\n|\r|\n/g) || [];
+ let hasTraceFunc = false;
+ for (let i = 0; i < stack.length; i++) {
+ if (hasTraceFunc && !stack[i].includes('TRACE_FUNC')) {
+ let label = `FUNC_${msg}::${stack[i].trim().split(' ')[1]}`;
+ if (extraMsg) {
+ label += `::${extraMsg}`;
+ }
+ TRACE('CPU', label);
+ return;
+ }
+ if (stack[i].includes('TRACE_FUNC')) {
+ hasTraceFunc = true;
+ }
+ }
+};
+
+export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
+ if (!env.wasm.trace) {
+ return;
+ }
+ TRACE_FUNC('BEGIN', extraMsg);
+};
+
+export const TRACE_FUNC_END = (extraMsg?: string) => {
+ if (!env.wasm.trace) {
+ return;
+ }
+ TRACE_FUNC('END', extraMsg);
+};
diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts
index 2d123cdb71290..d9f63fec9c492 100644
--- a/js/web/lib/backend-wasm.ts
+++ b/js/web/lib/backend-wasm.ts
@@ -26,6 +26,10 @@ export const initializeFlags = (): void => {
env.wasm.proxy = false;
}
+ if (typeof env.wasm.trace !== 'boolean') {
+ env.wasm.trace = false;
+ }
+
if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) {
const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency;
env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2));
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 6c3d22352772e..0148f32cdd91b 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Env, Tensor} from 'onnxruntime-common';
+import {Env, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
@@ -263,6 +263,7 @@ export class WebGpuBackend {
run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[],
createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView,
createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] {
+ TRACE_FUNC_BEGIN(program.name);
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < inputTensorViews.length; ++i) {
@@ -387,6 +388,7 @@ export class WebGpuBackend {
artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup,
uniformBufferBinding);
+ TRACE_FUNC_END(program.name);
return outputTensorViews;
}
diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
index ae5bf68483b46..0d699326366b3 100644
--- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
+
import {tensorDataTypeEnumToString} from '../../wasm-common';
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';
@@ -35,6 +37,7 @@ export class ProgramManager {
run(buildArtifact: Artifact, inputTensorViews: readonly TensorView[], outputTensorViews: readonly TensorView[],
inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number],
uniformBufferBinding: GPUBindingResource|undefined): void {
+ TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
@@ -128,11 +131,13 @@ export class ProgramManager {
if (this.backend.pendingDispatchNumber >= 16) {
this.backend.flush();
}
+ TRACE_FUNC_END(buildArtifact.programInfo.name);
}
dispose(): void {
// this.repo.forEach(a => this.glContext.deleteProgram(a.program));
}
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
+ TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
@@ -147,6 +152,7 @@ export class ProgramManager {
const computePipeline = device.createComputePipeline(
{compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name});
+ TRACE_FUNC_END(programInfo.name);
return {programInfo, computePipeline};
}
diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts
index b62287483208a..e17ec37e3e612 100644
--- a/js/web/lib/wasm/session-handler-inference.ts
+++ b/js/web/lib/wasm/session-handler-inference.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {readFile} from 'node:fs/promises';
-import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
+import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
@@ -54,6 +54,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise {
+ TRACE_FUNC_BEGIN();
let model: Parameters[0];
if (typeof pathOrBuffer === 'string') {
@@ -70,6 +71,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}
[this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
+ TRACE_FUNC_END();
}
async dispose(): Promise {
@@ -78,6 +80,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise {
+ TRACE_FUNC_BEGIN();
const inputArray: Tensor[] = [];
const inputIndices: number[] = [];
Object.entries(feeds).forEach(kvp => {
@@ -115,6 +118,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
for (let i = 0; i < results.length; i++) {
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
+ TRACE_FUNC_END();
return resultMap;
}
From 3b8b9147fa4f8f6348e171a257bbc325744301df Mon Sep 17 00:00:00 2001
From: Jiajie Hu
Date: Thu, 4 Jan 2024 06:15:26 +0800
Subject: [PATCH 05/20] [js/webgpu] Mitigate floating point accuracy issue in
Resize (#18956)
### Description
The patch fixes a floating point accuracy issue in Resize by preferring
integer indices and integer arithmetic where possible.
### Motivation and Context
Model test `test_resize_upsample_sizes_nearest_floor_align_corners` was
observed to be failing on certain platforms. The root cause is the
inaccurate floating point evaluation of 21 / 7 (2.999... vs 3), which
results in the wrong input element to be indexed (floor(2.999...) vs
floor(3)).
---
js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 83 ++++++++++++-----------
1 file changed, 45 insertions(+), 38 deletions(-)
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
index bea3e8625b41b..d359580904a7b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
@@ -110,41 +110,48 @@ const validateInputs =
const getOriginalCoordinateFromResizedCoordinate =
(coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
- `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType},
- lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
+ `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: ${dType}, lengthResized: u32,
+ lengthOriginal: u32, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
- return 'return xResized / xScale;';
+ return `return ${dType}(xResized) / xScale;`;
case 'pytorch_half_pixel':
- return 'if (lengthResized > 1) { \
- return (xResized + 0.5) / xScale - 0.5; \
- } else { \
- return 0.0; \
- }';
+ return `if (lengthResized > 1) {
+ return (${dType}(xResized) + 0.5) / xScale - 0.5;
+ } else {
+ return 0.0;
+ }`;
case 'tf_half_pixel_for_nn':
- return 'return (xResized + 0.5) / xScale;';
+ return `return (${dType}(xResized) + 0.5) / xScale;`;
case 'align_corners':
- return 'if (lengthResized == 1) { \
- return 0.0; \
- } else { \
- return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
- }';
+ return `if (lengthResized == 1) {
+ return 0.0;
+ } else {
+ // The whole part and the fractional part are calculated separately due to inaccuracy of floating
+ // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
+ // offset-by-one error later in floor().
+ let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
+ let fract =
+ ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
+ return whole + fract;
+ }`;
case 'tf_crop_and_resize':
- return `if (lengthResized > 1) { \
- return roiStart * (lengthOriginal - 1) + \
- (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
- } else { \
- return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \
+ return `if (lengthResized > 1) {
+ return roiStart * ${dType}(lengthOriginal - 1) +
+ (${dType}(xResized) * (roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
+ ${dType}(lengthResized - 1);
+ } else {
+ return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
}`;
case 'half_pixel_symmetric':
- return [
- 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
- 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
- 'return offset + ((xResized + 0.5) / xScale) - 0.5;'
- ].join('\n');
+ return `const outputWidth = xScale * ${dType}(lengthResized);
+ const adjustment = ${dType}(lengthResized) / outputWidth;
+ const center = ${dType}(lengthOriginal) / 2;
+ const offset = center * (1 - adjustment);
+ return offset + ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
case 'half_pixel':
- return 'return ((xResized + 0.5) / xScale) - 0.5;';
+ return `return ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
@@ -254,15 +261,15 @@ const calculateOriginalIndicesFromOutputIndices =
output.type.value}, ${outputShape.length}> {
var original_indices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
- var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
+ var output_index = ${output.indicesGet('output_indices', 'i')};
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
if (scale == 1.0) {
- original_indices[i] = output_index;
+ original_indices[i] = ${output.type.value}(output_index);
} else {
- var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
- var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
+ var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
+ var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
input_shape_i, roi_low, roi_hi);
}
@@ -276,23 +283,23 @@ const calculateInputIndicesFromOutputIndices =
fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
var input_indices: ${input.type.indices};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
- var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
+ var output_index = ${output.indicesGet('output_indices', 'i')};
var input_index: u32;
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
if (scale == 1.0) {
- input_index = u32(output_index);
+ input_index = output_index;
} else {
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
- var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
- var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
+ var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
+ var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
input_shape_i, roi_low, roi_hi);
- if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) {
+ if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) {
if (original_idx < 0) {
input_index = 0;
- } else if (original_idx > (input_shape_i - 1)) {
- input_index = u32(input_shape_i) - 1;
+ } else if (original_idx > ${output.type.value}(input_shape_i - 1)) {
+ input_index = input_shape_i - 1;
} else {
input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
}
@@ -391,8 +398,8 @@ const bicubicInterpolation =
fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${
output.type.indices}) -> ${dType} {
var output_index = ${output.indicesGet('output_indices', idx)};
- var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]},
- ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
+ var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]},
+ ${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
From 8e9188e265622bc811bab735c08135eec6cbd6fc Mon Sep 17 00:00:00 2001
From: Scott McKay
Date: Thu, 4 Jan 2024 11:12:48 +1000
Subject: [PATCH 06/20] Add SessionOptions use_deterministic_compute to the C
and C++ APIs. (#18944)
### Description
SessionOptions use_deterministic_compute can be set via the python API.
User request to enable setting via C API.
### Motivation and Context
#17416
---
.../onnxruntime/core/session/onnxruntime_c_api.h | 15 ++++++++++++++-
.../core/session/onnxruntime_cxx_api.h | 1 +
.../core/session/onnxruntime_cxx_inline.h | 6 ++++++
onnxruntime/core/session/abi_session_options.cc | 7 +++++++
onnxruntime/core/session/onnxruntime_c_api.cc | 1 +
onnxruntime/core/session/ort_apis.h | 1 +
.../test/shared_lib/test_session_options.cc | 6 ++++++
7 files changed, 36 insertions(+), 1 deletion(-)
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index dbd5ad41255fa..06fef6bf72cc9 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -29,8 +29,9 @@
*/
#pragma once
-#include
+#include
#include
+#include
#include
/** \brief The API version defined in this header
@@ -4515,6 +4516,18 @@ struct OrtApi {
* \since Version 1.17.
*/
ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);
+
+ /** \brief Set whether to use deterministic compute.
+ *
+ * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible.
+ * Note that this most likely will have a performance cost.
+ *
+ * \param[in] options
+ * \param[in] value
+ *
+ * \since Version 1.17.
+ */
+ ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 92c25d8688b66..16d9451624533 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -845,6 +845,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl {
SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
+ SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 860a27fc73f79..63e55603736b6 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -656,6 +656,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetGraphOptimizationLevel(G
return *this;
}
+template
+inline SessionOptionsImpl& SessionOptionsImpl::SetDeterministicCompute(bool value) {
+ ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
+ return *this;
+}
+
template
inline SessionOptionsImpl& SessionOptionsImpl::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc
index fb314b161f1ad..e2084e9ef4f00 100644
--- a/onnxruntime/core/session/abi_session_options.cc
+++ b/onnxruntime/core/session/abi_session_options.cc
@@ -293,3 +293,10 @@ ORT_API_STATUS_IMPL(OrtApis::AddExternalInitializers, _In_ OrtSessionOptions* op
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External initializers are not supported in this build");
#endif
}
+
+ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value) {
+ API_IMPL_BEGIN
+ options->value.use_deterministic_compute = value;
+ return nullptr;
+ API_IMPL_END
+}
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 9f8786b727ac1..76a8a778025e1 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -2721,6 +2721,7 @@ static constexpr OrtApi ort_api_1_to_17 = {
&OrtApis::ShapeInferContext_SetOutputTypeShape,
&OrtApis::SetSymbolicDimensions,
&OrtApis::ReadOpAttr,
+ &OrtApis::SetDeterministicCompute,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 09c83219ad2c8..c9e4074a1afe2 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -500,5 +500,6 @@ ORT_API_STATUS_IMPL(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferCont
ORT_API_STATUS_IMPL(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info);
ORT_API_STATUS_IMPL(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length);
ORT_API_STATUS_IMPL(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);
+ORT_API_STATUS_IMPL(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);
} // namespace OrtApis
diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc
index 64d9803f8bf8e..d706b74f06141 100644
--- a/onnxruntime/test/shared_lib/test_session_options.cc
+++ b/onnxruntime/test/shared_lib/test_session_options.cc
@@ -15,6 +15,12 @@ TEST(CApiTest, session_options_graph_optimization_level) {
options.SetGraphOptimizationLevel(ORT_ENABLE_EXTENDED);
}
+TEST(CApiTest, session_options_deterministic_compute) {
+ // Manual validation currently. Check that SetDeterministicCompute in abi_session_options.cc is hit.
+ Ort::SessionOptions options;
+ options.SetDeterministicCompute(true);
+}
+
#if !defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) && !defined(ORT_NO_EXCEPTIONS)
TEST(CApiTest, session_options_oversized_affinity_string) {
From 5fade70b5052efae1553e8e3ac0b06a527877ef0 Mon Sep 17 00:00:00 2001
From: JJ <103335846+computerscienceiscool@users.noreply.github.com>
Date: Wed, 3 Jan 2024 17:26:25 -0800
Subject: [PATCH 07/20] Update README.md (#18963)
Fixed a small spelling error.
### Description
Small spelling error fix.
### Motivation and Context
### Motivation and Context
---
.../wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
index 2e6392aada454..50b0841a0200a 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
@@ -157,7 +157,7 @@ const createConvTranspose2DOpProgramShaderSource =
}
for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
- let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : '0.0'};
+ let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : `vec4<${dataType}>(0.0)`};
${output.set('batch', 'r', 'c + i', 'd1', 'value')};
}
}`;
@@ -174,7 +174,7 @@ const createConvTranspose2DOpProgramShaderSource =
let wOutChannel = d1 - groupId * ${outputChannelsPerGroup};
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
- var dotProd = 0.0;
+ var dotProd = ${dataType}(0.0);
for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) {
if (wR % dilations.x != 0) {
continue;
@@ -209,7 +209,7 @@ const createConvTranspose2DOpProgramShaderSource =
}
}
}
- let value = dotProd + ${hasBias ? 'bias[d1]' : '0.0'};
+ let value = dotProd + ${hasBias ? 'bias[d1]' : `${dataType}(0.0)`};
${output.setByOffset('global_idx', 'value')};
`;
From b18abaaa2c37e251eb639d740057aaa75821ba96 Mon Sep 17 00:00:00 2001
From: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Date: Thu, 4 Jan 2024 08:06:55 -0800
Subject: [PATCH 11/20] [js/web] wait for threadpool initialization (#18952)
### Description
a replacement of #18683. try to resolve #18689.
By specifying "-s PTHREAD_POOL_SIZE" flag in emscripten, it forces the
threadpool to initialize before the webassembly instance is available.
---
cmake/onnxruntime_webassembly.cmake | 1 +
js/web/lib/wasm/binding/ort-wasm.d.ts | 1 +
js/web/lib/wasm/wasm-factory.ts | 1 +
3 files changed, 3 insertions(+)
diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake
index 9014089cb6112..1dc982aea5f2f 100644
--- a/cmake/onnxruntime_webassembly.cmake
+++ b/cmake/onnxruntime_webassembly.cmake
@@ -281,6 +281,7 @@ else()
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORT_NAME=ortWasmThreaded"
"SHELL:-s DEFAULT_PTHREAD_STACK_SIZE=131072"
+ "SHELL:-s PTHREAD_POOL_SIZE=Module[\\\"numThreads\\\"]"
)
else()
target_link_options(onnxruntime_webassembly PRIVATE
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 00431a4e86d5b..6c55dcc1bfd32 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -111,6 +111,7 @@ export interface OrtWasmModule extends EmscriptenModule {
// #endregion
// #region config
+ numThreads?: number;
mainScriptUrlOrBlob?: string|Blob;
// #endregion
diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts
index 2b7d492cc70ba..81508a253ce8b 100644
--- a/js/web/lib/wasm/wasm-factory.ts
+++ b/js/web/lib/wasm/wasm-factory.ts
@@ -167,6 +167,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise
};
if (!BUILD_DEFS.DISABLE_WASM_THREAD && useThreads) {
+ config.numThreads = numThreads;
if (typeof Blob === 'undefined') {
config.mainScriptUrlOrBlob = path.join(__dirname, 'ort-wasm-threaded.js');
} else {
From 011b562b51de61c81ab8013415c7d40f677d0825 Mon Sep 17 00:00:00 2001
From: Changming Sun
Date: Thu, 4 Jan 2024 10:41:28 -0800
Subject: [PATCH 12/20] Update c# dependencies (#18995)
### Description
Update c# dependencies
---
csharp/ApiDocs/ApiDocs.csproj | 2 +-
.../Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj | 2 +-
.../Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj | 2 +-
.../EndToEndTests.Mobile.Automation.csproj | 2 +-
.../Microsoft.ML.OnnxRuntime.EndToEndTests.csproj | 4 ++--
.../Microsoft.ML.OnnxRuntime.Tests.Common.csproj | 4 ++--
.../Microsoft.ML.OnnxRuntime.Tests.Devices.csproj | 2 +-
.../Microsoft.ML.OnnxRuntime.Tests.Droid.csproj | 2 +-
.../Microsoft.ML.OnnxRuntime.Tests.iOS.csproj | 2 +-
csharp/tools/MauiModelTester/MauiModelTester.csproj | 2 +-
.../Microsoft.ML.OnnxRuntime.PerfTool.csproj | 2 +-
11 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/csharp/ApiDocs/ApiDocs.csproj b/csharp/ApiDocs/ApiDocs.csproj
index 994e57913cf47..6081c444ba1af 100644
--- a/csharp/ApiDocs/ApiDocs.csproj
+++ b/csharp/ApiDocs/ApiDocs.csproj
@@ -7,7 +7,7 @@
-
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj
index 3d35de1dfc6aa..feb6bcd46d63e 100644
--- a/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj
+++ b/csharp/sample/Microsoft.ML.OnnxRuntime.FasterRcnnSample/Microsoft.ML.OnnxRuntime.FasterRcnnSample.csproj
@@ -7,7 +7,7 @@
-
+
diff --git a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj
index af8fa611a5010..bedf14680826c 100644
--- a/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj
+++ b/csharp/sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample/Microsoft.ML.OnnxRuntime.ResNet50v2Sample.csproj
@@ -7,7 +7,7 @@
-
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj
index b90929ad6d1c1..7bda34d266295 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Mobile/EndToEndTests.Mobile.Automation/EndToEndTests.Mobile.Automation.csproj
@@ -6,7 +6,7 @@
-
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj
index 1c9827c5bac62..5ff924bcf82f3 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/Microsoft.ML.OnnxRuntime.EndToEndTests.csproj
@@ -37,10 +37,10 @@
-
+
-
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj
index ee81ab77432d1..ab27d62c3bf3b 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Microsoft.ML.OnnxRuntime.Tests.Common.csproj
@@ -119,8 +119,8 @@
-
-
+
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj
index 37e83be5e33a1..40f6d453c6a90 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Devices/Microsoft.ML.OnnxRuntime.Tests.Devices.csproj
@@ -11,6 +11,6 @@
-
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj
index 11855032584a3..ef7e0825e919e 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Droid/Microsoft.ML.OnnxRuntime.Tests.Droid.csproj
@@ -134,7 +134,7 @@
5.0.0.2083
-
+
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj
index 352de5db00920..56e65833724f6 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.iOS/Microsoft.ML.OnnxRuntime.Tests.iOS.csproj
@@ -99,7 +99,7 @@
2.4.1
-
+
5.0.0.2083
diff --git a/csharp/tools/MauiModelTester/MauiModelTester.csproj b/csharp/tools/MauiModelTester/MauiModelTester.csproj
index b0a17978328c0..a374c2933ce8f 100644
--- a/csharp/tools/MauiModelTester/MauiModelTester.csproj
+++ b/csharp/tools/MauiModelTester/MauiModelTester.csproj
@@ -51,7 +51,7 @@
-
+
diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj
index 24f0d14ad9903..e0420a6ed0456 100644
--- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj
+++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/Microsoft.ML.OnnxRuntime.PerfTool.csproj
@@ -80,7 +80,7 @@
-
+
From 889b1ef2d1eca39e1216e0ca06684accf5908500 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Xavier=20Dupr=C3=A9?=
Date: Thu, 4 Jan 2024 20:27:46 +0100
Subject: [PATCH 13/20] Fix schema type constraint for custom operators
(#17497)
### Description
onnxruntime may raise an error "type inference failed" but when a custom
operator sets IsHomogeneous to false in its schema. This change make
sure that TypeInferenceFunction and schema type constraints are aligned
to prevent that from happening.
---------
Co-authored-by: Xavier Dupre
Co-authored-by: Scott McKay
---
cmake/onnxruntime_unittests.cmake | 32 ++++
onnxruntime/core/graph/graph.cc | 8 +-
onnxruntime/core/session/custom_ops.cc | 145 ++++++++++++------
onnxruntime/test/shared_lib/test_inference.cc | 43 ++++++
.../custom_op_local_function.cc | 58 +++++++
.../custom_op_local_function.def | 3 +
.../custom_op_local_function.h | 15 ++
.../custom_op_local_function.lds | 6 +
.../custom_op_test_local_function.py | 47 ++++++
.../custom_ops_type_inference_fails_0.onnx | Bin 0 -> 2086 bytes
.../custom_op_local_function/dummy_gemm.cc | 119 ++++++++++++++
.../custom_op_local_function/dummy_gemm.h | 51 ++++++
12 files changed, 478 insertions(+), 49 deletions(-)
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.cc
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.def
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.h
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.lds
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/custom_op_test_local_function.py
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.cc
create mode 100644 onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.h
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 7c8c70f913dca..ed878e16c546e 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1662,6 +1662,38 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI
${ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG})
endif()
+if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_MINIMAL_BUILD_CUSTOM_OPS))
+
+ file(GLOB_RECURSE custom_op_local_function_test_library_src
+ "${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.cc"
+ "${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.h"
+ "${TEST_SRC_DIR}/testdata/custom_op_local_function/dummy_gemm.cc"
+ "${TEST_SRC_DIR}/testdata/custom_op_local_function/dummy_gemm.h"
+ )
+
+ onnxruntime_add_shared_library_module(custom_op_local_function ${custom_op_local_function_test_library_src})
+
+ onnxruntime_add_include_to_target(custom_op_local_function onnxruntime_common GTest::gtest GTest::gmock)
+ target_include_directories(custom_op_local_function PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session
+ ${REPO_ROOT}/include/onnxruntime/core/common)
+
+ if(UNIX)
+ if (APPLE)
+ set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip")
+ else()
+ string(CONCAT ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG
+ "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.lds "
+ "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
+ endif()
+ else()
+ set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG
+ "-DEF:${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.def")
+ endif()
+
+ set_property(TARGET custom_op_local_function APPEND_STRING PROPERTY LINK_FLAGS
+ ${ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG})
+endif()
+
if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD)
set (onnxruntime_logging_apis_test_SRC
${ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR}/test_logging_apis.cc)
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 904b4263f4e67..f71b7ecebcf1a 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -2367,8 +2367,14 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
inferred_type = existing_type;
} else {
// This should not happen: indicates incompleteness in ONNX inference.
+ std::stringstream ss;
+ ss << "index=" << operand_index;
+ for (auto it = op_formal_parameter.GetTypes().begin(); it != op_formal_parameter.GetTypes().end(); ++it) {
+ ss << "," << *(*it);
+ }
Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL,
- "Node (" + node_name + ") output arg (" + output_def->Name() + ") type inference failed");
+ "Node (" + node_name + ") Op (" + node.OpType() + ") output arg (" +
+ output_def->Name() + ") type inference failed, inferred types: " + ss.str());
return status;
}
diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc
index b827c28f129b1..eea675eb0193a 100644
--- a/onnxruntime/core/session/custom_ops.cc
+++ b/onnxruntime/core/session/custom_ops.cc
@@ -5,7 +5,10 @@
#pragma warning(disable : 4267)
#endif
+#include
#include
+#include
+#include
#include "core/common/gsl.h"
#include "core/framework/data_types.h"
@@ -755,14 +758,16 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
return KernelCreateInfo(def_builder.Build(), kernel_create_fn);
}
-ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustomOp* op) {
- const size_t input_count = op->GetInputTypeCount(op);
- const size_t output_count = op->GetOutputTypeCount(op);
+ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops) {
+ // The function registers the first schema assuming all the other one are the same except the types constraints.
+ ORT_ENFORCE(ops.size() > 0, "No kernels to registers.");
int undefined = 0;
+ // Creation of the schema for the first kernel in ops.
+ const OrtCustomOp* op = *ops.begin();
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0);
- for (size_t i = 0; i < input_count; i++) {
+ auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) {
onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single;
bool is_homogeneous = true;
int min_arity = 1;
@@ -770,51 +775,79 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom
// The OrtCustomOp interface did not support the methods to query input/output characteristics before
// ORT API version 8. So, query the relevant methods ONLY from API version 8 onwards.
if (op->version >= min_ort_version_with_optional_io_support) {
- const auto characteristic = op->GetInputCharacteristic(op, i);
+ const auto characteristic = is_input ? op->GetInputCharacteristic(op, i) : op->GetOutputCharacteristic(op, i);
// Support for optional and variadic inputs/output was added in versions 8 and 14, respectively.
if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
option = onnx::OpSchema::FormalParameterOption::Optional;
} else if ((op->version >= min_ort_version_with_variadic_io_support) &&
(characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC)) {
- ORT_ENFORCE(i == input_count - 1, "Only the last input to a custom op may be marked variadic.");
+ ORT_ENFORCE(i == count - 1, "Only the last ", (is_input ? "input" : "output"),
+ " to a custom op may be marked variadic.");
option = onnx::OpSchema::FormalParameterOption::Variadic;
- min_arity = op->GetVariadicInputMinArity(op);
- is_homogeneous = static_cast(op->GetVariadicInputHomogeneity(op));
+ min_arity = is_input ? op->GetVariadicInputMinArity(op) : op->GetVariadicOutputMinArity(op);
+ is_homogeneous = static_cast(is_input
+ ? op->GetVariadicInputHomogeneity(op)
+ : op->GetVariadicOutputHomogeneity(op));
}
}
- const auto type = op->GetInputType(op, i);
- if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
- undefined++;
+ // The loop goes through all operators sharing the same schema to build
+ // the minimal type constraints for all of them. All kernels must have
+ // the same number of inputs / outputs among themselves to be able to build
+ // the type constraints. Any kind of incompatibility between a schema and
+ // a kernel is checked by method IsCompatible once the schema is created
+ // by this method.
+ std::unordered_set all_types;
+ for (auto o : ops) {
+ ORT_ENFORCE(static_cast(i) != (is_input ? o->GetInputTypeCount(o) : o->GetOutputTypeCount(o)),
+ "Another version of operator '", schema.Name(),
+ "'has a different number of ", (is_input ? "inputs" : "outputs"),
+ ". onnxruntime allows the overloading of an operator "
+ "if all versions have the same number of declared ",
+ (is_input ? "inputs" : "outputs"), ".");
+ const auto type = is_input ? o->GetInputType(o, i) : o->GetOutputType(o, i);
+ if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
+ // If 'type' is undefined, all types are allowed regardless of what other versions of the same operator
+ // define. In that case, all_types is cleared, that's the convention used by the code following this loop
+ // to declare all types as possible types.
+ all_types.clear();
+ break;
+ }
+ all_types.insert(type);
}
- std::string input_name = "Input" + std::to_string(i);
- schema.Input(gsl::narrow_cast(i), input_name, "", input_name, option, is_homogeneous, min_arity);
- // support all types as input here in schema, and handle the type inference in TypeShapeInference func
- schema.TypeConstraint(input_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
- }
- for (size_t i = 0; i < output_count; i++) {
- onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single;
- bool is_homogeneous = true;
- int min_arity = 1;
-
- // The OrtCustomOp interface did not support the methods to query input/output characteristics before
- // ORT API version 8. So, query the relevant methods ONLY from API version 8 onwards.
- if (op->version >= min_ort_version_with_optional_io_support) {
- const auto characteristic = op->GetOutputCharacteristic(op, i);
+ std::string prefix = is_input ? "Input" : "Output";
+ std::string name = prefix + std::to_string(i);
+ if (is_input) {
+ schema.Input(gsl::narrow_cast(i), name, "", name, option, is_homogeneous, min_arity);
+ } else {
+ schema.Output(gsl::narrow_cast(i), name, "", name, option, is_homogeneous, min_arity);
+ }
- // Support for optional and variadic inputs/output was added in versions 8 and 14, respectively.
- if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
- option = onnx::OpSchema::FormalParameterOption::Optional;
- } else if ((op->version >= min_ort_version_with_variadic_io_support) &&
- (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC)) {
- ORT_ENFORCE(i == output_count - 1, "Only the last output to a custom op may be marked variadic.");
- option = onnx::OpSchema::FormalParameterOption::Variadic;
- min_arity = op->GetVariadicOutputMinArity(op);
- is_homogeneous = static_cast(op->GetVariadicOutputHomogeneity(op));
+ if (!all_types.empty()) {
+ // all_types is not empty then only the types in this container are allowed of this input.
+ std::vector types;
+ for (auto type : all_types) {
+ const ONNX_NAMESPACE::TypeProto* type_proto =
+ DataTypeImpl::TensorTypeFromONNXEnum(static_cast(type))->GetTypeProto();
+ types.push_back(*ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(*type_proto));
}
+ schema.TypeConstraint(name, types, "defined list of types");
+ } else {
+ // all_types is empty. As mentioned in the previous loop, all types are allowed.
+ schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
+ undefined++;
}
+ };
+
+ const size_t input_count = op->GetInputTypeCount(op);
+ for (size_t i = 0; i < input_count; i++) {
+ create_type_constraint(op, static_cast(input_count), static_cast(i), true);
+ }
+
+ const size_t output_count = op->GetOutputTypeCount(op);
+ for (size_t i = 0; i < output_count; i++) {
const auto type = op->GetOutputType(op, i);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) {
@@ -826,11 +859,9 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom
"cannot be inferred without which model loading cannot proceed.");
}
}
- std::string output_name = "Output" + std::to_string(i);
- schema.Output(gsl::narrow_cast(i), output_name, "", output_name, option, is_homogeneous, min_arity);
- // support all types as input here in schema, and handle the type inference in TypeShapeInference func
- schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
+ create_type_constraint(op, static_cast(output_count), static_cast(i), false);
}
+
schema.SetDomain(domain);
if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) {
schema.SinceVersion(op->GetStartVersion(op));
@@ -905,7 +936,7 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same homogeneity");
- ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
+ ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicOutputMinArity(op),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same arity");
@@ -994,18 +1025,36 @@ common::Status CreateCustomRegistry(gsl::span op_domai
}
}
+ // domain_kernels aggregate all custom operator per names.
+ std::unordered_map> domain_kernels;
for (const auto* op : domain->custom_ops_) {
// define kernel
- auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
- kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
- ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
- // define schema
- auto schema_map_iter = schema_map.find(op->GetName(op));
- if (schema_map_iter == schema_map.end()) {
- auto schema = CreateSchema(domain->domain_, op);
- schema_map.emplace(schema.Name(), schema);
+ auto it = domain_kernels.find(op->GetName(op));
+ if (it == domain_kernels.end()) {
+ domain_kernels[op->GetName(op)] = {op};
} else {
- ORT_RETURN_IF_ERROR(IsCompatible(schema_map_iter->second, op));
+ domain_kernels[op->GetName(op)].push_back(op);
+ }
+ }
+
+ // Creation of the schemas, one per unique name.
+ for (auto& [name, ops] : domain_kernels) {
+ auto schema = CreateSchema(domain->domain_, ops);
+ // schema.Name() is equal to ops[0]->GetName(ops[0]) and op->GetName(op) is the value
+ // used as a key for dictionary domain_kernels, therefore name == schema.Name().
+ schema_map.emplace(schema.Name(), schema);
+
+ // This loops checks that all custom operators sharing the same name are compatible with the defined schema.
+ for (const auto* op : ops) {
+ // define kernel
+ auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
+ kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
+ ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
+ // If IsCompatible returns false, then all custom operators named
+ // 'op->GetName(op)' are not compatible among themselves.
+ // They should have the same number of inputs and outputs, the same characteristics,
+ // (optional, ...). Only the type can change.
+ ORT_RETURN_IF_ERROR(IsCompatible(schema, op));
}
}
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index 7dee0bc41a6f3..35c6b308e8fea 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -1264,6 +1264,49 @@ TEST(CApiTest, test_custom_op_get_const_input) {
}
#endif
+#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
+#if defined(__ANDROID__)
+// Disable on android because custom op libraries are not copied to the emulator.
+TEST(CApiTest, DISABLED_test_custom_op_local_function) {
+#else
+TEST(CApiTest, test_custom_op_local_function) {
+#endif // defined(__ANDROID__)
+ const auto* model_path = TSTR("testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx");
+
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+ std::vector ort_inputs;
+ std::vector input_names;
+
+ // input 0 (float type)
+ input_names.emplace_back("X");
+ std::vector input_0_data = {1.0f, 2.0f, 3.0f, 4.0f};
+ std::vector input_0_dims = {2, 2};
+ ort_inputs.emplace_back(
+ Ort::Value::CreateTensor(info, const_cast(input_0_data.data()),
+ input_0_data.size(), input_0_dims.data(), input_0_dims.size()));
+ const char* output_name = "Y";
+
+ const ORTCHAR_T* lib_name;
+#if defined(_WIN32)
+ lib_name = ORT_TSTR("custom_op_local_function.dll");
+#elif defined(__APPLE__)
+ lib_name = ORT_TSTR("libcustom_op_local_function.dylib");
+#else
+lib_name = ORT_TSTR("./libcustom_op_local_function.so");
+#endif
+
+ Ort::SessionOptions session_opts;
+
+ session_opts.RegisterCustomOpsLibrary(lib_name);
+
+ Ort::Session session(*ort_env, model_path, session_opts);
+ auto default_allocator = std::make_unique();
+
+ session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
+ &output_name, 1);
+}
+#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
+
#if defined(USE_OPENVINO) && (!defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS))
TEST(CApiTest, test_custom_op_openvino_wrapper_library) {
// Tests a custom operator that wraps an OpenVINO MNIST model (.xml and .bin files serialized into node attributes).
diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.cc b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.cc
new file mode 100644
index 0000000000000..38eb5d3ca9072
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.cc
@@ -0,0 +1,58 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "custom_op_local_function.h"
+
+#include
+#include
+#include
+#include
+
+#include "core/common/common.h"
+#include "core/framework/ortdevice.h"
+#include "core/framework/ortmemoryinfo.h"
+#include "dummy_gemm.h"
+
+static const char* c_OpDomain = "onnx_extented.ortops.tutorial.cpu";
+
+static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
+ static std::vector ort_custom_op_domain_container;
+ static std::mutex ort_custom_op_domain_mutex;
+ std::lock_guard lock(ort_custom_op_domain_mutex);
+ ort_custom_op_domain_container.push_back(std::move(domain));
+}
+
+OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
+ const OrtApiBase* api_base) {
+ Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
+ Ort::UnownedSessionOptions session_options(options);
+
+ // An instance remaining available until onnxruntime unload the library.
+ static Cpu::CustomGemmOp c_CustomGemmFloat(
+ "CustomGemmFloat", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
+ false);
+ static Cpu::CustomGemmOp c_CustomGemmFloat8E4M3FN(
+ "CustomGemmFloat8E4M3FN", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN,
+ ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
+ false);
+ OrtStatus* result = nullptr;
+
+ ORT_TRY {
+ Ort::CustomOpDomain domain{c_OpDomain};
+
+ domain.Add(&c_CustomGemmFloat);
+ domain.Add(&c_CustomGemmFloat8E4M3FN);
+
+ session_options.Add(domain);
+ AddOrtCustomOpDomainToContainer(std::move(domain));
+ }
+ ORT_CATCH(const std::exception& e) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ Ort::Status status{e};
+ result = status.release();
+ });
+ }
+
+ return result;
+}
diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.def b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.def
new file mode 100644
index 0000000000000..2bbbe3fe3ccb2
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.def
@@ -0,0 +1,3 @@
+LIBRARY "custom_op_local_function.dll"
+EXPORTS
+ RegisterCustomOps @1
diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.h b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.h
new file mode 100644
index 0000000000000..900e47908b588
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.h
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "onnxruntime_c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.lds b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.lds
new file mode 100644
index 0000000000000..bb5d118c7ca22
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/custom_op_local_function.lds
@@ -0,0 +1,6 @@
+VERS_1.0.0 {
+ global:
+ RegisterCustomOps;
+ local:
+ *;
+};
diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_op_test_local_function.py b/onnxruntime/test/testdata/custom_op_local_function/custom_op_test_local_function.py
new file mode 100644
index 0000000000000..3e353d4142554
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/custom_op_test_local_function.py
@@ -0,0 +1,47 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+import os
+import sys
+import unittest
+
+import numpy as np
+import onnx
+
+from onnxruntime import InferenceSession, SessionOptions
+
+
+class TestOnnxToolsGraph(unittest.TestCase):
+ def test_basic_all(self):
+ if sys.platform.startswith("win"):
+ shared_library = "custom_op_local_function.dll"
+ elif sys.platform.startswith("darwin"):
+ shared_library = "libcustom_op_local_function.dylib"
+ else:
+ shared_library = "./libcustom_op_local_function.so"
+ if not os.path.exists(shared_library):
+ raise FileNotFoundError(f"Unable to find '{shared_library}'")
+
+ filename = "custom_ops_type_inference_fails_0.onnx"
+
+ with open(os.path.join(os.path.dirname(__file__), filename), "rb") as f:
+ onxo = onnx.load(f)
+ d = onxo.opset_import.add()
+ d.domain = "ai.onnx.ml"
+ d.version = 2
+
+ sess_opts = SessionOptions()
+ sess_opts.register_custom_ops_library(shared_library)
+
+ sess = InferenceSession(
+ onxo.SerializeToString(),
+ sess_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ x = np.arange(2**2).reshape((2,) * 2).astype(np.float32)
+ t = np.arange(8).reshape((2, 4)).astype(np.float32)
+ got = sess.run(None, dict(X=x))[0]
+ np.testing.assert_allclose(t, got, atol=1e-5)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/onnxruntime/test/testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx b/onnxruntime/test/testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..8116ec338064567cea06fafe45168567813071ed
GIT binary patch
literal 2086
zcmah}|4!pZ5cb+m;tai-=I<-p}o%9agw8mMoQutf3x%L%zUG1R9cX9>6uiM%wJS!
z0y(DYvEAOFrDHpBoemS`byt71a}_#)?|$8LLhfI)eLrMQY?MLf(fsrckxkl(5MR9z
zfT|Y-jvvAw1k%%>
z@-wDDi=#IO&i7F~FA2va6nX4~$=3U(m73;A+O;e
znpIl3Mn`+B7mIl>rYb~NCHzq9JorLfE=gh=2QYLi7P*=rUoJwSh^Y;@GjHg9#A#}oiS1){#G@YhNA+xC}+`7_?
zIHpRCA=n*DHF^gr3o7^9df}Th7Bh1W&=6l*>L)18nCS{mmT5q4Q>EDp^ztF|dM<1A
z0-=+0#=4##B$*PPWqe$!?6B}&(D99v=_0m8cW`dEqmqNn5_5kr<-+9eCyP+F-EH>t0
z;+$P2;`~paC-vz%tEE-B&3Av%#tkW&
zV{>X&VJs6>Mb>*OvjiyyvLa9g1I9Y|WZ(zkr{e>jRj`VEhjBNIrizj){o(stc`p^_
z-YsPv-l4y@tTo2x=UVQ0rRD`(<*>wPOQmuo6
zv~kjhuWKNCxZHv@7`~%rToHD*BZ@e$uEUK9P@TR%(8rSK&Im-6esVS%&lM0hI(e*@
zhpQY_Hr(QMS?qCK7-^1-Gg7Dx*cX$I?=lZ>C;rV17&vaRoM`)@)47l56J)|;7zc{s
M$%T|n&0SOSFP?~MApigX
literal 0
HcmV?d00001
diff --git a/onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.cc b/onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.cc
new file mode 100644
index 0000000000000..4591dd89d2e35
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.cc
@@ -0,0 +1,119 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+
+#include "dummy_gemm.h"
+
+#ifndef ORT_ENFORCE
+#define ORT_ENFORCE(cond, ...) \
+ if (!(cond)) ORT_CXX_API_THROW("Initialization failed.", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+#endif
+
+namespace Cpu {
+
+void* CustomGemmOp::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
+ return std::make_unique(api, info).release();
+}
+
+const char* CustomGemmOp::GetName() const { return op_name_; }
+
+const char* CustomGemmOp::GetExecutionProviderType() const {
+ return "CPUExecutionProvider";
+}
+
+size_t CustomGemmOp::GetInputTypeCount() const { return 6; }
+
+ONNXTensorElementDataType CustomGemmOp::GetInputType(size_t index) const {
+ switch (index) {
+ case 0: // A
+ case 1: // B
+ return ab_type_;
+ case 2: // C
+ return c_type_;
+ case 3: // scale A
+ case 4: // scale B
+ case 5: // scale Y
+ return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
+ default:
+ ORT_CXX_API_THROW("Input index is out of boundary.", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+}
+
+OrtCustomOpInputOutputCharacteristic CustomGemmOp::GetInputCharacteristic(size_t index) const {
+ switch (index) {
+ case 0:
+ case 1:
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
+ case 2:
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
+ case 3:
+ case 4:
+ case 5:
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
+ default:
+ ORT_CXX_API_THROW("Input index is out of boundary.", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+}
+
+size_t CustomGemmOp::GetOutputTypeCount() const { return 1; }
+
+ONNXTensorElementDataType CustomGemmOp::GetOutputType(size_t index) const {
+ // D, scale D
+ switch (index) {
+ case 0:
+ return d_type_;
+ default:
+ ORT_CXX_API_THROW("Output index is out of boundary.", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+}
+
+OrtCustomOpInputOutputCharacteristic CustomGemmOp::GetOutputCharacteristic(size_t index) const {
+ switch (index) {
+ case 0:
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
+ default:
+ ORT_CXX_API_THROW("Output index is out of boundary.", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+}
+
+CustomGemmKernel::CustomGemmKernel(const OrtApi&, const OrtKernelInfo*) {}
+
+template
+ONNXTensorElementDataType GetTypeAndShape(const TValue& input, std::vector& shape, bool swap = false) {
+ auto t = input.GetTensorTypeAndShapeInfo();
+ shape = t.GetShape();
+ ORT_ENFORCE(shape.size() == 2);
+ if (swap) {
+ std::swap(shape[0], shape[1]);
+ }
+ return t.GetElementType();
+}
+
+void CustomGemmKernel::Compute(OrtKernelContext* context) {
+ // The function does nothing related to Gemm operator. It creates an output with the same dimensions as
+ // the model used in the unit tests and fills it with the first integer.
+ Ort::KernelContext ctx(context);
+
+ auto n_inputs = ctx.GetInputCount();
+ ORT_ENFORCE(n_inputs >= 2);
+ Ort::ConstValue input_A = ctx.GetInput(0);
+ Ort::ConstValue input_B = ctx.GetInput(1);
+
+ std::vector shape_A, shape_B;
+ GetTypeAndShape(input_A, shape_A);
+ GetTypeAndShape(input_B, shape_B);
+ ORT_ENFORCE(shape_A.size() == 2);
+ ORT_ENFORCE(shape_B.size() == 2);
+ std::vector dimensions{shape_A[0], shape_B[1]};
+ Ort::UnownedValue Y = ctx.GetOutput(0, dimensions);
+ float* out = Y.GetTensorMutableData();
+ size_t end = static_cast(dimensions[0] * dimensions[1]);
+ for (size_t i = static_cast(0); i < end; ++i) {
+ out[i] = static_cast(i);
+ }
+}
+
+} // namespace Cpu
diff --git a/onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.h b/onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.h
new file mode 100644
index 0000000000000..97a8e78cae6a6
--- /dev/null
+++ b/onnxruntime/test/testdata/custom_op_local_function/dummy_gemm.h
@@ -0,0 +1,51 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#pragma once
+
+#define ORT_API_MANUAL_INIT
+#include
+#include
+#undef ORT_API_MANUAL_INIT
+
+#include
+
+namespace Cpu {
+
+struct CustomGemmKernel {
+ CustomGemmKernel(const OrtApi& api, const OrtKernelInfo* info);
+ void Compute(OrtKernelContext* context);
+};
+
+struct CustomGemmOp : Ort::CustomOpBase {
+ typedef Ort::CustomOpBase parent_type;
+ CustomGemmOp(const char* op_name, ONNXTensorElementDataType ab_type,
+ ONNXTensorElementDataType c_type,
+ ONNXTensorElementDataType d_type, bool compute_time_as_output)
+ : parent_type() {
+ op_name_ = op_name;
+ ab_type_ = ab_type;
+ c_type_ = c_type;
+ d_type_ = d_type;
+ compute_time_as_output_ = compute_time_as_output;
+ }
+ void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
+ const char* GetName() const;
+ const char* GetExecutionProviderType() const;
+
+ size_t GetInputTypeCount() const;
+ ONNXTensorElementDataType GetInputType(size_t index) const;
+ OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const;
+
+ size_t GetOutputTypeCount() const;
+ ONNXTensorElementDataType GetOutputType(size_t index) const;
+ OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const;
+
+ private:
+ const char* op_name_;
+ ONNXTensorElementDataType ab_type_;
+ ONNXTensorElementDataType c_type_;
+ ONNXTensorElementDataType d_type_;
+ bool compute_time_as_output_;
+};
+
+} // namespace Cpu
From 658e30eb33f157dc7e7cba0e6ac9bf37178722e1 Mon Sep 17 00:00:00 2001
From: Wei-Sheng Chin
Date: Thu, 4 Jan 2024 12:59:47 -0800
Subject: [PATCH 14/20] Remove DORT since it's in PyTorch main now (#18996)
Main code are removed and tests are modified to use DORT directly from
PyTorch.
---
cmake/onnxruntime_python.cmake | 7 -
.../python/training/torchdynamo/__init__.py | 4 -
.../training/torchdynamo/ort_backend.py | 729 ------------------
.../training/torchdynamo/register_backend.py | 89 ---
.../test/python/orttraining_test_dort.py | 47 +-
.../orttraining_test_dort_custom_ops.py | 26 +-
setup.py | 1 -
7 files changed, 42 insertions(+), 861 deletions(-)
delete mode 100644 orttraining/orttraining/python/training/torchdynamo/__init__.py
delete mode 100644 orttraining/orttraining/python/training/torchdynamo/ort_backend.py
delete mode 100644 orttraining/orttraining/python/training/torchdynamo/register_backend.py
diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake
index 61922961588b2..2e3594f256f65 100644
--- a/cmake/onnxruntime_python.cmake
+++ b/cmake/onnxruntime_python.cmake
@@ -354,9 +354,6 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_optim_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/optim/*.py"
)
- file(GLOB onnxruntime_python_torchdynamo_srcs CONFIGURE_DEPENDS
- "${ORTTRAINING_SOURCE_DIR}/python/training/torchdynamo/*.py"
- )
file(GLOB onnxruntime_python_ortmodule_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/*.py"
)
@@ -746,7 +743,6 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental/gradient_graph
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/optim
- COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/torchdynamo
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/json_config
@@ -777,9 +773,6 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_optim_srcs}
$/onnxruntime/training/optim/
- COMMAND ${CMAKE_COMMAND} -E copy
- ${onnxruntime_python_torchdynamo_srcs}
- $/onnxruntime/training/torchdynamo/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_srcs}
$/onnxruntime/training/ortmodule/
diff --git a/orttraining/orttraining/python/training/torchdynamo/__init__.py b/orttraining/orttraining/python/training/torchdynamo/__init__.py
deleted file mode 100644
index 862c45ce31b25..0000000000000
--- a/orttraining/orttraining/python/training/torchdynamo/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py
deleted file mode 100644
index 9bafe39a5c211..0000000000000
--- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py
+++ /dev/null
@@ -1,729 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-
-import dataclasses
-import logging
-from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
-
-import numpy as np
-import onnx
-import torch
-import torch._C
-import torch._ops
-import torch._prims.executor
-import torch.fx
-import torch.onnx
-
-# TODO(wschin,justinchuby): Since the internal APIs are not stable, please
-# contact us if you hit errors.
-import torch.onnx._internal
-import torch.onnx._internal.diagnostics
-import torch.onnx._internal.exporter
-import torch.onnx._internal.fx.decomposition_table
-import torch.onnx._internal.fx.passes
-from torch._subclasses.fake_tensor import FakeTensor
-from torch.fx.passes.fake_tensor_prop import FakeTensorProp
-from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
-from torch.fx.passes.operator_support import OperatorSupport
-from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
-from torch.utils import _pytree
-
-import onnxruntime # type: ignore
-from onnxruntime.capi import _pybind_state as ORTC
-
-_NP_DTYPE = {
- torch.float16: np.float16,
- torch.float32: np.float32,
- torch.float64: np.float64,
- torch.uint8: np.uint8,
- torch.int8: np.int8,
- torch.int16: np.int16,
- torch.int32: np.int32,
- torch.int64: np.longlong,
- torch.bool: np.bool_,
-}
-
-_ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE = {
- 1: torch.float32,
- 2: torch.uint8,
- 3: torch.int8,
- 5: torch.int16,
- 6: torch.int32,
- 7: torch.int64,
- 9: torch.bool,
- 10: torch.float16,
-}
-
-_TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE = {value: key for key, value in _ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE.items()}
-
-
-def _nvtx_range_push(name: str):
- """If PyTorch is installed with CUDA support, this starts NVTX range.
-
- Check torch.cuda.nvtx.range_push's document for more details.
- """
- if torch.cuda.is_available():
- torch.cuda.nvtx.range_push(name)
-
-
-def _nvtx_range_pop():
- """If PyTorch is installed with CUDA support, this terminates NVTX range.
-
- Check torch.cuda.nvtx.range_pop's document for more details.
- """
- if torch.cuda.is_available():
- torch.cuda.nvtx.range_pop()
-
-
-def _get_ort_device_type(device_type: str):
- if device_type == "cuda":
- return ORTC.OrtDevice.cuda() # type: ignore
- if device_type == "cpu":
- return ORTC.OrtDevice.cpu() # type: ignore
- # ort pytorch device is mapped to NPU OrtDevice type
- if device_type == "ort":
- return ORTC.OrtDevice.npu() # type: ignore
- raise ValueError("Unsupported device type: " + device_type)
-
-
-logger = logging.getLogger(__name__)
-# Uncomment the following lines to print out development info.
-# logging.basicConfig(level=logging.INFO)
-# logger.setLevel(logging.INFO)
-
-
-class OrtOperatorSupport(OperatorSupport):
- """
- Operator support for ONNXRuntime backend. It has two-level of support decision.
- One is via support_dict and the other one is via extra_support_dict. The logic
- of using support_dict is implemented in OrtOperatorSupport and extra_support_dict
- is used by OperatorSupport.is_node_supported.
- """
-
- def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
- # Use extra_support_dict[op_name] = None to indicate
- # we support op_name with all input types. Otherwise,
- # see support_dict (type: SupportDict) in operator_support.py
- # for specifying supported types.
- super().__init__(extra_support_dict)
- self._support_dict = support_dict
-
- def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool:
- # OperatorSupport.is_node_supported returns True for non-callable nodes.
- # Since ORT can't execute them, we return False here to override the base
- # behavior.
- if node.op not in CALLABLE_NODE_OPS:
- return False
- # This is the and the only place to decide if aten op is supported.
- if node.op == "call_function" and node.target in self._support_dict:
- logger.info("support_dict supports node.target: %s (type: %s)", node.target, type(node.target))
- return True
- logger.info("support_dict doesn't support node.target: %s (type: %s)", node.target, type(node.target))
- # If node.target is not in support_dict, we still want to check if torch.jit.script
- # can convert it to ONNX equivalence. Let's use base mechanism to do this.
- # See extra_support_dict for supported ops.
- if super().is_node_supported(submodules, node):
- logger.info("extra_support_dict supports node.target: %s (type: %s)", node.target, type(node.target))
- return True
- logger.info("extra_support_dict doesn't supports node.target: %s (type: %s)", node.target, type(node.target))
- return False
-
-
-def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
- """
- In torch.fx.Graph, placehoder is a special assignment node. If it's not
- executed in the beginning, it could overwrite values computed by upstream
- nodes.
- """
-
- graph = graph_module.graph
- placeholders = []
- first_not_placeholder = None
- for node in graph.nodes:
- if node.op == "placeholder":
- placeholders.append(node)
- if first_not_placeholder is None and node.op != "placeholder":
- first_not_placeholder = node
- if first_not_placeholder is None:
- return
- for placeholder in placeholders:
- first_not_placeholder.prepend(placeholder)
-
-
-def _replace_to_copy_with_to(fx_module: torch.fx.GraphModule) -> None:
- # aten._to_copy doesn't have exporter so we replace it with aten.to.
- for node in fx_module.graph.nodes:
- if (
- isinstance(node.target, torch._ops.OpOverload)
- and node.target.overloadpacket == torch.ops.aten._to_copy # type: ignore
- ):
- is_default_layout = True
- is_on_same_device = True
- is_cast = True
- are_kwargs_supported = True
- if "layout" in node.kwargs and node.kwargs["layout"] != torch.strided:
- is_default_layout = False
- if "device" in node.kwargs and node.kwargs["device"] != node.args[0].meta["val"].device:
- is_on_same_device = False
- if "dtype" not in node.kwargs:
- is_cast = False
- for kwarg in node.kwargs:
- if kwarg not in ["layout", "device", "dtype"]:
- are_kwargs_supported = False
-
- if len(node.args) == 1 and is_default_layout and is_on_same_device and is_cast and are_kwargs_supported:
- # This aten::_to_copy looks like ONNX Cast, so other kwargs are ignored.
- # This change could lead to invalid FX graph but it doesn't matter, as long as the downstream backend,
- # ONNXRuntime, can execute the exported ONNX graph.
- node.kwargs = {"dtype": node.kwargs["dtype"]}
-
- node.target = torch.ops.aten.to.dtype # type: ignore
- else:
- raise RuntimeError(
- f"aten._to_copy must be replaced with other ONNX-supported aten ops. \
- args={[arg.meta for arg in node.args]}, kwargs={node.kwargs}"
- )
- fx_module.recompile()
-
-
-def _create_onnx_model(onnx_proto):
- return onnx.ModelProto.FromString(onnx_proto)
-
-
-def _create_onnx_session(onnx_proto, eps: Tuple[str, ...], session_options):
- # TODO(wechi): Add more EPs per PyTorch device types.
- # TODO(wechi): enable external allocators.
- return onnxruntime.InferenceSession(onnx_proto, providers=eps, sess_options=session_options)
-
-
-def _infer_ep_from_device(*args) -> Tuple[str, ...]:
- """Return the first valid device (i.e., GPU or CPU) in argument list."""
- eps = []
- for arg in args:
- if hasattr(arg, "device"):
- device = arg.device
- if device.type == "cuda":
- eps.append("CUDAExecutionProvider")
- elif device.type == "cpu":
- eps.append("CPUExecutionProvider")
- return tuple(eps)
-
-
-def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
- placeholders = []
- for node in graph_module.graph.nodes:
- if node.op == "placeholder":
- if hasattr(node, "meta") and "val" in node.meta:
- assert isinstance(node.meta["val"], torch.Tensor)
- placeholders.append(node)
-
-
-def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
- """Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
- for node in graph_module.graph.nodes:
- if node.op == "output":
- # Output node is unique. Let's retrieve output values from
- # this node's input list. And then just return.
- return node.args[0]
- raise ValueError("No output node found in this torch.fx.GraphModule.")
-
-
-def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
- """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
- flattened_output_args, _ = _pytree.tree_flatten(_extract_graph_module_outputs(graph_module))
- # Output arguments with example value (type: torch.Tensor) in the `graph_module`.
- selected_output_args = [
- output_arg.meta["val"]
- for output_arg in flattened_output_args
- # output_arg must have tensor for its device information.
- # Otherwise, skip it.
- if (hasattr(output_arg, "meta") and "val" in output_arg.meta)
- ]
- return _infer_ep_from_device(*selected_output_args)
-
-
-def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
- """Sort execution providers in eps based on pre-set priority."""
-
- def get_execution_provider_priority(ep: str) -> int:
- if ep == "CPUExecutionProvider":
- # Lowest priority.
- return 2
- if ep == "CUDAExecutionProvider":
- # Higher priority than CPU but lower than
- # other specialized EPs.
- return 1
- # Highest priority.
- return 0
-
- unique_eps = set(eps)
- return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))
-
-
-def _get_onnx_devices(values: Tuple[torch.Tensor, ...]) -> Tuple[ORTC.OrtDevice, ...]: # type: ignore
- assert all(value.device == values[0].device for value in values), "All values must be on the same device."
-
- def _device_id_or_zero(device_id: int) -> int:
- return device_id or 0
-
- devices: Tuple[ORTC.OrtDevice, ...] = tuple( # type: ignore
- ORTC.OrtDevice( # type: ignore
- _get_ort_device_type(value.device.type),
- ORTC.OrtDevice.default_memory(), # type: ignore
- _device_id_or_zero(value.device.index),
- )
- for value in values
- )
- return devices
-
-
-def _get_ortvalues_from_torch_tensors(
- tensors: Tuple[torch.Tensor, ...], devices: Tuple[ORTC.OrtDevice, ...]
-) -> Tuple[torch.Tensor, ...]:
- ortvalues = ORTC.OrtValueVector() # type: ignore
- ortvalues.reserve(len(tensors))
- dtypes = []
- shapes = []
- data_ptrs = []
-
- for tensor in tensors:
- dtypes.append(_NP_DTYPE[tensor.dtype])
- shapes.append(tensor.size())
- data_ptrs.append(tensor.data_ptr())
- ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices)
- return ortvalues
-
-
-def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor:
- if tensor.is_sparse:
- raise ValueError("sparse tensor is not yet supported.")
- out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device)
- return out
-
-
-def _run_onnx_session_with_ortvaluevector(
- sess: onnxruntime.InferenceSession,
- input_names: Tuple[str, ...],
- inputs: Tuple[torch.Tensor, ...],
- input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore
- output_names: Tuple[str, ...],
- outputs: Tuple[torch.Tensor, ...],
- output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore
- preallocate_output: bool,
-) -> Tuple[torch.Tensor, ...]:
- _nvtx_range_push("contiguous")
- inputs = tuple(a.contiguous() for a in inputs)
- _nvtx_range_pop()
-
- _nvtx_range_push("push_back_batch")
-
- ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices)
-
- # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue.
- # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue
- # to torch Tensor transferring the ownership.
- if preallocate_output:
- pth_outputs = tuple(map(lambda t: _to_real_tensor(t) if isinstance(t, FakeTensor) else t, outputs))
- ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices)
- else:
- ort_outputs = ORTC.OrtValueVector() # type: ignore
- _nvtx_range_pop()
-
- _nvtx_range_push("run_with_ortvaluevector")
- run_options = onnxruntime.RunOptions()
- run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
- sess.run_with_ortvaluevector(run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices)
- _nvtx_range_pop()
-
- if preallocate_output:
- return pth_outputs
- else:
- _nvtx_range_push("after run_with_ortvaluevector")
- pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(ort_outputs) # type: ignore
- _nvtx_range_pop()
- return pth_outputs
-
-
-def _assert_allclose_with_detailed_error_message(
- actual: torch.Tensor, expected: torch.Tensor, rtol: float = 1e-03, atol: float = 1e-04
-):
- diff = actual - expected
- real_atol = torch.max(torch.abs(diff))
- max_value = torch.max(torch.abs(actual), torch.abs(expected))
- max_value[max_value == 0.0] = 1.0
- real_rtol = torch.max(diff / max_value)
- allclose = bool(real_atol <= atol or real_rtol <= rtol)
- if not allclose:
- raise RuntimeError(
- "ONNX output doesn't match baseline output with "
- f"actual rtol={real_rtol} and actual atol={real_atol} "
- f"but expected rtol={rtol} and expected atol={atol}."
- )
-
-
-class OrtExecutionInfoPerSession:
- """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""
-
- def __init__(
- self,
- session: onnxruntime.InferenceSession,
- input_names: Tuple[str, ...],
- input_value_infos: Tuple[onnx.ValueInfoProto, ...],
- output_names: Tuple[str, ...],
- output_value_infos: Tuple[onnx.ValueInfoProto, ...],
- input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore
- output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore
- example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
- ):
- # Carrier of ONNX model and its executor.
- self.session: onnxruntime.InferenceSession = session
- # For the ONNX model stored in self.session, self.input_names[i] is the
- # name of the i-th positional input.
- self.input_names: Tuple[str, ...] = input_names
- # self.input_name[i]'s type information is stored in self.input_value_infos[i].
- self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos
- # Similar to self.input_names, but for outputs.
- self.output_names: Tuple[str, ...] = output_names
- # Similar to self.input_value_infos but for outputs.
- self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos
- # For the ONNX model stored in self.session, self.input_devices[i] is the
- # i-th positional input's device.
- self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices # type: ignore
- # Similar to self.input_devices, but for outputs.
- self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices # type: ignore
- # This is the outputs of executing the original torch.fx.GraphModule with example inputs
- # (i.e., args passed into OrtBackend._ort_acclerated_call).
- self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = example_outputs
-
- def is_supported(self, *args):
- # Compare the args and the input schema in ONNX model and
- # return the first match.
- if len(args) != len(self.input_value_infos):
- return False
- for arg, value_info in zip(args, self.input_value_infos):
- if not isinstance(arg, torch.Tensor):
- return False
- onnx_dtype = _TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE[arg.dtype]
- if onnx_dtype != value_info.type.tensor_type.elem_type:
- return False
- for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
- if isinstance(dim, int) and (onnx_dim.dim_value == dim or onnx_dim.dim_param):
- continue
- elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
- continue
- else:
- return False
- return True
-
-
-@dataclasses.dataclass
-class OrtExecutionInfoForAllGraphModules:
- def __init__(self):
- # All sessions (and their related information) created by exporting the same GraphModule
- # with different inputs.
- self.execution_info_per_graph_module: Dict[torch.fx.GraphModule, List[OrtExecutionInfoPerSession]] = {}
-
- def search_reusable_session_execution_info(self, graph_module: torch.fx.GraphModule, *args):
- if graph_module not in self.execution_info_per_graph_module:
- return None
- # All execution information for ONNX models exported from the same `graph_module`
- # with different inputs.
- candidates = self.execution_info_per_graph_module[graph_module]
-
- for candidate in candidates:
- if candidate.is_supported(*args):
- # Returns the first session that accepts this input schema.
- return candidate
- # No reusable session found.
- return None
-
- def cache_session_execution_info(self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession):
- if graph_module not in self.execution_info_per_graph_module:
- self.execution_info_per_graph_module[graph_module] = [info]
- else:
- self.execution_info_per_graph_module[graph_module].append(info)
-
-
-class OrtBackend:
- """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.
-
- The compiler entry point is OrtBackend.compile, which
- 1. partitions the original graph into supported sub-graphs (type: torch.fx.GrpahModule) and unsupported
- sub-graphs.
- 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
- 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
- """
-
- def __init__(
- self,
- ep: str = "CPUExecutionProvider",
- preallocate_output: bool = False,
- session_options=None,
- onnx_exporter_options: Optional["torch.onnx.ExportOptions"] = None,
- ):
- # onnx_exporter_options contains information shared between exporter and DORT.
- # For example, they should use the same decomposition table when
- # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
- # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
- # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
- if onnx_exporter_options is None:
- onnx_exporter_options = torch.onnx.ExportOptions()
- # Convert user-facing option to internal option used by ONNX exporter
- # to access required information.
- # Some useful fields:
- # - Decomposition table for decomposing FX operators in exporter is
- # self.resolved_onnx_exporter_options.decomposition_table.
- # - self.resolved_onnx_exporter_options.onnx_registry records what
- # aten/prim ops are supported by exporter and their exporters (type: callable).
- self.resolved_onnx_exporter_options = torch.onnx._internal.exporter.ResolvedExportOptions(onnx_exporter_options)
-
- # TODO(wechi): This line must generate result identical to the call of
- # _create_onnx_supports_op_overload_table(...) inside
- # create_onnx_friendly_decomposition_table(...) in
- # torch/onnx/_internal/fx/decomposition_table.py.
- support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
- # This is identical to self.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry.
- self.resolved_onnx_exporter_options.onnx_registry
- ) # type: ignore
-
- extra_support_dict: Dict[str, Any] = {
- "getattr": None,
- "_operator.getitem": None,
- }
-
- self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
- # TODO: this is a naive implementation of cache without proper guard
- self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
- # Conceptually, this filed is a 2-layer dictionary
- # GraphModule 0
- # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
- # ONNX Model 1
- # ...
- # GraphModule 1
- # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
- # ONNX Model 3
- # ...
- # ...
- # , which caches all previous compilation result so that we can reuse them.
- # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs
- # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different
- # graphs captured by Dynamo and sent to OrtBackend.compile.
- self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()
-
- self._assert_allclose_to_baseline = False
-
- self.ep = ep
- self.session_options = session_options
-
- # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession
- # in order to avoid internal allocation of output buffers in InferenceSession.
- # If output ortvalue returned from InferenceSession is allocated internally,
- # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership.
- # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor
- # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device.
- # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator,
- # and use the preallocated output buffers for InferenceSession not holding any ownership for them.
- self.preallocate_output = preallocate_output
-
- def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
- cached_execution_info_per_session = self._all_ort_execution_info.search_reusable_session_execution_info(
- graph_module, *args
- )
- if cached_execution_info_per_session:
- onnx_session = cached_execution_info_per_session.session
- input_names = cached_execution_info_per_session.input_names
- output_names = cached_execution_info_per_session.output_names
- input_devices = cached_execution_info_per_session.input_devices
- output_devices = cached_execution_info_per_session.output_devices
- prim_outputs = cached_execution_info_per_session.example_outputs
- else:
- # It's first time seeing such as graph. Let's make a new session
- # (type: onnxruntime.InferenceSession) for it.
-
- # TODO(wechi): this is a workaround for pytorch/pytorch#84311.
- _move_placeholder_to_front(graph_module)
- # Generate reference outputs. They are used to indicate output
- # tensors' types and devices when calling ORT.
- #
- # WARNING: The downstream code should not change prim_outputs and
- # this backend should always produces output with schema identical to prim_outputs'.
-
- if self.resolved_onnx_exporter_options.dynamic_shapes:
- # No pre-allocation when dynamic shape is enabled.
- self.preallocate_output = False
- extracted_outputs = _extract_graph_module_outputs(graph_module)
-
- def maybe_map_to_meta_val(value):
- if hasattr(value, "meta") and "val" in value.meta:
- # Select outputs with "val" information. Without "val",
- # it's not possible access output_arg.meta["val"].device.
- return value.meta["val"]
- else:
- return value
-
- prim_outputs = _pytree.tree_map(maybe_map_to_meta_val, extracted_outputs)
- else:
- try:
- prim_outputs = FakeTensorProp(graph_module).propagate(*args, **kwargs)
- except Exception:
- logger.info(f"FakeTensorProb failed for {graph_module}")
- # When FakeTensorProp fails, it is not possible to preallocate output buffers
- # because the output shapes are not inferred.
- self.preallocate_output = False
-
- # rethrow FakeTensorProb failure because it is not yet currently handled.
- raise
-
- graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
- self.resolved_onnx_exporter_options.diagnostic_context, graph_module
- ).run()
-
- from torch.onnx._internal.fx import fx_onnx_interpreter
-
- # Create the object to iterate through the nodes in graph one-by-one
- # and calls the corresponding ONNX exporter for each node.
- fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
- diagnostic_context=self.resolved_onnx_exporter_options.diagnostic_context
- )
- # Start the per-node exporting process. It's conceptually a for loop
- # scanning through the nodes in the graph.
- exported = fx_interpreter.run(
- fx_graph_module=graph_module,
- onnxfunction_dispatcher=self.resolved_onnx_exporter_options.onnxfunction_dispatcher,
- op_level_debug=self.resolved_onnx_exporter_options.op_level_debug,
- )
- # Convert the exported result to ONNX ModelProto.
- onnx_proto = exported.to_model_proto(
- opset_version=self.resolved_onnx_exporter_options.onnx_registry.opset_version
- ).SerializeToString()
-
- # Initialize a ORT session to execute this ONNX model.
- # Note that TorchDynamo assumes all inputs/outputs are on the
- # same device, but it's subject to change (very likely with
- # dynamic shape support), so we add execution providers
- # based on the all inputs/outputs plus a default OrtBackend.ep.
- eps_from_args = _infer_ep_from_device(args)
- eps_from_graph_module = _infer_ep_from_graph_module(graph_module)
- if eps_from_args:
- # If user feeds CUDA tensor as input argument,
- # we want to use CUDA EP.
- # Thus, `eps_from_args` (deduced from input arguments)
- # has highest priority.
- selected_eps = _sort_eps((*eps_from_args, self.ep))
- elif eps_from_graph_module:
- # If there is no EP in input arguments, we deduce EP from
- # graph_module's outputs. Those outputs may come from
- # FakeTensorProp or Dynamo's built-in symbolic shape inference.
- selected_eps = _sort_eps((*eps_from_graph_module, self.ep))
- else:
- # No EP found in inputs and outputs, let's use default.
- selected_eps = (self.ep,)
-
- onnx_session = _create_onnx_session(onnx_proto, selected_eps, self.session_options)
- # Cache ORT session. It's reused for the same "graph_module".
- # Generate ONNX model and extract its input and output names.
- onnx_model = _create_onnx_model(onnx_proto)
- # TODO(wechi): ORT session should provide a API to extract
- # input and output names from the underlying model.
- input_names = tuple(input.name for input in onnx_model.graph.input)
- output_names = tuple(output.name for output in onnx_model.graph.output)
- input_devices = _get_onnx_devices(args)
- # Cache devices for inputs and outputs. They are used to invoke
- # ORT session. Output devices indicate where (e.g., GPU or CPU)
- # to store outputs
- if isinstance(prim_outputs, tuple):
- output_devices = _get_onnx_devices(prim_outputs)
- else:
- output_devices = _get_onnx_devices((prim_outputs,))
-
- execution_info_per_session = OrtExecutionInfoPerSession(
- session=onnx_session,
- input_names=input_names,
- input_value_infos=tuple(input for input in onnx_model.graph.input),
- output_names=output_names,
- output_value_infos=tuple(output for output in onnx_model.graph.output),
- input_devices=input_devices,
- output_devices=output_devices,
- example_outputs=prim_outputs,
- )
-
- self._all_ort_execution_info.cache_session_execution_info(graph_module, execution_info_per_session)
-
- if isinstance(prim_outputs, tuple):
- assert all(isinstance(elem, torch.Tensor) for elem in prim_outputs)
- # ORT always returns a tuple of outputs. If the original is a tuple, just returning
- # ORT output is ok.
- _nvtx_range_push("run_onnx_session_with_ortvaluevector")
- onnx_outputs = _run_onnx_session_with_ortvaluevector(
- onnx_session,
- input_names,
- args,
- input_devices,
- output_names,
- prim_outputs,
- output_devices,
- self.preallocate_output,
- )
- _nvtx_range_pop()
- if self._assert_allclose_to_baseline:
- # Compute baseline.
- baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten")
- # Ensure every output tensor is close to the corresponding baseline.
- for onnx_output, baseline_output in zip(onnx_outputs, baseline_outputs):
- _assert_allclose_with_detailed_error_message(onnx_output, baseline_output)
- return onnx_outputs
- else:
- assert isinstance(prim_outputs, torch.Tensor)
- # ORT always returns a tuple of outputs. If the original output is a tensor,
- # ORT output's first element must be extracted and returned. Otherwise, type
- # mismatch may happen in downstream computation.
- onnx_outputs = _run_onnx_session_with_ortvaluevector(
- onnx_session,
- input_names,
- args,
- input_devices,
- output_names,
- (prim_outputs,),
- output_devices,
- self.preallocate_output,
- )
- assert len(onnx_outputs) == 1
- if self._assert_allclose_to_baseline:
- # Compute baseline.
- baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten")
- # Ensure output tensor is close to the corresponding baseline.
- _assert_allclose_with_detailed_error_message(onnx_outputs[0], baseline_outputs)
- return onnx_outputs[0]
-
- def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
- # FX graph based partitioning based on ONNX supported ops.
- if graph_module in self._partitioner_cache:
- partitioned_prim_graph_module = self._partitioner_cache[graph_module]
- else:
- prim_graph_module = graph_module
- # TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to.
- _replace_to_copy_with_to(prim_graph_module)
- partitioner = CapabilityBasedPartitioner(
- prim_graph_module, self._supported_ops, allows_single_node_partition=True
- )
- partitioned_prim_graph_module = partitioner.partition_and_fuse()
- self._partitioner_cache[graph_module] = partitioned_prim_graph_module
-
- # Overriding fused_module's __call__() function with ort_acclerated_call()
- # This loop goes through all graph partitions (each of them is an ONNX-representable graph)
- # and override their _wrappped_call function with _ort_accelerated_call.
- # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT.
- for node in partitioned_prim_graph_module.graph.nodes:
- # TODO: use a better way to identify fused submodule
- if node.op == "call_module" and "fused_" in node.name:
- fused_module = getattr(partitioned_prim_graph_module, node.name)
- # self.ort_acclerated_call is responsible for exporting graph to ONNX,
- # creating ORT session, and running ORT session.
- fused_module._wrapped_call = self._ort_acclerated_call
-
- return partitioned_prim_graph_module
-
- def __call__(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
- return self.compile(graph_module, args)
diff --git a/orttraining/orttraining/python/training/torchdynamo/register_backend.py b/orttraining/orttraining/python/training/torchdynamo/register_backend.py
deleted file mode 100644
index 3a49e85ab836d..0000000000000
--- a/orttraining/orttraining/python/training/torchdynamo/register_backend.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# -------------------------------------------------------------------------
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-# --------------------------------------------------------------------------
-
-from functorch.compile import min_cut_rematerialization_partition
-from torch._dynamo.backends.common import aot_autograd
-from torch.onnx._internal.exporter import ExportOptions
-
-from .ort_backend import OrtBackend
-
-
-def make_aot_ort(dynamic: bool = True):
- """Wrap OrtBackend as PyTorch's AOT compiler.
-
- Example usages:
- import torch
- from onnxruntime.training.torchdynamo.register_backend import make_aot_ort
- use_dynamic = True
- local_aot_ort, _ = make_aot_ort(dynamic = use_dynamic)
-
- @torch._dynamo.optimize(local_aot_ort, dynamic=use_dynamic)
- def foo(x: torch.Tensor):
- return torch.sigmoid(x)
-
- x = torch.rand(2, 2, dtype=torch.float)
- torch.testing.assert_close(torch.sigmoid(x), foo(x))
- """
- ort_backend = OrtBackend(onnx_exporter_options=ExportOptions(dynamic_shapes=dynamic))
- return (
- aot_autograd(
- fw_compiler=ort_backend,
- partition_fn=min_cut_rematerialization_partition,
- decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
- ),
- ort_backend,
- )
-
-
-# Wrap ORT as a compiler in Dynamo for training (i.e., when .backward is called).
-#
-# Under the hood, OrtBackend.compile is called inside functorch. See aot_function
-# and aot_module in aot_autograd.py in PyTorch repo for more details. Basically,
-# OrtBackend.compile is mapped to forward graph compiler, fw_compile, and backward
-# graph compiler, bw_compile, in aot_autograd.py.
-#
-# Example usage:
-# import torch
-# from onnxruntime.training.torchdynamo.register_backend import aot_ort
-# model = torch.nn.Linear(2, 2)
-# compiled_model = torch._dynamo.optimize(aot_ort)(model)
-# result = compiled_model(torch.rand(2, 2, dtype=torch.float)
-# result.sum().backward()
-#
-# DEFAULT_BACKEND should be the underlying compiler for ALL graphs if
-# the user uses ORT to accelerate PyTorch via Dynamo.
-# By using a global compiler for all graphs, cached compilation
-# results can be reused when encountering the identical graphs.
-aot_ort, DEFAULT_BACKEND = make_aot_ort(dynamic=False)
-
-# Similar to aot_ort but should be used with
-# torch._dynamo.optimize(dynamic_aot_ort, dynamic=True)
-# to enable dynamic shapes in ONNX graph.
-#
-# Similar to DEFAULT_BACKEND but DEFAULT_DYNAMIC_BACKEND enables dynamic shapes
-# when exporting FX graph to ONNX.
-# Note that this backend must be used with
-# torch._dynamo.optimize(DEFAULT_DYNAMIC_BACKEND, dynamic=True)
-# Without `dynamic=True`, the FX graph only contains static shapes, and results ONNX graph
-# with static shapes.
-dynamic_aot_ort, DEFAULT_DYNAMIC_BACKEND = make_aot_ort(dynamic=True)
-
-# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called).
-#
-# ort is usually faster than aot_ort for inference because the graphs generated by aot_autograd
-# mechanism are very different than the original graphs. Therefore, some ORT's graph transformers
-# are not applicable.
-#
-# Example usage:
-# import torch
-# from onnxruntime.training.torchdynamo.register_backend import ort
-# model = torch.nn.Linear(2, 2)
-# compiled_model = torch._dynamo.optimize(ort)(model)
-ort = DEFAULT_BACKEND
-
-# Similar to ort but should be used with
-# torch._dynamo.optimize(dynamic_ort, dynamic=True)
-# to enable dynamic shapes in ONNX graph.
-dynamic_ort = DEFAULT_DYNAMIC_BACKEND
diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py
index 2a7012787be6e..f0b6b9c5fba28 100644
--- a/orttraining/orttraining/test/python/orttraining_test_dort.py
+++ b/orttraining/orttraining/test/python/orttraining_test_dort.py
@@ -8,9 +8,22 @@
import torch.onnx._internal.exporter
from torch import nn
from torch.nn import functional as F
+from torch.onnx import ExportOptions
+from torch.onnx import _OrtBackend as OrtBackend
+from torch.onnx import _OrtBackendOptions as OrtBackendOptions
from torch.utils import _pytree
-from onnxruntime.training.torchdynamo.register_backend import aot_ort, dynamic_aot_ort, make_aot_ort, ort
+
+def make_local_backend(dynamic: bool = False, use_aot_autograd: bool = False):
+ ort_backend = OrtBackend(
+ options=OrtBackendOptions(
+ export_options=ExportOptions(
+ dynamic_shapes=dynamic,
+ ),
+ use_aot_autograd=use_aot_autograd,
+ )
+ )
+ return ort_backend
class TestTorchDynamoOrt(unittest.TestCase):
@@ -35,9 +48,7 @@ def elementwise_model(tensor_x: torch.Tensor):
tensor_q = tensor_p.sigmoid()
return tensor_q
- @torch._dynamo.optimize(aot_ort)
- def optimized_elementwise_model(tensor_x: torch.Tensor):
- return elementwise_model(tensor_x)
+ optimized_elementwise_model = torch.compile(elementwise_model, backend="onnxrt", dynamic=True)
def run(fun, list_x):
tensor_x = torch.tensor(list_x, dtype=torch.float32).requires_grad_()
@@ -77,9 +88,7 @@ def elementwise_model(tensor_x: torch.Tensor):
# With dynamic_shape=True, Dynamo sends FX graphs with dynamic
# shapes (e.g., batch size is a symbol "batch" instead of a fixed
# number) to OrtBackend.compile(...).
- @torch._dynamo.optimize(dynamic_aot_ort, dynamic=True)
- def optimized_elementwise_model(tensor_x: torch.Tensor):
- return elementwise_model(tensor_x)
+ optimized_elementwise_model = torch.compile(elementwise_model, backend="onnxrt", dynamic=True)
def run(fun, seed: torch.Tensor):
tensor_x = seed.detach().clone().requires_grad_()
@@ -125,8 +134,8 @@ def elementwise_model(tensor_x: torch.Tensor):
tensor_q = tensor_p.sigmoid()
return (tensor_q, (tensor_y, tensor_z))
- local_aot_ort, ort_backend = make_aot_ort(dynamic=True)
- cached = ort_backend._all_ort_execution_info.execution_info_per_graph_module
+ local_backend = make_local_backend(dynamic=True, use_aot_autograd=True)
+ cached = local_backend._all_ort_execution_info.execution_info_per_graph_module
# Before compilation, no graph is generated.
assert len(cached) == 0
@@ -135,7 +144,7 @@ def elementwise_model(tensor_x: torch.Tensor):
# With dynamic_shape=True, Dynamo sends FX graphs with dynamic
# shapes (e.g., batch size is a symbol "batch" instead of a fixed
# number) to OrtBackend.compile(...).
- @torch._dynamo.optimize(local_aot_ort, dynamic=True)
+ @torch._dynamo.optimize(local_backend, dynamic=True)
def optimized_elementwise_model(tensor_x: torch.Tensor):
return elementwise_model(tensor_x)
@@ -207,9 +216,8 @@ def elementwise_model(tensor_x: torch.Tensor):
tensor_q = tensor_p.relu()
return tensor_q
- @torch._dynamo.optimize(ort)
- def optimized_elementwise_model(tensor_x: torch.Tensor):
- return elementwise_model(tensor_x)
+ local_backend = make_local_backend(dynamic=True, use_aot_autograd=False)
+ optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True)
def run(fun, list_x):
tensor_x = torch.tensor(list_x, dtype=torch.float32).requires_grad_()
@@ -237,9 +245,7 @@ def copy_copy_copy(tensor_x: torch.Tensor):
)
return tensor_x1, tensor_x2, tensor_x3
- @torch._dynamo.optimize(aot_ort)
- def optimized_copy_copy_copy(tensor_x: torch.Tensor):
- return copy_copy_copy(tensor_x)
+ optimized_copy_copy_copy = torch.compile(copy_copy_copy, backend="onnxrt")
def run(fun, list_x):
tensor_x = torch.tensor(list_x, dtype=torch.float32)
@@ -265,7 +271,7 @@ def run_no_input_model():
def no_input_model():
return torch.ops.aten.full([2, 3], 1.5)
- @torch._dynamo.optimize(aot_ort)
+ @torch._dynamo.optimize("onnxrt")
def optimized_no_input_model():
return no_input_model()
@@ -291,9 +297,7 @@ def run_no_input_model():
def no_input_model():
return torch.ops.aten.full([2, 3], 1.5, device="cpu")
- @torch._dynamo.optimize(aot_ort)
- def optimized_no_input_model():
- return no_input_model()
+ optimized_no_input_model = torch.compile(no_input_model, backend="onnxrt")
def run(fun):
tensor_x = fun()
@@ -355,7 +359,8 @@ def run(model, tensor_x, tensor_y):
# Baseline.
loss, grads = run(model, tensor_x, tensor_y)
# ORT result.
- compiled_model = torch._dynamo.optimize(aot_ort)(model)
+ local_backend = make_local_backend(dynamic=False, use_aot_autograd=True)
+ compiled_model = torch.compile(model, backend=local_backend, dynamic=False)
loss_new, grads_new = run(compiled_model, tensor_x, tensor_y)
print(f"MNIST loss: {loss} (pytorch), {loss_new} (ort).")
diff --git a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py
index c2a6ed504a206..dfc62dba427e5 100644
--- a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py
+++ b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py
@@ -11,9 +11,10 @@
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch.library import Library
+from torch.onnx import _OrtBackend as OrtBackend
+from torch.onnx import _OrtBackendOptions as OrtBackendOptions
import onnxruntime
-from onnxruntime.training.torchdynamo.ort_backend import OrtBackend
# Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne
# in ONNX model executed by DORT.
@@ -112,16 +113,18 @@ def test_export_aten_mul_as_onnx_custom_op_and_run_ort(self):
# In order to use custom exporting function inside PyTorch-to-ONNX exporter used in DORT, create executor of ONNX model with custom `onnx_registry`.
ort_backend = OrtBackend(
- ep="CPUExecutionProvider",
- session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
- onnx_exporter_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry),
+ OrtBackendOptions(
+ preferred_execution_providers="CPUExecutionProvider",
+ ort_session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
+ export_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry),
+ )
)
# Wrap ORT executor as a Dynamo backend.
aot_ort = aot_autograd(
fw_compiler=ort_backend,
partition_fn=min_cut_rematerialization_partition,
- decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
+ decompositions=ort_backend._resolved_onnx_exporter_options.decomposition_table,
)
def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor):
@@ -169,19 +172,22 @@ def bar_impl(self: torch.Tensor) -> torch.Tensor:
# Create executor of ONNX model.
ort_backend = OrtBackend(
- ep="CPUExecutionProvider",
- session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
- onnx_exporter_options=torch.onnx.ExportOptions(onnx_registry=onnx_registry),
+ OrtBackendOptions(
+ preferred_execution_providers="CPUExecutionProvider",
+ ort_session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(),
+ export_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry),
+ )
)
+
# Allow torch.ops.foo.bar.default to be sent to DORT.
# _support_dict tells Dynamo which ops to sent to DORT.
- ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default)
+ ort_backend._supported_ops._support_dict[torch.ops.foo.bar.default] = None
# Wrap ORT executor as a Dynamo backend.
aot_ort = aot_autograd(
fw_compiler=ort_backend,
partition_fn=min_cut_rematerialization_partition,
- decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table,
+ decompositions=ort_backend._resolved_onnx_exporter_options.decomposition_table,
)
def one_foo(tensor_x: torch.Tensor):
diff --git a/setup.py b/setup.py
index 0c2eb19e82c87..685f0612e3762 100644
--- a/setup.py
+++ b/setup.py
@@ -464,7 +464,6 @@ def finalize_options(self):
"onnxruntime.training.experimental",
"onnxruntime.training.experimental.gradient_graph",
"onnxruntime.training.optim",
- "onnxruntime.training.torchdynamo",
"onnxruntime.training.ortmodule",
"onnxruntime.training.ortmodule.experimental",
"onnxruntime.training.ortmodule.experimental.json_config",
From 02b1ff5fa2c41dc026022ca29c9249628f71f026 Mon Sep 17 00:00:00 2001
From: Adrian Lizarraga
Date: Thu, 4 Jan 2024 13:32:48 -0800
Subject: [PATCH 15/20] [QNN EP] Support multithreaded inference of a single
session (#18981)
### Description
- Add mutex to protect QNN API calls for executing a graph and
extracting the corresponding profile data.
- Ensures QNN EP's execute function does not store unnecessary state
(i.e., input and output buffer pointers do not need to be stored as
class members.)
### Motivation and Context
Allow calling `session.Run()` from multiple threads when using QNN EP.
---
.../core/providers/qnn/builder/qnn_def.cc | 9 +
.../core/providers/qnn/builder/qnn_def.h | 1 +
.../core/providers/qnn/builder/qnn_model.cc | 107 ++++++----
.../core/providers/qnn/builder/qnn_model.h | 19 +-
.../test/providers/qnn/qnn_basic_test.cc | 194 +++++++++++++++++-
.../azure-pipelines/linux-qnn-ci-pipeline.yml | 8 +-
.../win-qnn-arm64-ci-pipeline.yml | 6 +-
.../azure-pipelines/win-qnn-ci-pipeline.yml | 4 +-
8 files changed, 292 insertions(+), 56 deletions(-)
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc
index a77ac16cf624b..55e72670a6971 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc
@@ -89,6 +89,15 @@ void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector
}
}
+void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size) {
+ if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) {
+ qnn_tensor.v1.clientBuf.data = buf_data;
+ qnn_tensor.v1.clientBuf.dataSize = buf_size;
+ } else {
+ ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version);
+ }
+}
+
void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size) {
if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) {
qnn_tensor.v1.clientBuf.dataSize = client_buf_size;
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h
index f6a3b1bd360ec..c202f2bf79c57 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_def.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h
@@ -100,6 +100,7 @@ void SetQnnTensorDim(Qnn_Tensor_t& qnn_tensor, const std::vector& dime
void SetQnnTensorMemType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorMemType_t mem_type);
void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf);
void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf);
+void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size);
void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size);
void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data);
void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params);
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc
index fd3a95b5f1f78..869d9326d9232 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc
@@ -166,14 +166,14 @@ Status QnnModel::FinalizeGraphs() {
Status QnnModel::SetupQnnInputOutput() {
LOGS(logger_, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name();
- auto result = SetupTensors(qnn_inputs_, graph_info_->InputTensors());
+ auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors());
if (Status::OK() != result) {
LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!");
}
- result = SetupTensors(qnn_outputs_, graph_info_->OutputTensors(), false);
+ result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false);
if (Status::OK() != result) {
LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!");
@@ -186,8 +186,8 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {
LOGS(logger_, VERBOSE) << "QnnModel::ExecuteGraphs";
const size_t num_inputs = context.GetInputCount();
const size_t num_outputs = context.GetOutputCount();
- ORT_RETURN_IF_NOT(qnn_inputs_.size() <= num_inputs, "Inconsistent input sizes");
- ORT_RETURN_IF_NOT(qnn_outputs_.size() == num_outputs, "Inconsistent output sizes");
+ ORT_RETURN_IF_NOT(qnn_input_infos_.size() <= num_inputs, "Inconsistent input sizes");
+ ORT_RETURN_IF_NOT(qnn_output_infos_.size() == num_outputs, "Inconsistent output sizes");
using namespace qnn::utils;
auto TensorDataSize = [&](auto ort_tensor) -> size_t {
@@ -198,49 +198,67 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) {
return element_size * length;
};
- for (auto& qnn_input_tensor : qnn_inputs_) {
- const std::string& model_input_name(GetQnnTensorName(qnn_input_tensor));
- auto index = GetOrtInputIndex(model_input_name);
- LOGS(logger_, VERBOSE) << "model_input = " << model_input_name << " index = " << index;
- auto ort_input_tensor = context.GetInput(index);
- auto qnn_tensor_size = GetQnnTensorClientBuf(qnn_input_tensor).dataSize;
+ std::vector qnn_inputs;
+ qnn_inputs.reserve(qnn_input_infos_.size());
+
+ for (const auto& qnn_input_info : qnn_input_infos_) {
+ LOGS(logger_, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName()
+ << " index = " << qnn_input_info.ort_index;
+ auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index);
auto ort_tensor_size = TensorDataSize(ort_input_tensor);
- LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_tensor_size << "Ort tensor size: " << ort_tensor_size;
- ORT_ENFORCE(qnn_tensor_size == ort_tensor_size,
+ LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size
+ << "Ort tensor size: " << ort_tensor_size;
+ ORT_ENFORCE(qnn_input_info.tensor_byte_size == ort_tensor_size,
"ORT Tensor data size does not match QNN tensor data size.");
- SetQnnTensorClientBufData(qnn_input_tensor,
- const_cast(ort_input_tensor.GetTensorData()));
+
+ qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor());
+ SetQnnTensorClientBuf(qnn_inputs.back(),
+ const_cast(ort_input_tensor.GetTensorData()), qnn_input_info.tensor_byte_size);
}
- for (auto& qnn_output_tensor : qnn_outputs_) {
- const std::string& model_output_name(GetQnnTensorName(qnn_output_tensor));
- auto index = GetOutputIndex(model_output_name);
- LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << index;
- const auto& output_info = GetOutputInfo(model_output_name);
- const std::vector& output_shape = output_info->shape_;
- auto output_tensor = context.GetOutput(index, output_shape.data(), output_shape.size());
- auto qnn_tensor_size = GetQnnTensorClientBuf(qnn_output_tensor).dataSize;
- auto ort_tensor_size = TensorDataSize(output_tensor);
- LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_tensor_size << "Ort tensor size: " << ort_tensor_size;
- ORT_ENFORCE(qnn_tensor_size == ort_tensor_size,
+ std::vector qnn_outputs;
+ qnn_outputs.reserve(qnn_output_infos_.size());
+
+ for (auto& qnn_output_info : qnn_output_infos_) {
+ const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName();
+ LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index;
+ const auto& ort_output_info = GetOutputInfo(model_output_name);
+ const std::vector& output_shape = ort_output_info->shape_;
+ auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size());
+ auto ort_tensor_size = TensorDataSize(ort_output_tensor);
+ LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size
+ << "Ort tensor size: " << ort_tensor_size;
+ ORT_ENFORCE(qnn_output_info.tensor_byte_size == ort_tensor_size,
"ORT Tensor data size does not match QNN tensor data size");
- SetQnnTensorClientBufData(qnn_output_tensor,
- const_cast(output_tensor.GetTensorData()));
+
+ qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor());
+ SetQnnTensorClientBuf(qnn_outputs.back(),
+ const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size);
}
LOGS(logger_, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name();
auto qnn_interface = qnn_backend_manager_->GetQnnInterface();
auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle();
Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR;
- execute_status = qnn_interface.graphExecute(graph_info_->Graph(),
- qnn_inputs_.data(),
- static_cast(qnn_inputs_.size()),
- qnn_outputs_.data(),
- static_cast(qnn_outputs_.size()),
- profile_backend_handle,
- nullptr);
- ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo());
+ {
+ // Acquire mutex before calling graphExecute and profiling APIs to support calling session.Run()
+ // from multiple threads.
+ std::lock_guard lock(graph_exec_mutex_);
+ execute_status = qnn_interface.graphExecute(graph_info_->Graph(),
+ qnn_inputs.data(),
+ static_cast(qnn_inputs.size()),
+ qnn_outputs.data(),
+ static_cast(qnn_outputs.size()),
+ profile_backend_handle,
+ nullptr);
+
+ // NOTE: This function returns immediately when profiling is disabled.
+ // Extracting profiling data can be expensive, but it is typically only enabled for debugging purposes
+ // and not in production. We can improve synchronization for event profiling if it becomes an issue.
+ ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo());
+ }
+
if (QNN_GRAPH_NO_ERROR != execute_status) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN graph execute error. Error code: ", execute_status);
}
@@ -262,14 +280,13 @@ Status QnnModel::GetQnnTensorDataLength(const std::vector& dims,
return Status::OK();
}
-// Setup details for Qnn_Tensor_t for execution
-// based on information in QnnTensorWrapper
-Status QnnModel::SetupTensors(std::vector& qnn_tensors,
+// Setup information for Qnn inputs/outputs used during execution.
+Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos,
const std::vector& tensor_wrappers,
bool is_input) {
size_t tensor_count = tensor_wrappers.size();
ORT_RETURN_IF(0 == tensor_count, "Zero tensor size!");
- qnn_tensors.resize(tensor_count);
+ qnn_tensor_infos.resize(tensor_count);
for (auto& tensor_wrapper : tensor_wrappers) {
size_t length = 0;
@@ -277,10 +294,14 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensors,
ORT_RETURN_IF_ERROR(GetQnnTensorDataLength(tensor_wrapper.GetTensorDims(),
tensor_wrapper.GetTensorDataType(),
length));
- auto tensor_name = tensor_wrapper.GetName();
- auto index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name);
- qnn_tensors[index] = tensor_wrapper.GetQnnTensor();
- SetQnnTensorClientBufSize(qnn_tensors[index], static_cast(length));
+ const auto& tensor_name = tensor_wrapper.GetName();
+ auto qnn_index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name);
+ auto ort_index = is_input ? GetOrtInputIndex(tensor_name) : qnn_index;
+
+ QnnTensorInfo& qnn_tensor_info = qnn_tensor_infos[qnn_index];
+ qnn_tensor_info.tensor_wrapper = &tensor_wrapper;
+ qnn_tensor_info.tensor_byte_size = static_cast(length);
+ qnn_tensor_info.ort_index = ort_index;
}
return Status::OK();
}
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h
index de4f872f73ccf..d0dd091cb1688 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h
@@ -3,8 +3,11 @@
#pragma once
+#include
+
#include "core/common/status.h"
#include "core/graph/graph_viewer.h"
+#include "core/platform/ort_mutex.h"
#include "core/providers/qnn/builder/qnn_def.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/qnn_backend_manager.h"
@@ -14,6 +17,12 @@
namespace onnxruntime {
namespace qnn {
+struct QnnTensorInfo {
+ const QnnTensorWrapper* tensor_wrapper = nullptr;
+ uint32_t tensor_byte_size = 0;
+ size_t ort_index = 0;
+};
+
class QnnModel {
public:
QnnModel(const logging::Logger& logger,
@@ -103,7 +112,8 @@ class QnnModel {
Qnn_DataType_t data_type,
size_t& data_length) const;
- Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, bool is_input = true);
+ Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers,
+ bool is_input = true);
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
@@ -126,9 +136,12 @@ class QnnModel {
std::vector output_names_;
std::unordered_map inputs_info_;
std::unordered_map outputs_info_;
- std::vector qnn_inputs_;
- std::vector qnn_outputs_;
+ std::vector qnn_input_infos_;
+ std::vector qnn_output_infos_;
QnnBackendType qnn_backend_type_ = QnnBackendType::CPU;
+
+ // Mutex acquired during graph execution to support multi-threaded inference of a single session.
+ OrtMutex graph_exec_mutex_;
};
} // namespace qnn
diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc
index 391d7bebc9589..f9064cad3fe12 100644
--- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc
+++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include
#include
+#include
+#include
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
@@ -287,8 +288,199 @@ TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) {
EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin"));
}
+struct ModelAndBuilder {
+ ModelAndBuilder(Graph& graph) : builder(graph) {}
+ std::string model_data;
+ ModelTestBuilder builder;
+};
+
+// Creates a model in memory. Input feeds and output names can be accessed from result.builder.
+static void CreateModelInMemory(std::unique_ptr& result,
+ const GetTestModelFn& model_build_fn,
+ const std::string& model_name,
+ int opset_version = 18) {
+ const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}};
+ auto& logging_manager = DefaultLoggingManager();
+
+ // Create float model and serialize it to a string.
+ onnxruntime::Model model(model_name, false, ModelMetaData(), PathString(),
+ IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {},
+ logging_manager.DefaultLogger());
+ result = std::make_unique(model.MainGraph());
+ model_build_fn(result->builder);
+ result->builder.SetGraphOutputs();
+ ASSERT_STATUS_OK(model.MainGraph().Resolve());
+ model.ToProto().SerializeToString(&result->model_data);
+}
+
+// Runs a session and verifies the outputs. Can be run by individual threads.
+static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds,
+ const std::vector& output_names,
+ const std::vector>& output_shapes,
+ const std::vector>& expected_values) {
+ std::vector fetches;
+ auto status = session.Run(run_options, feeds, output_names, &fetches);
+ ASSERT_TRUE(status.IsOK());
+
+ for (size_t i = 0; i < fetches.size(); i++) {
+ auto& tensor = fetches[i].Get();
+ TensorShape expected_shape(output_shapes[i]);
+ ASSERT_EQ(expected_shape, tensor.Shape());
+
+ gsl::span actual = tensor.DataAsSpan();
+ gsl::span expected(expected_values[i].data(), expected_values[i].size());
+ ASSERT_EQ(expected, actual);
+ }
+}
+
+// Returns a function that builds a float32 model that adds 3 tensors.
+static GetTestModelFn F32BuildAdd3Tensors(const TestInputDef& input0_def,
+ const TestInputDef& input1_def,
+ const TestInputDef& input2_def) {
+ return [input0_def, input1_def, input2_def](ModelTestBuilder& builder) {
+ NodeArg* input0 = MakeTestInput(builder, input0_def);
+ NodeArg* input1 = MakeTestInput(builder, input1_def);
+ NodeArg* input2 = MakeTestInput(builder, input1_def);
+
+ auto* add0_out = builder.MakeIntermediate();
+ builder.AddNode("Add", {input0, input1}, {add0_out});
+
+ auto* output = builder.MakeOutput();
+ builder.AddNode("Add", {add0_out, input2}, {output});
+ };
+}
+
+// Tests running a single session in multiple threads on the CPU backend.
+TEST_F(QnnCPUBackendTests, MultithreadSessionRun) {
+ std::unique_ptr model;
+ std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ std::vector shape = {1, 3, 2};
+ std::vector> output_shapes = {shape};
+ std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}};
+
+ CreateModelInMemory(model,
+ F32BuildAdd3Tensors(TestInputDef(shape, false, input_data),
+ TestInputDef(shape, false, input_data),
+ TestInputDef(shape, false, input_data)),
+ "add3.f32");
+
+ SessionOptions session_opts;
+ session_opts.session_logid = "logger0";
+
+ RunOptions run_opts;
+ run_opts.run_tag = session_opts.session_logid;
+
+ InferenceSession session_obj{session_opts, GetEnvironment()};
+ onnxruntime::ProviderOptions options;
+
+#if defined(_WIN32)
+ options["backend_path"] = "QnnCpu.dll";
+#else
+ options["backend_path"] = "libQnnCpu.so";
+#endif
+
+ auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts);
+ EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK());
+
+ auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size()));
+ ASSERT_TRUE(status.IsOK());
+ status = session_obj.Initialize();
+ ASSERT_TRUE(status.IsOK());
+
+ std::vector threads;
+ constexpr int num_threads = 5;
+
+ for (int i = 0; i < num_threads; i++) {
+ threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
+ model->builder.feeds_, model->builder.output_names_,
+ output_shapes, output_values));
+ }
+
+ for (auto& th : threads) {
+ th.join();
+ }
+}
+
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
+// Returns a function that builds a QDQ model that adds 3 tensors. Forces all scales and zero-points to be (1.0f, 0),
+// so it is only accurate when using non-fractional positive inputs.
+template
+static GetTestModelFn QDQBuildAdd3Tensors(const TestInputDef& input0_def,
+ const TestInputDef& input1_def,
+ const TestInputDef& input2_def) {
+ return [input0_def, input1_def, input2_def](ModelTestBuilder& builder) {
+ NodeArg* input0 = MakeTestInput(builder, input0_def);
+ NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, 1.0f, 0);
+ NodeArg* input1 = MakeTestInput(builder, input1_def);
+ NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, 1.0f, 0);
+ NodeArg* input2 = MakeTestInput(builder, input1_def);
+ NodeArg* input2_after_qdq = AddQDQNodePair(builder, input2, 1.0f, 0);
+
+ auto* add0_out = builder.MakeIntermediate();
+ builder.AddNode("Add", {input0_after_qdq, input1_after_qdq}, {add0_out});
+
+ auto* add0_out_dq = AddQDQNodePair(builder, add0_out, 1.0f, 0);
+
+ auto* add1_out = builder.MakeIntermediate();
+ builder.AddNode("Add", {add0_out_dq, input2_after_qdq}, {add1_out});
+
+ // op_output -> Q -> DQ -> output
+ AddQDQNodePairWithOutputAsGraphOutput(builder, add1_out, 1.0f, 0);
+ };
+}
+
+// Tests running a single session in multiple threads on the HTP backend.
+TEST_F(QnnHTPBackendTests, MultithreadSessionRun) {
+ std::unique_ptr model;
+ std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ std::vector shape = {1, 3, 2};
+ std::vector> output_shapes = {shape};
+ std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}};
+
+ CreateModelInMemory(model,
+ QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data),
+ TestInputDef(shape, false, input_data),
+ TestInputDef(shape, false, input_data)),
+ "add3.qdq");
+
+ SessionOptions session_opts;
+ session_opts.session_logid = "logger0";
+
+ RunOptions run_opts;
+ run_opts.run_tag = session_opts.session_logid;
+
+ InferenceSession session_obj{session_opts, GetEnvironment()};
+ onnxruntime::ProviderOptions options;
+
+#if defined(_WIN32)
+ options["backend_path"] = "QnnHtp.dll";
+#else
+ options["backend_path"] = "libQnnHtp.so";
+#endif
+
+ auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts);
+ EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK());
+
+ auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size()));
+ ASSERT_TRUE(status.IsOK());
+ status = session_obj.Initialize();
+ ASSERT_TRUE(status.IsOK());
+
+ std::vector threads;
+ constexpr int num_threads = 5;
+
+ for (int i = 0; i < num_threads; i++) {
+ threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts,
+ model->builder.feeds_, model->builder.output_names_,
+ output_shapes, output_values));
+ }
+
+ for (auto& th : threads) {
+ th.join();
+ }
+}
+
// Test shape inference of QDQ NHWC Resize operator (opset 18) that uses
// the sizes input. Use the QNN HTP backend.
TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) {
diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
index 07e69ff496720..d286c4f3a46fe 100644
--- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
@@ -86,7 +86,7 @@ jobs:
inputs:
script: |
./build/Release/onnx_test_runner -e qnn \
- -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \
+ -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \
cmake/external/onnx/onnx/backend/test/data/node
- task: CmdLine@2
@@ -94,7 +94,7 @@ jobs:
inputs:
script: |
./build/Release/onnx_test_runner -e qnn \
- -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \
+ -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnCpu.so" \
/data/float32_models
- task: CmdLine@2
@@ -102,7 +102,7 @@ jobs:
inputs:
script: |
./build/Release/onnx_test_runner -e qnn \
- -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \
+ -v -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \
/data/qdq_models
- task: CmdLine@2
@@ -110,5 +110,5 @@ jobs:
inputs:
script: |
./build/Release/onnx_test_runner -e qnn \
- -v -f -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \
+ -v -f -j 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \
/data/qdq_models/mobilenetv2-1.0_add_transpose_quant
diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml
index 5e35cbfed6692..6dc428d6606af 100644
--- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml
@@ -84,17 +84,17 @@ jobs:
displayName: 'Run unit tests'
- script: |
- .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node
+ .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
displayName: 'Run ONNX Tests'
- script: |
- .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" C:\data\float32_models
+ .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnCpu.dll" C:\data\float32_models
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
displayName: 'Run float32 model tests'
- script: |
- .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnHtp.dll" C:\data\qdq_models
+ .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\aarch64-windows-msvc\QnnHtp.dll" C:\data\qdq_models
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
displayName: 'Run QDQ model tests'
enabled: false
diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml
index 65b2924c8be60..fbec572fd346c 100644
--- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml
@@ -88,11 +88,11 @@ jobs:
displayName: 'Run unit tests'
- script: |
- .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node
+ .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
displayName: 'Run ONNX Tests'
- script: |
- .\$(BuildConfig)\onnx_test_runner -j 1 -c 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models
+ .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QNN_SDK_ROOT)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models
workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)'
displayName: 'Run float32 model tests'
From e10a8ae31feba949b682f2451268c0dc68589ba3 Mon Sep 17 00:00:00 2001
From: liqun Fu
Date: Thu, 4 Jan 2024 17:41:01 -0800
Subject: [PATCH 16/20] reduce max/min 20 (#17805)
### Description
reducemax/min have been updated in onnx(20). implement it in ort
### Motivation and Context
this is for ort1.17.0 release
---------
Signed-off-by: Liqun Fu
---
docs/OperatorKernels.md | 6 +-
.../providers/cpu/cpu_execution_provider.cc | 100 +++--
.../cpu/reduction/reduction_kernel_base.h | 40 ++
.../providers/cpu/reduction/reduction_ops.cc | 101 ++++-
.../providers/cpu/reduction/reduction_ops.h | 175 +++++---
.../providers/cuda/reduction/reduction_ops.h | 2 +-
onnxruntime/test/onnx/TestCase.cc | 2 +-
.../cpu/reduction/reduction_ops_test.cc | 398 +++++++++++++++++-
.../onnx_backend_test_series_filters.jsonc | 55 ++-
9 files changed, 737 insertions(+), 142 deletions(-)
create mode 100644 onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index e401baae2d803..f985cf10ded60 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -278,7 +278,8 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
-|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**
or
*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|ReduceMax|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**
or
*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
@@ -287,7 +288,8 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(int32)|
-|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**
or
*in* data:**T**
*out* reduced:**T**|18+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|ReduceMin|*in* data:**T**
*in* axes:**tensor(int64)**
*out* reduced:**T**
or
*in* data:**T**
*out* reduced:**T**|20+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
+|||[18, 19]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||[13, 17]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint8)|
|||11|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 1390f60243174..f60c7ddac5c05 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -850,21 +850,21 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceLogSumExp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceLogSumExp);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMax);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMax);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMax);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, double, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, float, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, double, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int32_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int64_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, int8_t, ReduceMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, 19, uint8_t, ReduceMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, float, ReduceProd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 18, int64_t, ReduceProd);
@@ -960,6 +960,20 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, bool, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, bool, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int32_t, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int64_t, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, int8_t, ReduceMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, uint8_t, ReduceMin);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample);
@@ -2263,36 +2277,36 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
ReduceLogSumExp)>,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo