Skip to content

Commit

Permalink
update makfiltervalue and test values for u8s8
Browse files Browse the repository at this point in the history
  • Loading branch information
raoanag committed Mar 1, 2024
1 parent f8eecc8 commit 93c9966
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ class DmlOperatorQAttention : public DmlOperator
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
mhaOperatorDesc.MaskFilterValue = std::numeric_limits<float>::lowest();
// Set MaskFilterValue to lowest float for Causal Mask

Check warning on line 419 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp:419: Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits<float>::lowest() :
kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
mhaOperatorDesc.HeadCount = numHeads;
mhaOperatorDesc.MaskType = maskType;
if (hasPast)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);

constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
constexpr WeightT weight_min = constexpr(std::is_same_v<WeightT, int8_t>) ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();

Check warning on line 914 in onnxruntime/test/contrib_ops/quantize_attention_op_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/contrib_ops/quantize_attention_op_test.cc:914: Lines should be <= 120 characters long [whitespace/line_length] [2]
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
constexpr int32_t weight_range = weight_max - weight_min;

std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
Expand Down

0 comments on commit 93c9966

Please sign in to comment.