Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Stable Diffusion 3.x and Flux Optimization #22986

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 25 additions & 36 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,42 +125,31 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);

if (data.bias == nullptr) {
assert(nullptr == fused_runner);
// For quantized attention, bias has been added so only need transpose here.
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
assert(qk_head_size == v_head_size);
int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.gemm_buffer, qkv, 3));
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
} else {
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
data.qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_flash_or_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
: AttentionQkvFormat::Q_K_V_BNSH));

// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;

int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.rotary_embedding,
parameters.past_sequence_length);
}
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
data.qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_flash_or_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal
? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
: AttentionQkvFormat::Q_K_V_BNSH));

// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;

int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
// format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
// format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.rotary_embedding,
parameters.past_sequence_length);
return Status::OK();
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // broadcast stride for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/print_tensor_statistics_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void PrintFloatStats(const T* data, size_t count) {
size_t zero = 0;
size_t subnormal = 0;
for (size_t i = 0; i < count; i++) {
switch (my_fpclassify(*data)) {
switch (my_fpclassify(data[i])) {
case FP_INFINITE:
inf++;
break;
Expand Down
33 changes: 25 additions & 8 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,36 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;

int broadcast = 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
// Handle a special case for MMDit where scale and bias need broadcast.
// X shape is (B, S, D), scale and bias shape is (B, 1, D), and we store S as broadcast stride.
if (x_num_dims == 3 && axis == 2 && n2 > 1 &&
scale->Shape().NumDimensions() == x_num_dims &&
scale->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
scale->Shape().GetDims()[1] == 1 &&
scale->Shape().GetDims()[2] == x_shape.GetDims()[2] &&
bias->Shape().NumDimensions() == x_num_dims &&
bias->Shape().GetDims()[0] == x_shape.GetDims()[0] &&
bias->Shape().GetDims()[1] == 1 &&
bias->Shape().GetDims()[2] == x_shape.GetDims()[2]) {
broadcast = static_cast<int>(x_shape.GetDims()[1]);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
}

// Outputs
Expand All @@ -65,7 +82,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -94,7 +111,7 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
X_data, n1, n2, epsilon_, scale_data, bias_data, broadcast);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ __global__ void cuApplyLayerNorm(
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta,
int broadcast,
const T* __restrict__ skip,
const T* __restrict__ bias,
T* __restrict__ skip_input_bias_add_output) {
Expand Down Expand Up @@ -366,8 +367,13 @@ __global__ void cuApplyLayerNorm(
curr += static_cast<U>(skip_vals[i]);
}

U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0;
// onnx operator LayerNormalization support broadcast.
// gamma and beta should be unidirectional broadcastable to tensor x.
// Here we support a special case for transformer models that x is (B, S, D) and gamma/beta is (B, 1, D)
int index = (broadcast > 0) ? ((i1 / broadcast) * n2 + i) : i;
U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1;
U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0;

if (simplified) {
ovals[i] = static_cast<V>(gamma_i * c_inv_std_dev * curr);
} else {
Expand Down Expand Up @@ -409,6 +415,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast,
const T* skip,
const T* bias,
T* skip_input_bias_add_output) {
Expand Down Expand Up @@ -442,15 +449,15 @@ void HostApplyLayerNorm(
input,
n1, n2,
U(epsilon),
gamma, beta,
gamma, beta, broadcast,
skip, bias, skip_input_bias_add_output);
}

#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
template void HostApplyLayerNorm<T, U, V, simplified>(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
double epsilon, const V* gamma, const V* beta, const T* skip, \
const T* bias, T* skip_input_bias_add_output);
double epsilon, const V* gamma, const V* beta, int broadcast, \
const T* skip, const T* bias, T* skip_input_bias_add_output);

LAYERNORM_LINEAR_IMPL(float, float, float, true)
LAYERNORM_LINEAR_IMPL(half, float, half, true)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void HostApplyLayerNorm(
double epsilon,
const V* gamma,
const V* beta,
int broadcast = 0, // broadcast stride for gamma/beta
const T* skip = nullptr,
const T* bias = nullptr,
T* skip_input_bias_add_output = nullptr);
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/python/tools/transformers/compare_bert_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,23 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
# Validate the output of baseline and treatment, to make sure the results are similar.
diff_count = 0
max_abs_diff = 0
max_diff_percentage = 0
case_passed = True
for test_case_id, results in enumerate(baseline_results):
case_passed = True
for i in range(len(results)):
treatment_output = treatment_results[test_case_id][i]
abs_diff = np.amax(np.abs(treatment_output - results[i]))
abs_diff_tensor = np.abs(treatment_output - results[i])
abs_diff = np.amax(abs_diff_tensor)
if verbose and abs_diff > atol:
print("abs_diff", abs_diff)
print("treatment", treatment_output)
print("baseline", results[i])

count_exceeding = np.sum(abs_diff_tensor > atol)
total_elements = abs_diff_tensor.size
percentage_exceeding = (count_exceeding / total_elements) * 100
max_diff_percentage = max(max_diff_percentage, percentage_exceeding)

max_abs_diff = max(max_abs_diff, abs_diff)
if not np.allclose(results[i].tolist(), treatment_output.tolist(), rtol=rtol, atol=atol):
if case_passed:
Expand All @@ -66,6 +73,7 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-1, atol=1e-3):
)

print(f"maximum absolute difference={max_abs_diff}")
print(f"maximum percentage of elements that exceeds atol={atol} is {max_diff_percentage:.3f}%")
return max_abs_diff, case_passed


Expand Down
Loading
Loading