Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Add use_tf32 provider option (for FP32 GEMM) #19357

Merged
merged 13 commits into from
Feb 6, 2024
2 changes: 2 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext {
bool cudnn_conv1d_pad_to_nc1d = false;
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;
bool use_tf32 = true;

void Init(const OrtKernelContext& kernel_ctx) {
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
Expand All @@ -52,6 +53,7 @@ struct CudaContext : public CustomOpContext {
cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
use_tf32 = FetchResource<bool>(kernel_ctx, CudaResource::use_tf32_t);
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@
// The strict mode has better accuracy but lower performance.
int prefer_nhwc = 0; // make the CUDA EP NHWC preferred
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
int use_tf32 = 1; // use TF32

Check warning on line 40 in include/onnxruntime/core/providers/cuda/cuda_provider_options.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/providers/cuda/cuda_provider_options.h:40: Lines should be <= 120 characters long [whitespace/line_length] [2]
};
1 change: 1 addition & 0 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ enum CudaResource : int {
cudnn_conv1d_pad_to_nc1d_t,
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
};
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct AttentionParameters {
bool pass_past_in_kv;
float mask_filter_value;
float scale;
bool use_tf32;
AttentionMaskType mask_type;
AttentionQkvFormat qkv_format;
};
Expand All @@ -82,6 +83,7 @@ struct PackedAttentionParameters {
int token_count;
bool has_relative_position_bias;
bool broadcast_res_pos_bias;
bool use_tf32;
};

// Parameters deduced from node attributes and inputs/outputs.
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
parameters.use_tf32 = UseTF32();

// Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr
std::vector<int64_t> bias_dims{weights->Shape().GetDims()[1]};
const TensorShape bias_shape{bias_dims};
Expand Down Expand Up @@ -251,7 +253,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ Status UnfusedAttention(
total_sequence_length, sequence_length, qk_head_size,
&alpha, data.k, qk_head_size, present_size_per_batch_k,
data.q, qk_head_size, sequence_length * qk_head_size,
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches,
device_prop, parameters.use_tf32));

DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size);
DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length);
Expand Down Expand Up @@ -514,7 +515,7 @@ Status UnfusedAttention(
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
Expand Down
18 changes: 10 additions & 8 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,16 @@
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>()), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (h2, h1)*(h1, S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(q_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_query_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_query_buffer in col-base: (h2, S*B)

// calcualte k, v

Check warning on line 285 in onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "calcualte" is a misspelling of "calculate" Raw Output: ./onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc:285:5: "calcualte" is a misspelling of "calculate"
n = 2 * hidden_size;
k = hidden_size;
if (!has_layer_state_ || !use_past_) {
Expand All @@ -298,13 +298,13 @@
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
gemm_kv_buffer_p = GetScratchBuffer<T>(static_cast<size_t>(batch_size) * 2 * key_sequence_length * hidden_size,
Expand All @@ -318,13 +318,13 @@
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(key->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
}
} else {
Expand All @@ -342,13 +342,13 @@
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
reinterpret_cast<const CudaT*>(bias->Data<T>() + hidden_size), n,
GetConstOnes<CudaT>(m, Stream(context)), 1,
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// matmul: (2*h2, h1)*(h1, T_S*B)
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(kv_weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(query->Data<T>()), k,
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop));
&one, reinterpret_cast<CudaT*>(gemm_kv_buffer_p.get()), n, device_prop, UseTF32()));
// gemm_kv_buffer in col-base: (2*h2, T_S*B)
} else {
kv_sequence_length = cache_sequence_length;
Expand All @@ -372,6 +372,8 @@
device_prop,
#ifdef USE_ROCM
GetTuningContext(),
#else
UseTF32(),
#endif
context->GetComputeStream(),
cublas,
Expand Down
18 changes: 11 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ Status DecoderQkvToContext(
T* workspace_buffer,
T* output,
T* new_key_cache,
T* new_value_cache) {
T* new_value_cache,
bool use_tf32) {
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int BN = batch_size * num_heads;
const int BHN = BN * head_size;
Expand Down Expand Up @@ -128,14 +129,14 @@ Status DecoderQkvToContext(
kv_sequence_length, sequence_length, head_size,
&alpha, key_cache, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
kv_sequence_length, sequence_length, head_size,
&alpha, k, head_size, strideA,
q, head_size, strideB,
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop));
&zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32));
}

constexpr bool is_unidirectional = false;
Expand Down Expand Up @@ -163,14 +164,14 @@ Status DecoderQkvToContext(
head_size, sequence_length, kv_sequence_length,
&one, value_cache, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
&zero, scratch3, head_size, strideB, BN, device_prop, use_tf32));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
head_size, sequence_length, kv_sequence_length,
&one, v, head_size, strideA,
scratch2, kv_sequence_length, temp_matrix_size,
&zero, scratch3, head_size, strideB, BN, device_prop));
&zero, scratch3, head_size, strideB, BN, device_prop, use_tf32));
}

// scratch3 is BxNxSxH, transpose to output SxBxNxH
Expand All @@ -180,6 +181,7 @@ Status DecoderQkvToContext(

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& device_prop,
bool use_tf32,
Stream* stream,
cublasHandle_t& cublas,
const size_t element_size,
Expand Down Expand Up @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel(
reinterpret_cast<half*>(workspace_buffer),
reinterpret_cast<half*>(output),
reinterpret_cast<half*>(new_key_cache),
reinterpret_cast<half*>(new_value_cache));
reinterpret_cast<half*>(new_value_cache),
use_tf32);
} else {
return DecoderQkvToContext(
device_prop,
Expand All @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel(
reinterpret_cast<float*>(workspace_buffer),
reinterpret_cast<float*>(output),
reinterpret_cast<float*>(new_key_cache),
reinterpret_cast<float*>(new_value_cache));
reinterpret_cast<float*>(new_value_cache),
use_tf32);
}
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace cuda {

Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
bool use_tf32, // Use TF32
Stream* stream, // ORT Stream
cublasHandle_t& cublas, // Cublas handle
const size_t element_size, // Element size of input tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T1>()), n,
reinterpret_cast<const CudaT*>(input->Data<T1>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

// Update the q, k, and v buffers
parameters.q = gemm_buffer.get();
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
weights_data, n,
input_data, k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));
} else {
// q
const CudaT* q_weight = weights_data;
Expand All @@ -145,15 +145,15 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
q_weight, n,
input_data, k,
&zero, q_data, n, device_prop));
&zero, q_data, n, device_prop, UseTF32()));
// k
const CudaT* k_weight = q_weight + static_cast<int64_t>(hidden_size) * hidden_size;
CudaT* k_data = q_data + static_cast<int64_t>(batch_size) * sequence_length * hidden_size;
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
k_weight, n,
input_data, k,
&zero, k_data, n, device_prop));
&zero, k_data, n, device_prop, UseTF32()));

// v
const CudaT* v_weight = k_weight + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -162,7 +162,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
v_weight, n,
input_data, k,
&zero, v_data, n, device_prop));
&zero, v_data, n, device_prop, UseTF32()));
}

// Wait for async copy of batch_global_num
Expand Down Expand Up @@ -195,7 +195,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(global_weights->Data<T>()), n,
input_data, k,
&zero, global_gemm_buffer, n, device_prop));
&zero, global_gemm_buffer, n, device_prop, UseTF32()));
} else {
// global q
const CudaT* global_q_weight = global_weights_data;
Expand All @@ -205,7 +205,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_q_weight, n,
input_data, k,
&zero, global_q, n, device_prop));
&zero, global_q, n, device_prop, UseTF32()));
} else {
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas,
Expand All @@ -226,7 +226,8 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
hidden_size, // ldc
static_cast<int64_t>(max_num_global) * hidden_size, // strideC
batch_size, // batch count
device_prop));
device_prop,
UseTF32()));
}
// global k
const CudaT* global_k_weight = global_weights_data + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -235,7 +236,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_k_weight, n,
input_data, k,
&zero, global_k, n, device_prop));
&zero, global_k, n, device_prop, UseTF32()));

// global v
const CudaT* global_v_weight = global_k_weight + static_cast<int64_t>(hidden_size) * hidden_size;
Expand All @@ -244,7 +245,7 @@ Status LongformerAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
global_v_weight, n,
input_data, k,
&zero, global_v, n, device_prop));
&zero, global_v, n, device_prop, UseTF32()));
}
}

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {

auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
parameters.use_tf32 = UseTF32();

ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
value,
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* relative_position_bias = context->Input<Tensor>(5);

PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
Expand Down Expand Up @@ -308,12 +309,12 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
cublasHandle_t cublas = this->GetCublasHandle(context);

// Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias
// The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel.
// The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel.
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
reinterpret_cast<const CudaT*>(weights->Data<T>()), n,
reinterpret_cast<const CudaT*>(input->Data<T>()), k,
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop, UseTF32()));

constexpr size_t element_size = sizeof(T);
constexpr bool no_qkv_workspace = false; // need workspace to add bias
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention(
q, qk_head_size, sequence_length * qk_head_size,
&zero,
scaled_qk, sequence_length, sequence_length * sequence_length,
batches, device_prop));
batches, device_prop, parameters.use_tf32));

DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length);

Expand Down Expand Up @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention(
v_head_size, sequence_length, sequence_length,
&one, v, v_head_size, sequence_length * v_head_size,
attention_score, sequence_length, sequence_length * sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop));
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v
Status result = LaunchTransposeRemovePadding(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
const Tensor* relative_position_bias = context->Input<Tensor>(6);

PackedAttentionParameters parameters;
parameters.use_tf32 = UseTF32();
ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(),
key,
value,
Expand Down
Loading
Loading