forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fusions for OpenAI CLIP (microsoft#20721)
### Description This PR adds fusions for [OpenAI's CLIP model](https://huggingface.co/openai/clip-vit-large-patch14-336). Here is an example of how to run the ORT transformer optimizer for the linked CLIP model. ``` $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers $ python3 optimizer.py --input /path/to/model.onnx --output /path/to/model_opt.onnx --model_type clip --num_heads 16 --hidden_size 1024 --use_external_data_format --opt_level 0 ``` ### Motivation and Context This PR helps optimize multi-modal models that use CLIP for the vision encoder.
- Loading branch information
1 parent
5d07291
commit ca22a5a
Showing
6 changed files
with
167 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
|
||
import logging | ||
|
||
from fusion_base import Fusion | ||
from onnx import helper | ||
from onnx_model import OnnxModel | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FusionQuickGelu(Fusion): | ||
def __init__(self, model: OnnxModel): | ||
super().__init__(model, "QuickGelu", ["Mul"]) | ||
|
||
def fuse(self, node, input_name_to_nodes, output_name_to_node): | ||
# Fuse the following subgraph to `QuickGelu` | ||
# | ||
# root_input | ||
# / \ | ||
# | Mul ----+ | ||
# | (B = ~1.702) | | ||
# \ | | | ||
# \ Sigmoid |---- `QuickGelu` | ||
# \ / | | ||
# \ / | | ||
# Mul ----+ | ||
# | | ||
# root_output | ||
|
||
if node.op_type != "Mul": | ||
logger.debug("fuse_quickgelu: failed to match second Mul node") | ||
return | ||
|
||
second_mul_node = node | ||
root_input = second_mul_node.input[0] | ||
|
||
sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1]) | ||
if sigmoid_node is None: | ||
logger.debug("fuse_quickgelu: failed to match Sigmoid node") | ||
return | ||
sigmoid_node = sigmoid_node[0] | ||
|
||
first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0]) | ||
if first_mul_node is None: | ||
logger.debug("fuse_quickgelu: failed to match first Mul node") | ||
return | ||
first_mul_node = first_mul_node[0] | ||
|
||
approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item() | ||
if abs(approximation_value - 1.7021484375) >= 1e-3: | ||
logger.debug("fuse_quickgelu: failed to match approximation value") | ||
return | ||
|
||
if first_mul_node.input[0] != root_input: | ||
logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input") | ||
return | ||
|
||
new_node = helper.make_node( | ||
"QuickGelu", | ||
inputs=[root_input], | ||
outputs=[second_mul_node.output[0]], | ||
name=self.model.create_node_name("QuickGelu"), | ||
) | ||
new_node.domain = "com.microsoft" | ||
new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)]) | ||
|
||
self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node]) | ||
self.nodes_to_add.append(new_node) | ||
self.node_name_to_graph_name[new_node.name] = self.this_graph_name | ||
self.increase_counter("QuickGelu") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters