Skip to content

Commit

Permalink
Fix weight tensors in transformers optimizer not saved to external da…
Browse files Browse the repository at this point in the history
…ta (#17427)

Some initializers are added without raw=True flag. That causes those
tensors cannot be saved to external data. If those tensors exceed 2GB
in total, optimized model cannot be saved due to protobuf limit.

This change will save attention weights and bias in raw data.

Note: it is optional to use raw data for shape tensor since they are
tiny.

### Motivation and Context
#17212
#15349
  • Loading branch information
tianleiwu authored Sep 6, 2023
1 parent 2629cb8 commit e8b8d0d
Show file tree
Hide file tree
Showing 18 changed files with 161 additions and 159 deletions.
79 changes: 24 additions & 55 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,7 @@ def process_mask(self, input: str) -> str:
# ReduceSum-13: axes is moved from attribute to input
axes_name = "ort_const_1_reduce_sum_axes"
if self.model.get_initializer(axes_name) is None:
self.model.add_initializer(
helper.make_tensor(
name=axes_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[1],
)
)
self.add_initializer(name=axes_name, data_type=TensorProto.INT64, dims=[1], vals=[1], raw=False)
mask_index_node = helper.make_node(
"ReduceSum",
inputs=[input_name, axes_name],
Expand Down Expand Up @@ -428,19 +421,12 @@ def create_combined_qkv_bias(
qkv_bias_dim = 3 * np.prod(qb.shape)

bias_name = name_prefix + "_qkv_bias"
bias = helper.make_tensor(
self.add_initializer(
name=bias_name,
data_type=TensorProto.FLOAT,
data_type=q_bias.data_type,
dims=[qkv_bias_dim],
vals=qkv_bias.flatten().tolist(),
vals=qkv_bias,
)

# Convert bias to FP16 if model is using FP16
if q_bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))

self.model.add_initializer(bias, self.this_graph_name)

return bias_name

def create_packed_qkv_matmul_node(
Expand Down Expand Up @@ -488,13 +474,13 @@ def create_packed_qkv_matmul_node(

qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d))
qkv_weight_name = matmul_node_name + "_qkv_weight"
weight = helper.make_tensor(

self.add_initializer(
name=qkv_weight_name,
data_type=TensorProto.FLOAT,
data_type=q_weight.data_type,
dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
vals=qkv_weight.flatten().tolist(),
vals=qkv_weight,
)
self.model.add_initializer(weight, self.this_graph_name)

# Created packed QKV MatMul with output (B, S, 3*D)
# Output is of the form:
Expand All @@ -519,23 +505,15 @@ def create_packed_qkv_matmul_node(

# Create Slice nodes to access Q, K, V
q_slice_name = matmul_node_name + "_q_start_index"
q_start_tensor = helper.make_tensor(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0])
self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False)
k_slice_name = matmul_node_name + "_k_start_index"
k_start_tensor = helper.make_tensor(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d])
self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False)
v_slice_name = matmul_node_name + "_v_start_index"
v_start_tensor = helper.make_tensor(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d])
self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False)
end_of_qkv_name = matmul_node_name + "_end_of_qkv_index"
end_of_qkv_tensor = helper.make_tensor(
name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d]
)
self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False)
qkv_last_axis_name = matmul_node_name + "_qkv_last_axis"
qkv_axis_tensor = helper.make_tensor(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1])

self.model.add_initializer(q_start_tensor, self.this_graph_name)
self.model.add_initializer(k_start_tensor, self.this_graph_name)
self.model.add_initializer(v_start_tensor, self.this_graph_name)
self.model.add_initializer(end_of_qkv_tensor, self.this_graph_name)
self.model.add_initializer(qkv_axis_tensor, self.this_graph_name)
self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False)

q_slice_output = matmul_node_name + "_q_out"
q_slice = helper.make_node(
Expand Down Expand Up @@ -823,7 +801,6 @@ def create_attention_node(
assert q_bias_shape == k_bias_shape == qw_out_size
assert v_bias_shape == vw_out_size

qkv_bias_dim = 0
if is_qkv_diff_dims:
qkv_bias = np.concatenate((qb, kb, vb), axis=0)
qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape
Expand All @@ -834,29 +811,20 @@ def create_attention_node(
attention_node_name = self.model.create_node_name("Attention")

if not self.use_multi_head_attention:
weight = helper.make_tensor(
self.add_initializer(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
data_type=q_weight.data_type,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
vals=qkv_weight,
)

# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)

bias = None
if has_bias:
bias = helper.make_tensor(
self.add_initializer(
name=attention_node_name + "_qkv_bias",
data_type=TensorProto.FLOAT,
data_type=q_bias.data_type,
dims=[qkv_bias_dim],
vals=qkv_bias.flatten().tolist(),
vals=qkv_bias,
)
if q_bias.data_type == 10:
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)

# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
Expand Down Expand Up @@ -1198,14 +1166,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if einsum_node is not None:
unique_index = einsum_node.input[0]
new_edge = "edge_modified_" + unique_index
shape_tensor = helper.make_tensor(

shape_tensor = self.add_initializer(
name="shape_modified_tensor" + unique_index,
data_type=TensorProto.INT64,
dims=[4],
vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]).tobytes(),
raw=True,
vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]),
raw=False,
)
self.model.add_initializer(shape_tensor, self.this_graph_name)

self.model.add_node(
helper.make_node(
"Reshape",
Expand Down
Loading

0 comments on commit e8b8d0d

Please sign in to comment.