diff --git a/src/webnn/native/ops/Binary.cpp b/src/webnn/native/ops/Binary.cpp index 82bb8b785..b1e3ef360 100644 --- a/src/webnn/native/ops/Binary.cpp +++ b/src/webnn/native/ops/Binary.cpp @@ -54,17 +54,19 @@ namespace webnn::native::op { } outputShape = {1}; } - if (rankA == 2 && rankB == 1) { - if (inputShapeA[1] != inputShapeB[0]) { + if (rankA >= 2 && rankB == 1) { + if (inputShapeA[rankA - 1] != inputShapeB[0]) { return DAWN_VALIDATION_ERROR("The input shapes are incompatible."); } - outputShape = {inputShapeA[0], 1}; + outputShape = std::move(inputShapeA); + outputShape[rankA - 1] = 1; } - if (rankA == 1 && rankB == 2) { - if (inputShapeA[0] != inputShapeB[0]) { + if (rankA == 1 && rankB >= 2) { + if (inputShapeA[0] != inputShapeB[rankB - 2]) { return DAWN_VALIDATION_ERROR("The input shapes are incompatible."); } - outputShape = {1, inputShapeB[1]}; + outputShape = std::move(inputShapeB); + outputShape[rankB - 2] = 1; } if (rankA >= 2 && rankB >= 2) { if (inputShapeA[rankA - 1] != inputShapeB[rankB - 2]) {