Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider adopting new broadcasting rules #590

Closed
a-sully opened this issue Feb 28, 2024 · 2 comments · Fixed by #743
Closed

Consider adopting new broadcasting rules #590

a-sully opened this issue Feb 28, 2024 · 2 comments · Fixed by #743
Assignees

Comments

@a-sully
Copy link
Contributor

a-sully commented Feb 28, 2024

Reviving a discussion from #534, which defined shape broadcasting but didn't touch on the question of what WebNN's shape broadcasting rules should be

WebNN currently specifies two kinds of broadcasting rules: unidirectional and bidirectional

Of the popular ML frameworks, ONNX (which WebNN is largely based on) appears to be an outlier in making a distinction between "unidirectional" and "multidirectional" broadcasting. This distinction is not made by:

The "unidirectional broadcastable" constraint of some ONNX ops (e.g. prelu()) requires workarounds when exporting from other formats to ONNX - like in this example of using TVM to export PyTorch to ONNX: pytorch/pytorch#70570 (comment).

What should we do?

Option 1: Adopt Numpy's broadcasting rules

Rationale: Numpy's broadcasting rules are a standard across the industry. It seems reasonable to be what we expose to the web

Outcome: "bidirectional broadcasting" will be the only type of broadcasting exposed to the web. The user agent muse ensure that the constraints of the underlying framework - such as unidirectional broadcasting for ONNX (@fdwr has suggested that this is trivial), and lack of inferred broadcasting specifications for XLA (more on that below) - are satisfied.

Option 2: Adopt XLA's broadcasting rules

Rationale: The XLA Principles apply to WebNN, too:

The XLA language is as strict and explicit as possible, avoiding implicit "magical" features. Such features might make some computations slightly easier to define, but at the cost of more assumptions baked into user code that will be difficult to change in the long term. If necessary, implicit magical features can be added in client-level wrappers... With regard to broadcasting, XLA requires explicit broadcasting specifications on operations between arrays of different ranks. This is different from NumPy, which infers the specification when possible.

Outcome: Both "unidirectional broadcasting" and "bidirectional broadcasting" concepts would be removed from the WebNN spec. To facilitate explicit broadcasts, something like StableHLO's broadcast_in_dim op would need to be added to WebNN

Option 3: Keep the status quo

Rationale: It's the status quo

Outcome: No action needed regarding the current spec. However, all models ported to WebNN will need to abide by this "unidirectionally broadcastable" constraint which is specific to ONNX

@fdwr
Copy link
Collaborator

fdwr commented Mar 5, 2024

Yo Austin - if the spec is unclear, then yeah, it should be made so. Before my thoughts, let's first breakdown broadcasting into its three parts:

  • (1) rank alignment, which is a reshape of multiple inputs to the rank of the biggest input, filling in the new shape with 1's for the newly inserted dimensions. These are typically inserted on the leading edge, with existing dimensions as right-aligned (e.g. 2D to 4D, [3,4] -> [1,1,3,4]), but more generically you can find rare occurrences of left alignment and even middle axis alignment, such as with instanceNormalization's scale and bias in the decomposition algorithm (e.g. [7] -> [1,7,1,1]]). For elementwise ops though (add, mul, div...), right aligned ranks are the norm.
  • (2) new output size computation, which takes the maximum size along each corresponding dimension, either unidirectionally or bidirectionally (e.g. [2,3,1,5] and [1,1,4,5] -> [2,3,4,5]).
  • (3) dimension expansion (à la expand), which is the real work, a repetition of elements across the new dimension sizes.

For XLA, step 1 does not happen because it expects the ranks already match. Step 2 uses bidirectional broadcasting for the elementwise operators, and XLA's BroadCastInDim uses undirectional broadcasting of the input shape and expected output shape, even if they don't say it by name.

For NumPy, all 3 steps happen, step 1 is right aligned (though there are likely cases of middle aligned broadcasts too given axes, at least internally for processing), and step 2 is bidirectional, except in the case of its own broadcasting operator broadcast_to which uses unidirectional broadcasting even if it they don't say it by name (e.g. this works numpy.broadcast_to([1, 2, 3], (3, 3)), while this fails numpy.broadcast_to([1, 2, 3], (3, 1)) because the input shape is not undirectionally broadcastable to the output shape) .

And my thinking so far:

  • broadcasting should be kept in WebNN because of the significant perf benefits of not needing to expand to large temporaries first (e.g. mul(a3RowVector, b3ColVector) rather than mul(expand(reshape(a3RowVector, [3,1]), [3,3]), expand(reshape(b3ColVector, [1,3]), [3,3]))).
  • any WebNN operators that use broadcasting should be clear that they do, rather than something that implicitly happens for any operator (which is already the case AFAICS in WebNN). For WebNN, generally broadcasting is limited the elementwise operators (add, prelu, where, div...).
  • I'm receptive to additional WebNN API simplifications/constraints such as requiring the ranks to already be resolved by the caller before reaching WebNN (so two input tensors into add with different ranks would need to be trivially reshaped to the same rank first by the client). Note at least a few backends already require consistent ranks on inputs, including at least a few XNNPack operators and nearly all DirectML ops; and if Apple BNNS and MPS also have this consistent-same-rank requirement, then it could ease backend adoption (plus reduce testing complexity and spec'ese) if the caller took care of the right-alignment/left-alignment/middle-alignment rank coercion before calling WebNN. All backends should treat reshapes as light adjustments of the tensor description without copies of the actual tensor data, making reshapes basically free. Though, here's another case of balancing backend complexity vs front-end caller complexity ⚖. I'd like more info on BNNS and MPS behavior first (maybe Phillis knows?).
  • currently unidirectional broadcasting is used by only a few operators anyway, expand, GEMM's C tensor, and prelu:
    • For parameterized rectified linear unit, it's reasonable to change it from to undirectional broadcasting to bidirectional broadcasting, to ease porting from PyTorch (and really I was quite surprised to see that, because I naturally thought prelu already was bidirectionally broadcasted). Review welcome 😉.
    • For expand, it's functionally very similar to XLA's BroadCastInDim ("...expanding existing dimensions with size 1...The dimensions of operand must have size 1 or be the same size as the dimension in the output shape they are mapped to") except that BroadcastInDim combines both a reshape and an expand into a single operator (BroadCastInDim -> expand(reshape(input, coercedRankShape), expandedShape)). Btw, we used to have more reshape-family operators proposed for WebNN (squeeze, unsqueeze, flattenTo2D), until realizing that (a) they were all just little variations of reshape which the client can resolve as higher layer policy (b) there may be other reshaping variants we don't even know about, and (c) increasing the API surface here didn't actually bring any hardware benefit, because WebNN's real benefit is about the accelerated backends (the "real" operators, moreso than just fiddling with dimensions in a tensor description).
    • For GEMM's C tensor, bidirectional broadcasting wouldn't make any sense, and broadcasting for this common case brings performance benefits of avoiding the sizeable intermediate.
    • Thus, I still think unidirectional broadcasting as a concept is useful, even if only WebNN and ONNX have it for a few operators, whereas NumPy and XLA have it as an unnamed concept for a few special operators.

So, I'm partial to an option 4:

Option 4

  • adjust prelu to bidirectional
  • keep unidirectional for rare cases (expand and GEMM)
  • research more backends to potentially add restrictions that inputs must have the same rank.

@huningxin?

@a-sully
Copy link
Contributor Author

a-sully commented Apr 4, 2024

Thank you @fdwr for the very thorough response! I think your Option 4 proposal makes sense, with one addendum

My primary motivations for filing this issue were:

  • bring extra scrutiny to our uses of (especially unidirectional) broadcasting, as well as where we currently implicitly broadcast and should make this broadcasting explicit, and
  • open a discussion about what is expected to be handled by WebNN's implementation vs by its callers

It seems that I've been successful in that latter point :)

here's another case of balancing backend complexity vs front-end caller complexity ⚖

In the spirit of https://github.com/extensibleweb/manifesto I'm generally in favor of pushing complexity to callers (e.g. "This leads to better performance with less implementation effort"). In this case, I didn't expect that we'd actually adopt XLA's broadcasting rules for WebNN, though I figured it was worth calling it out as the option on the furthest towards the "caller complexity" end of things :P


As for the follow-up question... Regarding:

any WebNN operators that use broadcasting should be clear that they do, rather than something that implicitly happens for any operator

I agree! Related to that:

you can find rare occurrences of left alignment and even middle axis alignment, such as with instanceNormalization's scale and bias in the decomposition algorithm (e.g. [7] -> [1,7,1,1]])
and then:

Is this middle axis alignment perhaps only relevant when using NCHW layout? If we were using NHWC layout, would [7] broadcast to [1, 1, 1, 7]?

Regardless, the spec of instanceNormalization doesn't say anything about broadcasting. Let's add a fourth action item to Option 4?

  • specify broadcasts for all operators which currently implicitly broadcast

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants