Skip to content

Commit

Permalink
cuda layernorm support broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Dec 15, 2024
1 parent c7317cb commit 2f5b9b9
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 129 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // broadcast stride for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
33 changes: 25 additions & 8 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,36 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;

int broadcast = 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
// Handle a special case for MMDit where scale and bias need broadcast.
// X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride.
if (x_num_dims == 3 && axis == 2 && n2 > 1 &&
scale->Shape().NumDimensions() == x_num_dims &&
scale->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
scale->Shape().GetDims()[1] == 1 &&
scale->Shape().GetDims()[2] == x_shape.GetDims()[2] &&
bias->Shape().NumDimensions() == x_num_dims &&
bias->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
bias->Shape().GetDims()[1] == 1 &&
bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) {
broadcast = static_cast<int>(x_shape.GetDims()[1]);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
}

// Outputs
Expand All @@ -65,7 +82,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -94,7 +111,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ __global__ void cuApplyLayerNorm(
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta,
int broadcast,
const T* __restrict__ skip,
const T* __restrict__ bias,
T* __restrict__ skip_input_bias_add_output) {
Expand Down Expand Up @@ -366,8 +367,13 @@ __global__ void cuApplyLayerNorm(
curr += static_cast<U>(skip_vals[i]);
}

U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0;
// onnx operator LayerNormalization support broadcast.
// gamma and beta should be unidirectional broadcastable to tensor x.
// Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D)
int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i;
U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0;

if (simplified) {
ovals[i] = static_cast<V>(gamma_i * c_inv_std_dev * curr);
} else {
Expand Down Expand Up @@ -409,6 +415,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast,
const T* skip,
const T* bias,
T* skip_input_bias_add_output) {
Expand Down Expand Up @@ -442,15 +449,15 @@ void HostApplyLayerNorm(
input,
n1, n2,
U(epsilon),
gamma, beta,
gamma, beta, broadcast,
skip, bias, skip_input_bias_add_output);
}

#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
template void HostApplyLayerNorm<T, U, V, simplified>(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
double epsilon, const V* gamma, const V* beta, const T* skip, \
const T* bias, T* skip_input_bias_add_output);
double epsilon, const V* gamma, const V* beta, int broadcast, \
const T* skip, const T* bias, T* skip_input_bias_add_output);

LAYERNORM_LINEAR_IMPL(float, float, float, true)
LAYERNORM_LINEAR_IMPL(half, float, half, true)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast = 0, // broadcast stride for gamma/beta
const T* skip = nullptr,
const T* bias = nullptr,
T* skip_input_bias_add_output = nullptr);
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/python/tools/transformers/fusion_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class FusionLayerNormalization(Fusion):
def __init__(self, model: OnnxModel, check_constant_and_dimension:bool=True):
def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True):
super().__init__(model, "LayerNormalization", "ReduceMean")
self.check_constant_and_dimension = check_constant_and_dimension

Expand Down Expand Up @@ -132,11 +132,15 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):

node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
weight_input, 1, "layernorm weight"
):
return

bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"):
if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
bias_input, 1, "layernorm bias"
):
return

self.nodes_to_remove.extend(subgraph_nodes)
Expand Down
89 changes: 44 additions & 45 deletions onnxruntime/python/tools/transformers/fusion_mha_mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Optional, Tuple
from typing import Optional

import numpy as np
from fusion_base import Fusion
Expand All @@ -13,13 +13,14 @@

logger = getLogger(__name__)


class FusionMultiHeadAttentionMMDit(Fusion):
"""
Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT).
"""

def __init__(self, model: OnnxModel):
super().__init__(model, fused_op_type = "MultiHeadAttention", search_op_types = ["Softmax"])
super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"])

def get_num_heads(self, node: NodeProto, output_name_to_node) -> int:
"""
Expand All @@ -46,7 +47,8 @@ def get_num_heads(self, node: NodeProto, output_name_to_node) -> int:
node,
["SimplifiedLayerNormalization", "Transpose", "Reshape", "Add"],
[0, 0, 0, 0],
output_name_to_node=output_name_to_node)
output_name_to_node=output_name_to_node,
)

num_heads = 0
if k_proj_nodes:
Expand Down Expand Up @@ -101,11 +103,7 @@ def get_num_heads_with_concat(self, transpose_k: NodeProto, output_name_to_node)
|
Transpose(perm=0,1,3,2)
"""
nodes = self.model.match_parent_path(
transpose_k,
["Concat"],
[0],
output_name_to_node=output_name_to_node)
nodes = self.model.match_parent_path(transpose_k, ["Concat"], [0], output_name_to_node=output_name_to_node)

return self.get_num_heads(nodes[0], output_name_to_node) if nodes else 0

Expand Down Expand Up @@ -155,37 +153,35 @@ def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_

return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD")



def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> Optional[str]:
"""
Before:
MatMul MatMul .. [-1] [24] ..
| | | | / /
Add Concat Add Concat
| / | /
Reshape Reshape
| |
Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
| |
SimplifiedLayerNorm SimplifiedLayerNorm
| /
Concat(axis=2)
|
Mul
After:
MatMul MatMul .. [-1] [24] ..
| | | | / /
Add Concat Add Concat
| / | /
Reshape Reshape
| |
Before:
MatMul MatMul .. [-1] [24] ..
| | | | / /
Add Concat Add Concat
| / | /
Reshape Reshape
| |
Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
| |
SimplifiedLayerNorm SimplifiedLayerNorm
| /
Concat(axis=1)
|
Reshape (shape=[0, 0, -1])
| /
Concat(axis=2)
|
Mul
After:
MatMul MatMul .. [-1] [24] ..
| | | | / /
Add Concat Add Concat
| / | /
Reshape Reshape
| |
SimplifiedLayerNorm SimplifiedLayerNorm
| /
Concat(axis=1)
|
Reshape (shape=[0, 0, -1])
"""

path = self.model.match_parent_path(
Expand All @@ -204,7 +200,6 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -
concat,
["SimplifiedLayerNormalization", "Transpose"],
[1, 0],

)
if path is None:
return None
Expand Down Expand Up @@ -232,7 +227,6 @@ def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -

return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD")


def create_multihead_attention_node(
self,
q: str,
Expand Down Expand Up @@ -284,10 +278,9 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if self.model.find_graph_output(softmax.output[0]):
return

nodes = self.model.match_child_path(softmax,
["MatMul", "Transpose", "Reshape"],
[(0, 0), (0, 0), (0, 0)],
input_name_to_nodes)
nodes = self.model.match_child_path(
softmax, ["MatMul", "Transpose", "Reshape"], [(0, 0), (0, 0), (0, 0)], input_name_to_nodes
)
if nodes is None:
return

Expand Down Expand Up @@ -334,21 +327,27 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
# |
# v
# -- Transpose (perm=[0,2,1,3]) -> Concat -> (v)
transpose_1 = self.model.match_parent(concat, "Transpose", input_index=0, output_name_to_node=output_name_to_node)
transpose_1 = self.model.match_parent(
concat, "Transpose", input_index=0, output_name_to_node=output_name_to_node
)
if transpose_1 is None:
return
if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
return

transpose_2 = self.model.match_parent(concat, "Transpose", input_index=1, output_name_to_node=output_name_to_node)
transpose_2 = self.model.match_parent(
concat, "Transpose", input_index=1, output_name_to_node=output_name_to_node
)
if transpose_2 is None:
return
if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]):
return
else:
# Match v path like:
# -- Transpose (perm=[0,2,1,3]) -> (v)
transpose_1 = self.model.match_parent(matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node)
transpose_1 = self.model.match_parent(
matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
)
if transpose_1 is None:
return
if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
Expand Down
18 changes: 10 additions & 8 deletions onnxruntime/python/tools/transformers/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,11 @@ def get_children(self, node, input_name_to_nodes=None, output_index=None):
if output_index < len(node.output):
output = node.output[output_index]
if output in input_name_to_nodes:
for node in input_name_to_nodes[output]:
children.append(node)
children = list(input_name_to_nodes[output])
else:
for output in node.output:
if output in input_name_to_nodes:
for node in input_name_to_nodes[output]:
children.append(node) # noqa: PERF402
children.extend(input_name_to_nodes[output])

return children

Expand Down Expand Up @@ -444,7 +442,7 @@ def match_child_path(
self,
node,
child_op_types,
edges:Optional[List[Tuple[int, int]]]=None,
edges: Optional[List[Tuple[int, int]]] = None,
input_name_to_nodes=None,
exclude=[], # noqa: B006
):
Expand All @@ -465,10 +463,12 @@ def match_child_path(
if edges is not None:
assert len(edges) == len(child_op_types)
for edge in edges:
assert isinstance(edge, tuple) and len(edge) == 2 and isinstance(edge[0], int) and isinstance(edge[1], int)
assert (
isinstance(edge, tuple) and len(edge) == 2 and isinstance(edge[0], int) and isinstance(edge[1], int)
)

if input_name_to_nodes is None:
input_name_to_nodes = self.input_name_to_nodes()
input_name_to_nodes = self.input_name_to_nodes()

current_node = node
matched_children = []
Expand All @@ -478,7 +478,9 @@ def match_child_path(
if edges is None:
children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes)
else:
children_nodes = self.get_children(current_node, input_name_to_nodes=input_name_to_nodes, output_index=edges[i][0])
children_nodes = self.get_children(
current_node, input_name_to_nodes=input_name_to_nodes, output_index=edges[i][0]
)

for child in children_nodes:
if child.op_type == op_type and child not in exclude:
Expand Down
Loading

0 comments on commit 2f5b9b9

Please sign in to comment.