Skip to content

Commit

Permalink
Add use_tf32 cuda provider option
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jan 31, 2024
1 parent 04afe77 commit 7b1627c
Show file tree
Hide file tree
Showing 35 changed files with 192 additions and 95 deletions.
2 changes: 2 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext {
bool cudnn_conv1d_pad_to_nc1d = false;
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;
bool use_tf32 = true;

void Init(const OrtKernelContext& kernel_ctx) {
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
Expand All @@ -52,6 +53,7 @@ struct CudaContext : public CustomOpContext {
cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 {
// The strict mode has better accuracy but lower performance.
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32

Check warning on line 40 in include/onnxruntime/core/providers/cuda/cuda_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_provider_options.h#L40

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_provider_options.h:40:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};
1 change: 1 addition & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ enum CudaResource : int {
cudnn_conv1d_pad_to_nc1d_t,
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
};
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct AttentionParameters {
bool pass_past_in_kv;
float mask_filter_value;
float scale;
bool use_tf32;
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand All @@ -82,6 +83,7 @@ struct PackedAttentionParameters {
int token_count;
bool has_relative_position_bias;
bool broadcast_res_pos_bias;
bool use_tf32;
};

// Parameters deduced from node attributes and inputs/outputs.
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
parameters.use_tf32 = UseTF32();

Check warning on line 88 in onnxruntime/contrib_ops/cuda/bert/attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/attention.cc#L88

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/contrib_ops/cuda/bert/attention.cc:88:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
// Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr
std::vector<int64_t> bias_dims{weights->Shape().GetDims()[1]};
const TensorShape bias_shape{bias_dims};
Expand Down Expand Up @@ -251,7 +253,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ Status UnfusedAttention(
total_sequence_length, sequence_length, qk_head_size,
&alpha, data.k, qk_head_size, present_size_per_batch_k,
data.q, qk_head_size, sequence_length * qk_head_size,
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop, parameters.use_tf32));

Check warning on line 464 in onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/attention_impl.cu#L464

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/bert/attention_impl.cu:464:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length);
Expand Down Expand Up @@ -514,7 +514,7 @@ Status UnfusedAttention(
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>()), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (h2, h1)*(h1, S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(q_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_query_buffer in col-base: (h2, S*B)

// calcualte k, v

Check notice on line 285 in onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc#L285

"calcualte" is a misspelling of "calculate"
Raw output
./onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc:285:5: "calcualte" is a misspelling of "calculate"
Expand All @@ -298,13 +298,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * key_sequence_length * hidden_size,
Expand All @@ -318,13 +318,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(key->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
} else {
Expand All @@ -342,13 +342,13 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
kv_sequence_length = cache_sequence_length;
Expand All @@ -372,6 +372,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
device_prop,
#ifdef USE_ROCM
GetTuningContext(),
#else
UseTF32(),

Check warning on line 376 in onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc#L376

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc:376:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
#endif
context->GetComputeStream(),
cublas,
Expand All @@ -395,7 +397,8 @@ Status DecoderAttention<T>::ComputeInternal(OpKernelContext* context) const {
workspace_p.get(),
output->MutableData<T>(),
nullptr == new_key_cache ? nullptr : new_key_cache->MutableData<T>(),
nullptr == new_value_cache ? nullptr : new_value_cache->MutableData<T>());
nullptr == new_value_cache ? nullptr : new_value_cache->MutableData<T>()
);
}

} // namespace cuda
Expand Down
18 changes: 11 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ Status DecoderQkvToContext(
T* workspace_buffer,
T* output,
T* new_key_cache,
T* new_value_cache) {
T* new_value_cache,
bool use_tf32) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int BN = batch_size * num_heads;
const int BHN = BN * head_size;
Expand Down Expand Up @@ -128,14 +129,14 @@ Status DecoderQkvToContext(
kv_sequence_length, sequence_length, head_size,
&alpha, key_cache, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
kv_sequence_length, sequence_length, head_size,
&alpha, k, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32));
}

constexpr bool is_unidirectional = false;
Expand Down Expand Up @@ -163,14 +164,14 @@ Status DecoderQkvToContext(
head_size, sequence_length, kv_sequence_length,
&one, value_cache, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
&zero, scratch3, head_size, strideB, BN, device_prop, use_tf32));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
head_size, sequence_length, kv_sequence_length,
&one, v, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
&zero, scratch3, head_size, strideB, BN, device_prop, use_tf32));
}

// scratch3 is BxNxSxH, transpose to output SxBxNxH
Expand All @@ -180,6 +181,7 @@ Status DecoderQkvToContext(

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& device_prop,
bool use_tf32,
Stream* stream,
cublasHandle_t& cublas,
const size_t element_size,
Expand Down Expand Up @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel(
reinterpret_cast<half*>(workspace_buffer),
reinterpret_cast<half*>(output),
reinterpret_cast<half*>(new_key_cache),
reinterpret_cast<half*>(new_value_cache));
reinterpret_cast<half*>(new_value_cache),
use_tf32);
} else {
return DecoderQkvToContext(
device_prop,
Expand All @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel(
reinterpret_cast<float*>(workspace_buffer),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(new_key_cache),
reinterpret_cast<float*>(new_value_cache));
reinterpret_cast<float*>(new_value_cache),
use_tf32);
}
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace cuda {

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
bool use_tf32, // Use TF32
Stream* stream, // ORT Stream
cublasHandle_t& cublas, // Cublas handle
const size_t element_size, // Element size of input tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T1>()), n,
reinterpret_cast<const CudaT*>(input->Data<T1>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

// Update the q, k, and v buffers
parameters.q = gemm_buffer.get();
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
weights_data, n,
input_data, k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));
} else {
// q
const CudaT* q_weight = weights_data;
Expand All @@ -145,15 +145,15 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
q_weight, n,
input_data, k,
&zero, q_data, n, device_prop));
&zero, q_data, n, device_prop, UseTF32()));
// k
const CudaT* k_weight = q_weight + static_cast<int64_t>(hidden_size) * hidden_size;
CudaT* k_data = q_data + static_cast<int64_t>(batch_size) * sequence_length * hidden_size;
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
k_weight, n,
input_data, k,
&zero, k_data, n, device_prop));
&zero, k_data, n, device_prop, UseTF32()));

// v
const CudaT* v_weight = k_weight + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -162,7 +162,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
v_weight, n,
input_data, k,
&zero, v_data, n, device_prop));
&zero, v_data, n, device_prop, UseTF32()));
}

// Wait for async copy of batch_global_num
Expand Down Expand Up @@ -195,7 +195,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(global_weights->Data<T>()), n,
input_data, k,
&zero, global_gemm_buffer, n, device_prop));
&zero, global_gemm_buffer, n, device_prop, UseTF32()));
} else {
// global q
const CudaT* global_q_weight = global_weights_data;
Expand All @@ -205,7 +205,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_q_weight, n,
input_data, k,
&zero, global_q, n, device_prop));
&zero, global_q, n, device_prop, UseTF32()));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas,
Expand All @@ -226,7 +226,8 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
hidden_size, // ldc
static_cast<int64_t>(max_num_global) * hidden_size, // strideC
batch_size, // batch count
device_prop));
device_prop,
UseTF32()));
}
// global k
const CudaT* global_k_weight = global_weights_data + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -235,7 +236,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_k_weight, n,
input_data, k,
&zero, global_k, n, device_prop));
&zero, global_k, n, device_prop, UseTF32()));

// global v
const CudaT* global_v_weight = global_k_weight + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -244,7 +245,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_v_weight, n,
input_data, k,
&zero, global_v, n, device_prop));
&zero, global_v, n, device_prop, UseTF32()));
}
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* relative_position_bias = context->Input<Tensor>(5);

PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
Expand Down Expand Up @@ -313,7 +314,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

constexpr size_t element_size = sizeof(T);
constexpr bool no_qkv_workspace = false; // need workspace to add bias
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention(
q, qk_head_size, sequence_length * qk_head_size,
&zero,
scaled_qk, sequence_length, sequence_length * sequence_length,
batches, device_prop));
batches, device_prop, parameters.use_tf32));

DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length);

Expand Down Expand Up @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention(
v_head_size, sequence_length, sequence_length,
&one, v, v_head_size, sequence_length * v_head_size,
attention_score, sequence_length, sequence_length * sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v
Status result = LaunchTransposeRemovePadding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
const Tensor* relative_position_bias = context->Input<Tensor>(6);

PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(),
key,
value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ Status UnfusedAttention(
q, qk_head_size, sequence_length * qk_head_size,
&zero,
scaled_qk, sequence_length, sequence_length * sequence_length,
batches, device_prop));
batches, device_prop, parameters.use_tf32));

// Q, K and V are ready now
DUMP_TENSOR_INIT();
Expand Down Expand Up @@ -808,7 +808,7 @@ Status UnfusedAttention(
v_head_size, sequence_length, sequence_length,
&one, v, v_head_size, sequence_length * v_head_size,
attention_score, sequence_length, sequence_length * sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose and remove padding to output TxNxH_v
Status result = LaunchTransposeRemovePadding(
Expand Down
Loading

0 comments on commit 7b1627c

Please sign in to comment.