Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge WindowsAI to main #18983

Merged
merged 25 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c3d96a7
Update DML version to 1.13.0 (#18978)
jeffbloo Jan 3, 2024
9bbe425
Register LPpool18 and AvgPool 19 (#16880)
zhangxiang1993 Jul 28, 2023
9ff5e3b
Add QLinearConcat for DML EP (#16971) (#18268)
raoanag Nov 3, 2023
cb7f28a
Register Resize for INT8 and UINT8 (#18252)
raoanag Nov 3, 2023
dcfff10
Enable QLinearAveragePooling DML EP (#17384) (#18240)
raoanag Nov 6, 2023
d5f3aae
Utilize DML constant input graph node (#18267)
raoanag Nov 18, 2023
531e875
Avoid command list reset in common case of re-used command list execu…
raoanag Nov 18, 2023
a1000a0
Enable GEMM activation fusions on MCDM (#18372)
raoanag Nov 18, 2023
5c28334
Filter activation fusions on MCDM (#18371)
raoanag Nov 18, 2023
613fdce
Create ring buffer for re-used command lists (#18368)
raoanag Nov 20, 2023
7f9e6c4
readd npu enumeration (#18437) (#18518)
raoanag Nov 20, 2023
e8209ce
CP 7fd1ce95a4e4f3c2b6152dfc2b1807a983ef45e5 (#18560)
smk2007 Nov 22, 2023
623d957
register resize with uint8/int8 support (#18647)
zhangxiang1993 Dec 1, 2023
c1ec3c3
User/chrila/fix dml dx12 warning (#18746)
chrilaMSFT Dec 9, 2023
107d749
[DirectML EP] Add DML EP registration for Col2Im (#17786)
smk2007 Dec 9, 2023
d2f7a5b
Cherry pick fix constant pow (#18785)
Jamather Dec 12, 2023
b2f81c8
Hide Col2Im registration behind DML_TARGET_VERSION 6300 (#18829)
smk2007 Dec 14, 2023
bdaeebd
Fix bug in DML EP ExecuteCommandList fast path and simplify design (#…
jeffbloo Dec 18, 2023
70d3f68
De-duplicate 1D scale and zero point tensors to scalars in DML kernel…
tbqh Jan 2, 2024
ee60e3a
Limit size of constant nodes creates by DML EP following deduplicatio…
jeffbloo Jan 2, 2024
56fcea9
Enable QDQ quantization for DML EP (#18367)
raoanag Jan 3, 2024
70a6f81
Port attention query fix from b2768bbf2347b4ea564f2a937f9f48987620ddf0
jeffbloo Jan 4, 2024
f4ad940
Disable MatMul QDQ selector on DML EP until MatMulIntegerToFloat is r…
jeffbloo Jan 4, 2024
8ea3e68
Update ContribOperators.md
jeffbloo Jan 4, 2024
7401b66
Update OperatorKernels.md
jeffbloo Jan 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.12.1" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.0" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.12.1" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.0" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
6 changes: 3 additions & 3 deletions cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.1)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down Expand Up @@ -72,12 +72,11 @@ else()
if (dml_EXTERNAL_PROJECT)
set(dml_preset_config $<IF:$<CONFIG:Debug>,debug,release>)
set(dml_preset_name ${onnxruntime_target_platform}-win-redist-${dml_preset_config})
target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1)
include(ExternalProject)
ExternalProject_Add(
directml_repo
GIT_REPOSITORY https://dev.azure.com/microsoft/WindowsAI/_git/DirectML
GIT_TAG d460f0f46967bea878786f1bed69487692c779bf
GIT_TAG a5312f72c51864b4d705ac62d25d08bcd88c4fb1
GIT_SHALLOW OFF # not allowed when GIT_TAG is a commit SHA, which is preferred (it's stable, unlike branches)
GIT_PROGRESS ON
BUILD_IN_SOURCE ON
Expand All @@ -94,6 +93,7 @@ else()
target_link_libraries(DirectML INTERFACE ${directml_install_path}/lib/DirectML.lib)
add_dependencies(DirectML directml_repo-install)
include_directories(BEFORE ${directml_install_path}/include)
target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1)
else()
include_directories(BEFORE ${dml_INCLUDE_DIR})
set(DML_PACKAGE_DIR ${dml_INCLUDE_DIR}/..)
Expand Down
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
<dd>Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size</dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
Expand Down
15 changes: 10 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,8 @@ Do not modify directly.*
|Asinh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|Atan|*in* input:**T**<br> *out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Atanh|*in* input:**T**<br> *out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|19+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|||7+|**T** = tensor(float), tensor(float16)|
|BatchNormalization|*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* input_mean:**U**<br> *in* input_var:**U**<br> *out* Y:**T**<br> *out* running_mean:**U**<br> *out* running_var:**U**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *in* mean:**T**<br> *in* var:**T**<br> *out* Y:**T**<br> *out* mean:**T**<br> *out* var:**T**<br> *out* saved_mean:**T**<br> *out* saved_var:**T**<br><br>or<br><br>*in* X:**T**<br> *in* scale:**T1**<br> *in* B:**T1**<br> *in* input_mean:**T2**<br> *in* input_var:**T2**<br> *out* Y:**T**<br> *out* running_mean:**T2**<br> *out* running_var:**T2**|15+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -951,7 +952,7 @@ Do not modify directly.*
|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||7+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Dropout|*in* data:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**<br> *out* mask:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**<br> *out* mask:**T1**|7+|**T** = tensor(float), tensor(float16)|
|DynamicQuantizeLinear|*in* x:**T1**<br> *out* y:**T2**<br> *out* y_scale:**tensor(float)**<br> *out* y_zero_point:**T2**|11+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|DynamicQuantizeLinear|*in* x:**T1**<br> *out* y:**T2**<br> *out* y_scale:**tensor(float)**<br> *out* y_zero_point:**T2**|11+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|Einsum|*in* Inputs:**T**<br> *out* Output:**T**|12+|**T** = tensor(float), tensor(float16)|
|Elu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|Equal|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(bool)|
Expand Down Expand Up @@ -1030,7 +1031,8 @@ Do not modify directly.*
|||11+|**T** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)|
|LpNormalization|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|LpPool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|LpPool|*in* X:**T**<br> *out* Y:**T**|18+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|
|||2+|**T** = tensor(float), tensor(float16)|
|MatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||9+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1145,8 +1147,8 @@ Do not modify directly.*
|Reshape|*in* data:**T**<br> *in* shape:**tensor(int64)**<br> *out* reshaped:**T**<br><br>or<br><br>*in* data:**T**<br> *out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|||11+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||11+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
Expand Down Expand Up @@ -1247,6 +1249,9 @@ Do not modify directly.*
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConcat|*in* Y_scale:**TF**<br> *in* Y_zero_point:**T8**<br> *in* inputs:**TV**<br> *out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)<br/> **TF** = tensor(float)<br/> **TV** = tensor(float), tensor(int8), tensor(uint8)|
|QLinearGlobalAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@
std::unique_ptr<Action> action = std::make_unique<QDQ::UnaryReplaceWithQLinear>(kMSDomain);

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when unary QLinear* ops support 16-bit.
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::UnarySelector>();
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::UnarySelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"AveragePool", {}},
{"LeakyRelu", {}},
Expand All @@ -123,20 +123,43 @@
void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
// 4 nodes. 2 x DQ for inputs, target, Q
// Replace with internal QLinear version of operator. Delete all original nodes.
const std::string action_name{"2DQ"};
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);
{
const std::string action_name{"2DQ"};
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit.
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Add", {}},
{"Mul", {}}},
std::move(selector),
std::move(action));
// TODO: Enable 16-bit types in selector when binary QLinear* ops support 16-bit.

Check warning on line 131 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc#L131

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:131:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Add", {}},
{"Mul", {}}},
std::move(selector),
std::move(action));

#else
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
#endif
}

#ifdef USE_DML
{
const std::string action_name{"2DQ_DML"};
std::unique_ptr<Action> action = std::make_unique<QDQ::BinaryReplaceWithQLinear>(kMSDomain);

#if !defined(ORT_MINIMAL_BUILD)
std::vector<const char*> providers = {kDmlExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::BinarySelector>(providers);

qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Add", {}}},
std::move(selector),
std::move(action));

#else
#error "ORT_MINIMAL_BUILD and USE_DML are not expected simultaneously. This would require RegisterAction to be called here."

Check warning on line 160 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc#L160

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:160:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
#endif
}
#endif
}

Expand Down Expand Up @@ -193,7 +216,8 @@

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit.
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::MatMulSelector>(is_int8_allowed);
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::MatMulSelector>(providers, is_int8_allowed);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
std::move(selector),
Expand All @@ -214,8 +238,8 @@
std::unique_ptr<Action> action = std::make_unique<QDQ::GemmReplaceWithQuant>();

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QGemm supports 16-bit.
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>();
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Gemm", {}}},
std::move(selector),
Expand All @@ -235,8 +259,9 @@
std::unique_ptr<Action> action = std::make_unique<QDQ::WhereReplaceWithQLinear>();

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearWhere supports 16-bit.
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::WhereSelector>();

std::vector<const char*> providers = {kCpuExecutionProvider};

Check warning on line 263 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc#L263

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:263:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::WhereSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Where", {}}},
std::move(selector),
Expand Down Expand Up @@ -271,8 +296,8 @@
"QDQSelectorActionTransformer",
CreateSelectorActionRegistry(is_int8_allowed),
apply_context,
// this transformer is only compatible with the CPU EP
{kCpuExecutionProvider}} {
// this transformer is only compatible with the CPU and DML EP
{kCpuExecutionProvider, kDmlExecutionProvider}} {
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ std::optional<NodeGroup> NodeGroupSelector::GetQDQSelection(const GraphViewer& g
}

std::optional<NodesToOptimizeIndices> BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const {
const std::string_view node_ep = node.GetExecutionProviderType();

if (!compatible_providers_.empty() &&
std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) {
return std::nullopt;
}

const auto qdq_group = node_group_selector_->GetQDQSelection(graph_viewer, node);
if (!qdq_group.has_value()) {
return std::nullopt;
Expand Down
Loading
Loading