Skip to content

Commit

Permalink
Fixed computed output shape of matmul op
Browse files Browse the repository at this point in the history
  • Loading branch information
mingmingtasd committed Nov 24, 2022
1 parent 756228c commit b57cde9
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/webnn/native/ops/Binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ namespace webnn::native::op {
}
outputShape = {1};
}
if (rankA == 2 && rankB == 1) {
if (inputShapeA[1] != inputShapeB[0]) {
if (rankA >= 2 && rankB == 1) {
if (inputShapeA[rankA - 1] != inputShapeB[0]) {
return DAWN_VALIDATION_ERROR("The input shapes are incompatible.");
}
outputShape = {inputShapeA[0], 1};
outputShape = std::move(inputShapeA);
outputShape[rankA - 1] = 1;
}
if (rankA == 1 && rankB == 2) {
if (inputShapeA[0] != inputShapeB[0]) {
if (rankA == 1 && rankB >= 2) {
if (inputShapeA[0] != inputShapeB[rankB - 2]) {
return DAWN_VALIDATION_ERROR("The input shapes are incompatible.");
}
outputShape = {1, inputShapeB[1]};
outputShape = std::move(inputShapeB);
outputShape[rankB - 2] = 1;
}
if (rankA >= 2 && rankB >= 2) {
if (inputShapeA[rankA - 1] != inputShapeB[rankB - 2]) {
Expand Down

0 comments on commit b57cde9

Please sign in to comment.