diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fce334f719239..3dc6e768e9262 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -272,12 +272,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider }; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 1eb091b9b8ce1..3c0f49f3d2d49 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -2191,6 +2191,11 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> { using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT> +{ + using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC; +}; // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as // the first argument. @@ -2539,6 +2544,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return std::invoke(std::forward(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward(args)...); #pragma warning(push) #pragma warning(disable: 4063) @@ -2700,7 +2707,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; - case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 71e808160163d..9080e8efedab9 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1726,7 +1726,7 @@ using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; -using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulIntegerToFloatHelper; +using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 0d5dab35826c1..c7f2ec89fb817 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -39,12 +39,8 @@ static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, cons for (int64_t n = 0; n < N; n++) { float sum = 0.0f; for (int64_t k = 0; k < K; k++) { - float A_dequantized = has_zp ? - (static_cast(A_data[m * K + k]) - static_cast(A_zero_point[0])) * A_scale[0] : - A_data[m * K + k] * A_scale[0]; - float B_dequantized = has_zp ? - (static_cast(B_data[k * N + n]) - static_cast(B_zero_point[n])) * B_scale[n] : - B_data[k * N + n] * B_scale[n]; + float A_dequantized = has_zp ? (static_cast(A_data[m * K + k]) - static_cast(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0]; + float B_dequantized = has_zp ? (static_cast(B_data[k * N + n]) - static_cast(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n]; sum += A_dequantized * B_dequantized; } @@ -81,8 +77,7 @@ void TestMatMulIntegerToFloat(bool is_matrix_b_constant, std::vector tmp_B_data; tmp_B_data = random.Uniform(B_dims, - (constexpr(std::is_same_v)) ? - std::numeric_limits::lowest()/2 : std::numeric_limits::lowest(), + (constexpr(std::is_same_v)) ? std::numeric_limits::lowest() / 2 : std::numeric_limits::lowest(), std::numeric_limits::max() / 2); std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType { return static_cast(v); @@ -148,7 +143,6 @@ void TestMatMulIntegerToFloat(bool is_matrix_b_constant, } else { test.Run(); } - } template @@ -161,13 +155,6 @@ void RunMatMulIntegerToFloatTest() { ); TestMatMulIntegerToFloat( - A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ true, /*is_matrix_b_constant*/ false, /*per_column*/ HasZeroPoint, /*has_zp*/