Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DML EP to accept broadcasted tensor of size 1 to match CPU #19081

Merged
merged 5 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@
{
ML_CHECK_VALID_ARGUMENT(axis < outputShapeDimCount);
uint32_t broadcastAxisLength = outputShape[axis];
ML_CHECK_VALID_ARGUMENT(inputTensorShape[0] == broadcastAxisLength);
ML_CHECK_VALID_ARGUMENT((inputTensorShape[0] == broadcastAxisLength) ||
fdwr marked this conversation as resolved.
Show resolved Hide resolved
// treat as broadcast dimension to match CPU behavior
(inputTensorShape[0] == 1));

Check warning on line 563 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp#L563

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp:563:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
fdwr marked this conversation as resolved.
Show resolved Hide resolved
inputTensorShape.insert(inputTensorShape.begin(), axis, 1);
inputTensorShape.insert(inputTensorShape.end(), outputShapeDimCount - 1 - axis, 1);
}
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/test/contrib_ops/quantize_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int32_cpu) {
test.Run();
}

TEST(DequantizeLinearOpTest, DequantizeLinearOpTest_BroadcastTensorOfOne) {
OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);

test.AddInput<int32_t>("x", {4}, {-30, -3, 100, 127});
test.AddInput<float>("x_scale", {1}, {2.0f}, true);
test.AddInput<int32_t>("zero_point", {1}, {0}, true);
test.AddOutput<float>("y", {4}, {-60.f, -6.f, 200.f, 254.f});
test.Run();
}

#ifdef USE_CUDA
TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_half_uint8) {
OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ TEST(DequantizeLinearOpTest, Int32) {
test.Run();
}

TEST(DequantizeLinearOpTest_BroadcastTensor, Int32) {
OpTester test("DequantizeLinear", 13);
test.AddInput<int32_t>("x", {4}, {-30, -3, 100, 127});
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("x_scale", {1}, {2.0f});
test.AddInput<int32_t>("x_zero_point", {1}, {0});
test.AddOutput<float>("y", {4}, {-60.f, -6.f, 200.f, 254.f});
test.Run();
}

// 2d inputs
TEST(DequantizeLinearOpTest, 2D) {
OpTester test("DequantizeLinear", 10);
Expand Down
Loading