-
Notifications
You must be signed in to change notification settings - Fork 127
conv fuse #288
base: master
Are you sure you want to change the base?
conv fuse #288
Conversation
@@ -559,6 +561,7 @@ class GraphImpl implements Graph, Graph.Transformer { | |||
// apply common transform | |||
this.removeAllIdentityNodes(); | |||
this.removeAllDropoutNodes(); | |||
this.fuseConvActivationNodes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we apply fusion here, wouldn't it fail all backends except webgl?
I'm a little surprised that browser stack CI got passed. did I miss anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's my guess too. So I submit this PR and want to check how the CI goes. Will look into the result tomorrow.
} | ||
} | ||
|
||
fuseConvActivationNodes() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, it's not enough to traverse the graph only once. Consider a graph with pattern Conv+Relu+Sigmoid. After one iteration, it's transformed to Conv+Sigmoid (as Relu is fused). Ideally, it should further fuse Conv and Sigmoid. We may want to keep running the transformer until nothing can be transformed. It's fine to leave it for future work as we don't have immediate requirements for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically you are right, We need to keep checking until no further fusing can take place. In practice, I am not sure if a model would concat multiple activations in a row. It seems redundant for following activations once the signal is 'activated' after the first one. Maybe there is a user case but it shouldn't be very common I guess..
The implementation won't be hard. We just need to insert possibly several internal attributes, one for an activation function. And loop them through in conv's shader gen. For the sake of simplicity, I'll keep it as it is for now. And if we do see a need in the future, we can add this logic in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right that the example I used above was artificial. What I suggested is kind of the standard approach for graph level optimization, which is also used by ORT. In other common situations like matmul+add+add (=> gemm+add => gemm), this approach is needed. But our current fusion transformer covers only two activations, so we don't have to do it now.
fc4e05c
to
ff67621
Compare
@@ -321,16 +324,20 @@ export class WebGLUnpackedConv extends Conv { | |||
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape); | |||
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)'; | |||
const sharedDim = im2colLayout.shape[3]; | |||
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported; | |||
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for my education, why blend cannot co-exist with fusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any indication on perf by disabling blend while enabling activation?
This PR fuses conv with following activation functions.