Skip to content

Commit

Permalink
Introduce new optimizer MatMul + BatchNormalization (#17915)
Browse files Browse the repository at this point in the history
Introduce new ORT L1 optimizer under RewriteRule category to fuse MatMul
+ BatchNormalization node. This optimizer look for a specific pattern
observed in one of the impacting customer models and fuse the Matmul and
Batchnormalization node into a Gemm node. For details on the pattern
matching and fusion please refer to the comment section of
`matmul_bn_fusion.cc`.

To visualize, this optimizer will replace following subgraph to a Gemm
node.
<pre>
               MatMul                  GEMM
                 |                       |
              Reshape ^     --->      Reshape ^
                 |                       |
            Transpose ^             Transpose ^
                 |
       BatchNormalization
Note: ^ means there can be >=0 occurrence(s) of that node.
Few example fusable pattern:
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM ->
Reshape -> Transpose
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM ->
Reshape -> Reshape
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization --->
GEMM -> Reshape -> Transpose -> Reshape
* - MatMul -> BatchNormalization ---> GEMM
</pre>

Note: This optimizer may evolve in the future to be more generic in
terms of the pattern matching.

- Why is this change required? What problem does it solve?
One of the user of ORT+DML ep needs this to better target the model to
DML. But this transformation applies more broadly, so added L1
optimizer.
<!-- - If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
sumitsays authored and raoanag committed Nov 9, 2023
1 parent b2768bb commit caa0cc1
Showing 0 changed files with 0 additions and 0 deletions.

0 comments on commit caa0cc1

Please sign in to comment.