Skip to content

Commit

Permalink
[CUDA] Add use_tf32 provider option (for FP32 GEMM) (#19357)
Browse files Browse the repository at this point in the history
[TF32](https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/)
could help boost performance on GPU of SM >= 80. Sometime, user observes accuracy loss, or need disable TF32 for testing
purpose. To disable TF32, it is also possible to set environment
variable `NVIDIA_TF32_OVERRIDE = 0`. However, sometime we do not want to
use environment variable to avoid impacting other applications, or want
to have finer control (like one session using TF32, and another session
not). This provider option could help.

Here we add a provider option `use_tf32`. When `use_tf32 = 0`, we will
disable TF32 for float MatMul/GEMM in cublas. It applies to MatMulNBits,
Attention, LongformerAttention, PackedAttention,
PackedMultiHeadAttention operators when float GEMM is used internally in
the operator. Note that it will not impact other data type, like fp8
gemm could still use TF32 in accumulation.

Previously, cublasGemmStridedBatchedHelper does not use TF32 in
inference. Here we enabled TF32 by default, so we might observe speed up
for FP32 transformers models on SM >= 80.

There is another PR that enables the option for cuDNN Conv later.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

#15407
#19288
  • Loading branch information
tianleiwu authored Feb 6, 2024
1 parent c4b49fb commit bedf0ee
Show file tree
Hide file tree
Showing 36 changed files with 245 additions and 139 deletions.
2 changes: 2 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext {
bool cudnn_conv1d_pad_to_nc1d = false;
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;
bool use_tf32 = true;

void Init(const OrtKernelContext& kernel_ctx) {
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
Expand All @@ -52,6 +53,7 @@ struct CudaContext : public CustomOpContext {
cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 {
// The strict mode has better accuracy but lower performance.
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32
};
1 change: 1 addition & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ enum CudaResource : int {
cudnn_conv1d_pad_to_nc1d_t,
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
};
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct AttentionParameters {
bool pass_past_in_kv;
float mask_filter_value;
float scale;
bool use_tf32;
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand All @@ -82,6 +83,7 @@ struct PackedAttentionParameters {
int token_count;
bool has_relative_position_bias;
bool broadcast_res_pos_bias;
bool use_tf32;
};

// Parameters deduced from node attributes and inputs/outputs.
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
parameters.use_tf32 = UseTF32();

// Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr
std::vector<int64_t> bias_dims{weights->Shape().GetDims()[1]};
const TensorShape bias_shape{bias_dims};
Expand Down Expand Up @@ -251,7 +253,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ Status UnfusedAttention(
total_sequence_length, sequence_length, qk_head_size,
&alpha, data.k, qk_head_size, present_size_per_batch_k,
data.q, qk_head_size, sequence_length * qk_head_size,
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches,
device_prop, parameters.use_tf32));

DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length);
Expand Down Expand Up @@ -514,7 +515,7 @@ Status UnfusedAttention(
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
Expand Down
18 changes: 10 additions & 8 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>()), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (h2, h1)*(h1, S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(q_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_query_buffer in col-base: (h2, S*B)

// calcualte k, v
Expand All @@ -298,13 +298,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * key_sequence_length * hidden_size,
Expand All @@ -318,13 +318,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(key->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
} else {
Expand All @@ -342,13 +342,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
kv_sequence_length = cache_sequence_length;
Expand All @@ -372,6 +372,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
device_prop,
#ifdef USE_ROCM
GetTuningContext(),
#else
UseTF32(),
#endif
context->GetComputeStream(),
cublas,
Expand Down
18 changes: 11 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ Status DecoderQkvToContext(
T* workspace_buffer,
T* output,
T* new_key_cache,
T* new_value_cache) {
T* new_value_cache,
bool use_tf32) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int BN = batch_size * num_heads;
const int BHN = BN * head_size;
Expand Down Expand Up @@ -128,14 +129,14 @@ Status DecoderQkvToContext(
kv_sequence_length, sequence_length, head_size,
&alpha, key_cache, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
kv_sequence_length, sequence_length, head_size,
&alpha, k, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32));
}

constexpr bool is_unidirectional = false;
Expand Down Expand Up @@ -163,14 +164,14 @@ Status DecoderQkvToContext(
head_size, sequence_length, kv_sequence_length,
&one, value_cache, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
&zero, scratch3, head_size, strideB, BN, device_prop, use_tf32));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
head_size, sequence_length, kv_sequence_length,
&one, v, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
&zero, scratch3, head_size, strideB, BN, device_prop, use_tf32));
}

// scratch3 is BxNxSxH, transpose to output SxBxNxH
Expand All @@ -180,6 +181,7 @@ Status DecoderQkvToContext(

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& device_prop,
bool use_tf32,
Stream* stream,
cublasHandle_t& cublas,
const size_t element_size,
Expand Down Expand Up @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel(
reinterpret_cast<half*>(workspace_buffer),
reinterpret_cast<half*>(output),
reinterpret_cast<half*>(new_key_cache),
reinterpret_cast<half*>(new_value_cache));
reinterpret_cast<half*>(new_value_cache),
use_tf32);
} else {
return DecoderQkvToContext(
device_prop,
Expand All @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel(
reinterpret_cast<float*>(workspace_buffer),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(new_key_cache),
reinterpret_cast<float*>(new_value_cache));
reinterpret_cast<float*>(new_value_cache),
use_tf32);
}
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace cuda {

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
bool use_tf32, // Use TF32
Stream* stream, // ORT Stream
cublasHandle_t& cublas, // Cublas handle
const size_t element_size, // Element size of input tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T1>()), n,
reinterpret_cast<const CudaT*>(input->Data<T1>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

// Update the q, k, and v buffers
parameters.q = gemm_buffer.get();
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
weights_data, n,
input_data, k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));
} else {
// q
const CudaT* q_weight = weights_data;
Expand All @@ -145,15 +145,15 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
q_weight, n,
input_data, k,
&zero, q_data, n, device_prop));
&zero, q_data, n, device_prop, UseTF32()));
// k
const CudaT* k_weight = q_weight + static_cast<int64_t>(hidden_size) * hidden_size;
CudaT* k_data = q_data + static_cast<int64_t>(batch_size) * sequence_length * hidden_size;
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
k_weight, n,
input_data, k,
&zero, k_data, n, device_prop));
&zero, k_data, n, device_prop, UseTF32()));

// v
const CudaT* v_weight = k_weight + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -162,7 +162,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
v_weight, n,
input_data, k,
&zero, v_data, n, device_prop));
&zero, v_data, n, device_prop, UseTF32()));
}

// Wait for async copy of batch_global_num
Expand Down Expand Up @@ -195,7 +195,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(global_weights->Data<T>()), n,
input_data, k,
&zero, global_gemm_buffer, n, device_prop));
&zero, global_gemm_buffer, n, device_prop, UseTF32()));
} else {
// global q
const CudaT* global_q_weight = global_weights_data;
Expand All @@ -205,7 +205,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_q_weight, n,
input_data, k,
&zero, global_q, n, device_prop));
&zero, global_q, n, device_prop, UseTF32()));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas,
Expand All @@ -226,7 +226,8 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
hidden_size, // ldc
static_cast<int64_t>(max_num_global) * hidden_size, // strideC
batch_size, // batch count
device_prop));
device_prop,
UseTF32()));
}
// global k
const CudaT* global_k_weight = global_weights_data + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -235,7 +236,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_k_weight, n,
input_data, k,
&zero, global_k, n, device_prop));
&zero, global_k, n, device_prop, UseTF32()));

// global v
const CudaT* global_v_weight = global_k_weight + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -244,7 +245,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_v_weight, n,
input_data, k,
&zero, global_v, n, device_prop));
&zero, global_v, n, device_prop, UseTF32()));
}
}

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
parameters.use_tf32 = UseTF32();

ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
value,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* relative_position_bias = context->Input<Tensor>(5);

PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
Expand Down Expand Up @@ -308,12 +309,12 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublasHandle_t cublas = this->GetCublasHandle(context);

// Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias
// The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel.
// The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel.
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

constexpr size_t element_size = sizeof(T);
constexpr bool no_qkv_workspace = false; // need workspace to add bias
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention(
q, qk_head_size, sequence_length * qk_head_size,
&zero,
scaled_qk, sequence_length, sequence_length * sequence_length,
batches, device_prop));
batches, device_prop, parameters.use_tf32));

DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length);

Expand Down Expand Up @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention(
v_head_size, sequence_length, sequence_length,
&one, v, v_head_size, sequence_length * v_head_size,
attention_score, sequence_length, sequence_length * sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v
Status result = LaunchTransposeRemovePadding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
const Tensor* relative_position_bias = context->Input<Tensor>(6);

PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(),
key,
value,
Expand Down
Loading

0 comments on commit bedf0ee

Please sign in to comment.