Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Stable Diffusion 3.x and Flux Optimization #22986

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // broadcast stride for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
33 changes: 25 additions & 8 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,36 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;

int broadcast = 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
// Handle a special case for MMDit where scale and bias need broadcast.
// X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride.
if (x_num_dims == 3 && axis == 2 && n2 > 1 &&
scale->Shape().NumDimensions() == x_num_dims &&
scale->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
scale->Shape().GetDims()[1] == 1 &&
scale->Shape().GetDims()[2] == x_shape.GetDims()[2] &&
bias->Shape().NumDimensions() == x_num_dims &&
bias->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
bias->Shape().GetDims()[1] == 1 &&
bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) {
broadcast = static_cast<int>(x_shape.GetDims()[1]);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
}

// Outputs
Expand All @@ -65,7 +82,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -94,7 +111,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ __global__ void cuApplyLayerNorm(
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta,
int broadcast,
const T* __restrict__ skip,
const T* __restrict__ bias,
T* __restrict__ skip_input_bias_add_output) {
Expand Down Expand Up @@ -366,8 +367,13 @@ __global__ void cuApplyLayerNorm(
curr += static_cast<U>(skip_vals[i]);
}

U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0;
// onnx operator LayerNormalization support broadcast.
// gamma and beta should be unidirectional broadcastable to tensor x.
// Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D)
int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i;
U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0;

if (simplified) {
ovals[i] = static_cast<V>(gamma_i * c_inv_std_dev * curr);
} else {
Expand Down Expand Up @@ -409,6 +415,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast,
const T* skip,
const T* bias,
T* skip_input_bias_add_output) {
Expand Down Expand Up @@ -442,15 +449,15 @@ void HostApplyLayerNorm(
input,
n1, n2,
U(epsilon),
gamma, beta,
gamma, beta, broadcast,
skip, bias, skip_input_bias_add_output);
}

#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
template void HostApplyLayerNorm<T, U, V, simplified>(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
double epsilon, const V* gamma, const V* beta, const T* skip, \
const T* bias, T* skip_input_bias_add_output);
double epsilon, const V* gamma, const V* beta, int broadcast, \
const T* skip, const T* bias, T* skip_input_bias_add_output);

LAYERNORM_LINEAR_IMPL(float, float, float, true)
LAYERNORM_LINEAR_IMPL(half, float, half, true)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast = 0, // broadcast stride for gamma/beta
const T* skip = nullptr,
const T* bias = nullptr,
T* skip_input_bias_add_output = nullptr);
Expand Down
39 changes: 0 additions & 39 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,45 +399,6 @@ def split_kv(self, present_k_name: str, present_v_name: str, kv_node: str):
self.node_name_to_graph_name[gather_k_name] = self.this_graph_name
self.node_name_to_graph_name[gather_v_name] = self.this_graph_name

def transpose_kv(self, past_k: str, past_v: str):
"""Transpose past_k and past_v from (B,N,P,H) to (B,P,N,H)

Args:
past_k (str): name of past K value of shape (B,N,P,H)
past_v (str): name of past V value of shape (B,N,P,H)

Returns:
past_k_transpose (str): name of past K value of shape (B,P,N,H)
past_v_transpose (str): name of past V value of shape (B,P,N,H)
"""
past_k_transpose = (past_k + "_transposed").replace(".", "_")
past_v_transpose = (past_v + "_transposed").replace(".", "_")
transpose_k_name = self.model.create_node_name("Transpose")
transpose_v_name = self.model.create_node_name("Transpose")

transpose_k = helper.make_node(
"Transpose",
inputs=[past_k],
outputs=[past_k_transpose],
name=transpose_k_name,
perm=[0, 2, 1, 3],
)
transpose_v = helper.make_node(
"Transpose",
inputs=[past_v],
outputs=[past_v_transpose],
name=transpose_v_name,
perm=[0, 2, 1, 3],
)

# Add reshape nodes to graph
self.nodes_to_add.append(transpose_k)
self.nodes_to_add.append(transpose_v)
self.node_name_to_graph_name[transpose_k_name] = self.this_graph_name
self.node_name_to_graph_name[transpose_v_name] = self.this_graph_name

return past_k_transpose, past_v_transpose

def create_combined_qkv_bias(
self,
q_add: NodeProto,
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
return

if self.fuse_4(tanh_node, input_name_to_nodes, output_name_to_node):
return

def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
"""
Fuse Gelu with tanh into one node:
Expand Down Expand Up @@ -358,3 +361,122 @@
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True

def fuse_4(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""
This pattern is from stable diffusion 3.5 model.
Fuse Gelu with tanh into one node:
+-----------------+------------------+
| | |
| v v
[root] ==> Mul --> Mul --> Mul -----> Add --> Mul --> Tanh --> Add -----> Mul --> Mul -->
| (A=0.0447) (A=0.7978) (A=1) ^ (A=0.5)
| |
+-------------------------------------------------------------------------+
Note that constant input for Add and Mul could be first or second input.
"""
if tanh_node.output[0] not in input_name_to_nodes:
return

children = input_name_to_nodes[tanh_node.output[0]]
if len(children) != 1 or children[0].op_type != "Add":
return
add_after_tanh = children[0]

if not self.model.has_constant_input(add_after_tanh, 1.0):
return

if add_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[add_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_after_tanh = children[0]

if mul_after_tanh.output[0] not in input_name_to_nodes:
return
children = input_name_to_nodes[mul_after_tanh.output[0]]
if len(children) != 1 or children[0].op_type != "Mul":
return
mul_half = children[0]
if not self.model.has_constant_input(mul_half, 0.5):
return

root_input = mul_after_tanh.input[0 if mul_after_tanh.input[1] == add_after_tanh.output[0] else 1]

mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
if mul_before_tanh is None:
return

i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
if i < 0:
return

add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
if add_before_tanh is None:
return

if add_before_tanh.input[0] == root_input:
another = 1
elif add_before_tanh.input[1] == root_input:
another = 0
else:
return

mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", another, output_name_to_node)
if mul_after_pow is None:
return

i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
if i < 0:
return

mul = self.model.match_parent(mul_after_pow, "Mul", 0 if i == 1 else 1, output_name_to_node)
if mul is None:
return

if mul.input[0] == root_input:
another = 1
elif mul.input[1] == root_input:
another = 0
else:
return

mul2 = self.model.match_parent(mul, "Mul", another, output_name_to_node)
if mul2 is None:
return

if mul2.input[0] != root_input or mul2.input[1] != root_input:
return

subgraph_nodes = [
mul2,
mul,
mul_after_pow,
add_before_tanh,
mul_before_tanh,
tanh_node,
add_after_tanh,
mul_after_tanh,
mul_half,
]

if not self.model.is_safe_to_fuse_nodes(
subgraph_nodes,
[mul_half.output[0]],
input_name_to_nodes,
output_name_to_node,
):
return

self.nodes_to_remove.extend(subgraph_nodes)
fused_node = helper.make_node(
"FastGelu",
inputs=[root_input],
outputs=mul_half.output,
name=self.model.create_node_name("FastGelu"),
)
fused_node.domain = "com.microsoft"
self.nodes_to_add.append(fused_node)
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
return True
Loading
Loading