Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/WindowsAI' into user/linneamay/r…
Browse files Browse the repository at this point in the history
…esize-18
  • Loading branch information
Linnea May committed Jan 4, 2024
2 parents 6e9bbb3 + 7401b66 commit fd374bf
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
<dd>Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size</dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i

#if !defined(ORT_MINIMAL_BUILD)
// TODO: Enable 16-bit types in selector when QLinearMatMul and MatMulInteger support 16-bit.
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::MatMulSelector>(is_int8_allowed);
std::vector<const char*> providers = {kCpuExecutionProvider};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::MatMulSelector>(providers, is_int8_allowed);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
std::move(selector),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,10 @@ class WhereSelector : public BaseSelector {
// 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not
class MatMulSelector : public BaseSelector {
public:
MatMulSelector(bool int8_allowed, bool allow_16bit = false)
MatMulSelector(gsl::span<const char*> compatible_providers, bool int8_allowed, bool allow_16bit = false)
: BaseSelector(std::make_unique<MatMulNodeGroupSelector>(int8_allowed, /*matmulintegertofloat_allowed*/ true,
allow_16bit)) {}
allow_16bit),
compatible_providers) {}
};

// Input: DQ nodes for A, B and optional C
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_

constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING,
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
13,
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},

{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
Expand Down

0 comments on commit fd374bf

Please sign in to comment.