diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 1370f5c4c5e10..108173474db46 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext { bool cudnn_conv1d_pad_to_nc1d = false; bool enable_skip_layer_norm_strict_mode = false; bool prefer_nhwc = false; + bool use_tf32 = true; void Init(const OrtKernelContext& kernel_ctx) { cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); @@ -52,6 +53,7 @@ struct CudaContext : public CustomOpContext { cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); enable_skip_layer_norm_strict_mode = FetchResource(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); + use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); } template diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 82bb8ba83be4a..6d53760ab60b5 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 { // The strict mode has better accuracy but lower performance. int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not + int use_tf32 = 1; // use TF32 }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index c0e6328f27122..1fef077860be3 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -18,4 +18,5 @@ enum CudaResource : int { cudnn_conv1d_pad_to_nc1d_t, enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, + use_tf32_t, }; \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 8afeb874750b4..a34f41d2938c6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -64,6 +64,7 @@ struct AttentionParameters { bool pass_past_in_kv; float mask_filter_value; float scale; + bool use_tf32; AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -82,6 +83,7 @@ struct PackedAttentionParameters { int token_count; bool has_relative_position_bias; bool broadcast_res_pos_bias; + bool use_tf32; }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index bf6431cf1afb2..7a807342ad685 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -84,6 +84,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + // Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr std::vector bias_dims{weights->Shape().GetDims()[1]}; const TensorShape bias_shape{bias_dims}; @@ -251,7 +253,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 54c9a5da1e9da..c20f42c4d06bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -461,7 +461,8 @@ Status UnfusedAttention( total_sequence_length, sequence_length, qk_head_size, &alpha, data.k, qk_head_size, present_size_per_batch_k, data.q, qk_head_size, sequence_length * qk_head_size, - &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, + device_prop, parameters.use_tf32)); DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); @@ -514,7 +515,7 @@ Status UnfusedAttention( v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 3f703ae3d05e6..ceee17c2a2d01 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -273,13 +273,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data()), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (h2, h1)*(h1, S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(q_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) // calcualte k, v @@ -298,13 +298,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, @@ -318,13 +318,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(key->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } } else { @@ -342,13 +342,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { kv_sequence_length = cache_sequence_length; @@ -372,6 +372,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { device_prop, #ifdef USE_ROCM GetTuningContext(), +#else + UseTF32(), #endif context->GetComputeStream(), cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index 1dc22a9c8ea98..e24d9da94c964 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -37,7 +37,8 @@ Status DecoderQkvToContext( T* workspace_buffer, T* output, T* new_key_cache, - T* new_value_cache) { + T* new_value_cache, + bool use_tf32) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -128,14 +129,14 @@ Status DecoderQkvToContext( kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } constexpr bool is_unidirectional = false; @@ -163,14 +164,14 @@ Status DecoderQkvToContext( head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } // scratch3 is BxNxSxH, transpose to output SxBxNxH @@ -180,6 +181,7 @@ Status DecoderQkvToContext( Status LaunchDecoderAttentionKernel( const cudaDeviceProp& device_prop, + bool use_tf32, Stream* stream, cublasHandle_t& cublas, const size_t element_size, @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } else { return DecoderQkvToContext( device_prop, @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h index 9db9ccb45e330..f9667a613e648 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -11,6 +11,7 @@ namespace cuda { Status LaunchDecoderAttentionKernel( const cudaDeviceProp& prop, // Device Properties + bool use_tf32, // Use TF32 Stream* stream, // ORT Stream cublasHandle_t& cublas, // Cublas handle const size_t element_size, // Element size of input tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 72ede2e22b557..07a6fbd60e171 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); // Update the q, k, and v buffers parameters.q = gemm_buffer.get(); diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index e556ae4a490e9..9c5d0e9834f6f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -136,7 +136,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, weights_data, n, input_data, k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); } else { // q const CudaT* q_weight = weights_data; @@ -145,7 +145,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, q_weight, n, input_data, k, - &zero, q_data, n, device_prop)); + &zero, q_data, n, device_prop, UseTF32())); // k const CudaT* k_weight = q_weight + static_cast(hidden_size) * hidden_size; CudaT* k_data = q_data + static_cast(batch_size) * sequence_length * hidden_size; @@ -153,7 +153,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, k_weight, n, input_data, k, - &zero, k_data, n, device_prop)); + &zero, k_data, n, device_prop, UseTF32())); // v const CudaT* v_weight = k_weight + static_cast(hidden_size) * hidden_size; @@ -162,7 +162,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, v_weight, n, input_data, k, - &zero, v_data, n, device_prop)); + &zero, v_data, n, device_prop, UseTF32())); } // Wait for async copy of batch_global_num @@ -195,7 +195,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(global_weights->Data()), n, input_data, k, - &zero, global_gemm_buffer, n, device_prop)); + &zero, global_gemm_buffer, n, device_prop, UseTF32())); } else { // global q const CudaT* global_q_weight = global_weights_data; @@ -205,7 +205,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_q_weight, n, input_data, k, - &zero, global_q, n, device_prop)); + &zero, global_q, n, device_prop, UseTF32())); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, @@ -226,7 +226,8 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { hidden_size, // ldc static_cast(max_num_global) * hidden_size, // strideC batch_size, // batch count - device_prop)); + device_prop, + UseTF32())); } // global k const CudaT* global_k_weight = global_weights_data + static_cast(hidden_size) * hidden_size; @@ -235,7 +236,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_k_weight, n, input_data, k, - &zero, global_k, n, device_prop)); + &zero, global_k, n, device_prop, UseTF32())); // global v const CudaT* global_v_weight = global_k_weight + static_cast(hidden_size) * hidden_size; @@ -244,7 +245,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_v_weight, n, input_data, k, - &zero, global_v, n, device_prop)); + &zero, global_v, n, device_prop, UseTF32())); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index f978f50c6851f..2ef011cdd9a21 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -94,6 +94,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index ec8b1d051b3d9..55deed55dfd33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -268,6 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -308,12 +309,12 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = this->GetCublasHandle(context); // Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias - // The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel. + // The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel. CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 3b52320839403..ce7ac3796dbe1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 1b026e64778e3..b4a162989978c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 83af018a97ea6..49029da12a308 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -775,7 +775,7 @@ Status UnfusedAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); // Q, K and V are ready now DUMP_TENSOR_INIT(); @@ -808,7 +808,7 @@ Status UnfusedAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output TxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 92ba808dd85c2..05f55d9106d0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -200,7 +200,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c D, BNS, head_size, &one, reinterpret_cast(weight_tensor.template Data()), (int)D, reinterpret_cast(workspace.get()), (int)head_size, - &zero, gemm_output, ld_gemm_output, device_prop)); + &zero, gemm_output, ld_gemm_output, device_prop, UseTF32())); auto status = LaunchGatedRelativePositionBiasKernel( device_prop, stream, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 705f2d49fe2bf..001b6070d5e1a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -106,6 +106,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_tensor = context->Input(8); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bbcb7de99781f..0534ed6dc7fc0 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -117,7 +117,8 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 5b0e61e197014..015df70c8ec3c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -135,7 +135,8 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index d0bb2321edf0a..55f0b5570e0ee 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -78,6 +78,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } bool IsNHWCPreferred() const { return info_.prefer_nhwc; } + bool UseTF32() const { return info_.use_tf32; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 81ddc38820914..c96381e3e68b1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -33,6 +33,7 @@ constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_dur constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kUseTF32 = "use_tf32"; } // namespace provider_option_names } // namespace cuda @@ -115,6 +116,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -167,6 +169,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; @@ -188,6 +191,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 04eea2f6c8e94..1cac3d1513698 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -76,6 +76,9 @@ struct CUDAExecutionProviderInfo { bool use_ep_level_unified_stream{false}; + // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. + bool use_tf32{true}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -100,7 +103,8 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { (static_cast(info.cudnn_conv1d_pad_to_nc1d) << 26) ^ (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ (static_cast(info.prefer_nhwc) << 28) ^ - (static_cast(info.use_ep_level_unified_stream) << 29); + (static_cast(info.use_ep_level_unified_stream) << 29) ^ + (static_cast(info.use_tf32) << 30); onnxruntime::HashCombine(data, value); onnxruntime::HashCombine(info.gpu_mem_limit, value); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index e3106e41e77c8..288da23f35ec8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -90,6 +90,10 @@ class CudaKernel : public OpKernel { return stream->cublas_handle_; } + bool UseTF32() const { + return provider_->UseTF32(); + } + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 892e8d5329eba..103c79c93b2ca 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -225,6 +225,7 @@ struct CUDA_Provider : Provider { info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; + info.use_tf32 = params->use_tf32 != 0; return std::make_shared(info); } @@ -258,6 +259,7 @@ struct CUDA_Provider : Provider { cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; + cuda_options.use_tf32 = internal_options.use_tf32; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 0a256394b7d99..3c0bf183362dd 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -212,6 +212,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::prefer_nhwc_t: return reinterpret_cast(ep_info_.prefer_nhwc); break; + case CudaResource::use_tf32_t: + return reinterpret_cast(ep_info_.use_tf32); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index 3e50116eafd17..ee0334e552022 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -51,25 +51,27 @@ Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, CudaT one = cuda::ToCudaType::FromFloat(1.0f); CudaT zero = cuda::ToCudaType::FromFloat(0.0f); - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(static_cast(einsum_cuda_assets)->cublas_handle_, - CUBLAS_OP_N, - CUBLAS_OP_N, - static_cast(N), - static_cast(M), - static_cast(K), - &one, - reinterpret_cast(input_2_data), - static_cast(N), - static_cast(right_stride), - reinterpret_cast(input_1_data), - static_cast(K), - static_cast(left_stride), - &zero, - reinterpret_cast(output_data), - static_cast(N), - static_cast(output_stride), - static_cast(num_batches), - static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp())); + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + static_cast(einsum_cuda_assets)->cublas_handle_, + CUBLAS_OP_N, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &one, + reinterpret_cast(input_2_data), + static_cast(N), + static_cast(right_stride), + reinterpret_cast(input_1_data), + static_cast(K), + static_cast(left_stride), + &zero, + reinterpret_cast(output_data), + static_cast(N), + static_cast(output_stride), + static_cast(num_batches), + static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp(), + static_cast(einsum_cuda_assets)->cuda_ep_->UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 8fe23c9a036cc..4e61e0c8c69c6 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -118,7 +118,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const b_data, N, GetConstOnes(M, Stream(ctx)), 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) { // B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( @@ -130,7 +130,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const GetConstOnes(N, Stream(ctx)), N, b_data, 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else { // B is (M, N), no broadcast needed. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, static_cast(M) * N * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); @@ -153,7 +153,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const // ideally we need to set the output buffer contents to 0 if bias is missing, // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer B != nullptr ? &beta : &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index e4c37c52a1780..6e126fbeadce8 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -173,7 +173,8 @@ Status FuncMatMul( &cuda_zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(A->Shape(), B->Shape(), trans_A, trans_B, trans_batch_B, trans_batch_B, stride_A, stride_B, stride_C, batch_count)) { @@ -195,7 +196,8 @@ Status FuncMatMul( ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } @@ -213,12 +215,12 @@ Status FuncMatMul( ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, cuda_kernel->GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + cuda_kernel->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -238,7 +240,8 @@ Status FuncMatMul( Y_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } @@ -321,7 +324,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help &zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(left_X->Shape(), right_X->Shape(), transa, transb, trans_batch_a_, trans_batch_b_, stride_A, stride_B, stride_C, batch_count)) { @@ -343,7 +347,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } @@ -361,12 +366,12 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + this->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -386,7 +391,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help output_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 510cc5cfbb7dd..053c66ddcb34a 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -29,13 +29,15 @@ cublasGemmHelper(cublasHandle_t handle, const float* B, int ldb, const float* beta, float* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - // TF32 uses 10 bit mantissa which has sufficient margin of precision for most use cases. It gets 8x throughput than FP32 in A100. - // It can be overrided by setting environment variable NVIDIA_TF32_OVERRIDE = 0 to disable TF32 - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); + // To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 or set provider option use_tf32 = 0 + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemm(handle, @@ -58,7 +60,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const double* B, int ldb, const double* beta, double* C, int ldc, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemm(handle, transa, transb, @@ -79,7 +82,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const half* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -121,7 +125,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -155,10 +160,11 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, - const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, + const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -169,7 +175,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*, - BFloat16*, int, const cudaDeviceProp&) { + BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -185,7 +191,17 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const float* beta, float* Carray[], int ldc, int batch_count, - const cudaDeviceProp&) { + const cudaDeviceProp& prop, + bool use_tf32) { +// The caller shall check memory alignments of the matrices when use_tf32 is true. +#if defined(USE_CUDA) + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); +#else + ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); +#endif + return cublasSgemmBatched(handle, transa, transb, @@ -208,7 +224,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const double* beta, double* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmBatched(handle, transa, transb, @@ -231,7 +248,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const half* beta, half* Carray[], int ldc, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -266,11 +284,12 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], - int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, - BFloat16* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], + int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, + BFloat16* Carray[], int ldc, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -282,7 +301,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOpera #else inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*[], int, const BFloat16*[], int, - const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&) { + const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&, + bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -301,15 +321,14 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, float* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { -#ifdef ENABLE_TRAINING_OPS + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); -#else - ORT_UNUSED_PARAMETER(prop); -#endif + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemmStridedBatched(handle, @@ -337,7 +356,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, double* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmStridedBatched(handle, transa, transb, @@ -363,7 +383,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -411,7 +432,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -447,49 +469,66 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const BFloat16* alpha, const BFloat16* A, int lda, - long long int strideA, const BFloat16* B, int ldb, - long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, - long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const BFloat16* alpha, const BFloat16* A, int lda, + long long int strideA, const BFloat16* B, int ldb, + long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, + long long int strideC, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); // accumulating in FP32 - return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, - ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); + return cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, + ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT); } #else -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, - int, const BFloat16*, const BFloat16*, int, long long int, - const BFloat16*, int, long long int, const BFloat16*, BFloat16*, - int, long long int, int, const cudaDeviceProp&) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + int, const BFloat16*, const BFloat16*, int, long long int, + const BFloat16*, int, long long int, const BFloat16*, BFloat16*, + int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif // transpose using geam -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, + float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, + double* C, int ldc) { return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } bool CanUse_cublasTransposeHelper_MLFloat16(int m, int n); -cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); + +cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, + int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); // copy -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { return cublasScopy(handle, n, x, incx, y, incy); } -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { return cublasDcopy(handle, n, x, incx, y, incy); } -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); + +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index c0b7d4722d3e4..70bf08d65401a 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -101,6 +101,10 @@ class RocmKernel : public OpKernel { return static_cast(provider_->GetTuningContext()); } + bool UseTF32() const { + return false; + } + // To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory // and it can only be released after the copy has finished template diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 7cbc37cb64c5a..d93f70785c093 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -115,7 +115,8 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmHelper(handle, transa, transb, @@ -154,7 +155,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } -// Compatible for function call with the extra hipDeviceProp_t argument +// Compatible for function call with extra arguments (see cublasGemmHelper) template rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_operation transa, @@ -165,7 +166,8 @@ rocblas_status rocblasGemmHelper(rocblas_handle handle, const Scalar* B, int ldb, const Scalar* beta, Scalar* C, int ldc, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmHelper(handle, transa, transb, @@ -404,7 +406,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } -// Compatible for function call with the extra hipDeviceProp_t argument +// Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) template rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_operation transa, @@ -419,7 +421,8 @@ rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, Scalar* C, int ldc, intmax_t strideC, int batchCount, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmStridedBatchedHelper(handle, transa, transb, @@ -445,7 +448,8 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, __half* C, int ldc, intmax_t strideC, int batchCount, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmStridedBatchedHelper(handle, transa, transb, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 32ae15e71acc6..bb8732784945d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1555,6 +1555,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; cuda_options_converted.use_ep_level_unified_stream = 0; + cuda_options_converted.use_tf32 = 1; return cuda_options_converted; } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu index fd9e9c4fd1612..8b05b96ec38a9 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu @@ -56,6 +56,9 @@ class GemmBenchmark : public IKernelExplorer { typedef typename ToCudaType::MappedType CudaT; CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); + + // TF32 is enable by default. To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 + constexpr bool use_tf32 = true; CUBLAS_CALL_THROW(cublasGemmHelper( params_.cublas_handle, CUBLAS_OP_N, @@ -69,7 +72,8 @@ class GemmBenchmark : public IKernelExplorer { &zero, params_.output_, params_.n_, - device_prop_)); + device_prop_, + use_tf32)); } private: @@ -79,11 +83,11 @@ class GemmBenchmark : public IKernelExplorer { cudaDeviceProp device_prop_; }; -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ .def("Run", &name::Run); KE_REGISTER(m) { diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index 09baf8def05f6..31ef62e69bb88 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -433,7 +433,8 @@ static void RunModelWithRandomInput( std::vector token_offset_dims{batch_size, sequence_length}; std::vector cum_seq_len_dims{batch_size + 1}; - float gpu_threshold = is_float16 ? 0.15f : 0.005f; + // TF32 in SM >= 80 is enabled by default, need larger threshold for float when TF32 is enabled. + float gpu_threshold = is_float16 ? 0.15f : (HasCudaEnvironment(800) ? 0.05f : 0.005f); gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); if (enable_cuda) { diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 5b41806b646af..91b6c71e735a8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -428,6 +428,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) + test_get_and_set_option_with_values("use_tf32", ["1", "0"]) + option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" option["gpu_external_empty_cache"] = "0"