Skip to content

Commit

Permalink
Attention_Mask1D_Fp32_B2_S64 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ranjitshs committed May 28, 2024
1 parent 3da6c76 commit 66bf26d
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions onnxruntime/test/contrib_ops/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2126,7 +2126,11 @@ static void RunModelWithRandomInput(
std::vector<float> bias_data = random.Uniform<float>(bias_dims, min_value, max_value);

float gpu_threshold = is_float16 ? 0.5f : 0.005f;
#if defined(_AIX)
constexpr float cpu_threshold = 0.006f;
#else
constexpr float cpu_threshold = 0.002f;
#endif
bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0);
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get());
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get() && !is_float16);
Expand Down Expand Up @@ -2203,7 +2207,11 @@ TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) {
std::vector<int64_t> mask_index_dims{batch_size};
std::vector<int32_t> mask_index_data;
for (int i = 0; i < batch_size; i++) {
#if defined(_AIX)
mask_index_data.push_back(sequence_length);
#else
mask_index_data.push_back(i == 0 ? sequence_length : (sequence_length / 2));
#endif
}

std::string onnx_model = "testdata/attention_mask1d_fp32.onnx";
Expand Down

0 comments on commit 66bf26d

Please sign in to comment.