From e6cc581ae0734806320b5d99b3e7a3fb5f7116a5 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 29 Oct 2023 04:13:35 -0700 Subject: [PATCH 1/2] Change concat for Expand in attention mask reshaping --- .../tools/transformers/fusion_attention.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index c1b241aa1a5ec..5ea9338385e55 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -224,23 +224,31 @@ def reshape_add_qk(self, add_qk: str): # B = batch size, N = num heads, S = source sequence length, T = target sequence length mask_output_name = add_qk + "_mask" - # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists - concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add)) - if len(concat_node) == 1: + # Check if expand node for (B,1,S,T) --> (B,N,S,T) already exists + expand_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add)) + if len(expand_node) == 1: return mask_output_name - assert len(concat_node) == 0 - concat_node_name = self.model.create_node_name("Concat") - concat_add_qk_fp32 = helper.make_node( - "Concat", - inputs=[add_qk for _ in range(self.num_heads)], + assert len(expand_node) == 0 + expand_node_name = self.model.create_node_name("Expand") + + expand_add_qk_shape = self.add_initializer( + name="expand_add_qk_shape", + data_type=TensorProto.INT64, + dims=[4], + vals=[1, self.num_heads, 1, 1], + raw=False, + ) + + expand_add_qk_fp32 = helper.make_node( + "Expand", + inputs=[add_qk, expand_add_qk_shape.name], outputs=[mask_output_name], - name=concat_node_name, - axis=1, + name=expand_node_name, ) # Add new node to graph - self.nodes_to_add.append(concat_add_qk_fp32) - self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + self.nodes_to_add.append(expand_add_qk_fp32) + self.node_name_to_graph_name[expand_node_name] = self.this_graph_name return mask_output_name From 164d3a8010cbbb9fd6c23784252eda138d89c6d6 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sun, 29 Oct 2023 11:24:24 -0700 Subject: [PATCH 2/2] Make initializer name unique --- onnxruntime/python/tools/transformers/fusion_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 5ea9338385e55..6e988deb8167d 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -233,7 +233,7 @@ def reshape_add_qk(self, add_qk: str): expand_node_name = self.model.create_node_name("Expand") expand_add_qk_shape = self.add_initializer( - name="expand_add_qk_shape", + name=mask_output_name + "_shape", data_type=TensorProto.INT64, dims=[4], vals=[1, self.num_heads, 1, 1],