Skip to content

Commit

Permalink
Resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
raoanag committed Feb 22, 2024
1 parent b40a236 commit 2e63289
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 24 deletions.
12 changes: 6 additions & 6 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kDmlExecutionProvider};
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kAclExecutionProvider,
onnxruntime::kArmNNExecutionProvider,
onnxruntime::kJsExecutionProvider };
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kAclExecutionProvider,
onnxruntime::kArmNNExecutionProvider,
onnxruntime::kJsExecutionProvider};
const InlinedHashSet<std::string_view> cpu_dml_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kDmlExecutionProvider};
#ifdef MLAS_TARGET_AMD64_IX86
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2191,6 +2191,11 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION>
{
using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
{
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
};

// Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as
// the first argument.
Expand Down Expand Up @@ -2539,6 +2544,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_GELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);

#pragma warning(push)
#pragma warning(disable: 4063)
Expand Down Expand Up @@ -2700,7 +2707,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2";
case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1";
case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1";
case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT";
default:
assert(false);
return "<unknown>";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,7 @@ using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MatMul = MatMulHelper;
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulIntegerToFloatHelper;
using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper;
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper;
Expand Down
19 changes: 3 additions & 16 deletions onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@ static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, cons
for (int64_t n = 0; n < N; n++) {
float sum = 0.0f;
for (int64_t k = 0; k < K; k++) {
float A_dequantized = has_zp ?
(static_cast<int>(A_data[m * K + k]) - static_cast<int>(A_zero_point[0])) * A_scale[0] :
A_data[m * K + k] * A_scale[0];
float B_dequantized = has_zp ?
(static_cast<int>(B_data[k * N + n]) - static_cast<int>(B_zero_point[n])) * B_scale[n] :
B_data[k * N + n] * B_scale[n];
float A_dequantized = has_zp ? (static_cast<int>(A_data[m * K + k]) - static_cast<int>(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0];
float B_dequantized = has_zp ? (static_cast<int>(B_data[k * N + n]) - static_cast<int>(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n];

sum += A_dequantized * B_dequantized;
}
Expand Down Expand Up @@ -81,8 +77,7 @@ void TestMatMulIntegerToFloat(bool is_matrix_b_constant,

std::vector<WType> tmp_B_data;
tmp_B_data = random.Uniform<WType>(B_dims,
(constexpr(std::is_same_v<WType, int8_t>)) ?
std::numeric_limits<int8_t>::lowest()/2 : std::numeric_limits<uint8_t>::lowest(),
(constexpr(std::is_same_v<WType, int8_t>)) ? std::numeric_limits<int8_t>::lowest() / 2 : std::numeric_limits<uint8_t>::lowest(),
std::numeric_limits<WType>::max() / 2);
std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType {
return static_cast<WType>(v);
Expand Down Expand Up @@ -148,7 +143,6 @@ void TestMatMulIntegerToFloat(bool is_matrix_b_constant,
} else {
test.Run();
}

}

template <typename IType, typename WType, typename OType, bool HasZeroPoint, bool HasBias>
Expand All @@ -161,13 +155,6 @@ void RunMatMulIntegerToFloatTest() {
);

TestMatMulIntegerToFloat<IType, WType, OType>(
A_dims,
B_dims,
model_path,
true, /*is_matrix_b_constant*/
false, /*per_column*/
HasZeroPoint, /*has_zp*/
HasBias /*has_bias*/
true, /*is_matrix_b_constant*/
false, /*per_column*/
HasZeroPoint, /*has_zp*/
Expand Down

0 comments on commit 2e63289

Please sign in to comment.