Skip to content

Commit

Permalink
[CPU] Fix mamtulnbits accuracy level (microsoft#22963)
Browse files Browse the repository at this point in the history
### Description
Fix mamtulnbits accuracy level



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
fajin-corp authored and ankitm3k committed Dec 11, 2024
1 parent 14950cc commit 3828c33
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ constexpr size_t A = 0,
};

typedef enum {
Level0, /*!< input fp32, accumulator fp32 */
Level1, /*!< input fp32, accumulator fp32 */
Level2, /*!< input fp16, accumulator fp16 */
Level3, /*!< input bf16, accumulator fp32 */
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,12 @@ void TestMatMulNBitsTyped() {
base_opts.block_size = block_size;
base_opts.accuracy_level = accuracy_level;

if constexpr (std::is_same<AType, MLFloat16>::value) {
if (base_opts.accuracy_level == 4) {
base_opts.output_abs_error = 0.1f;
base_opts.output_rel_error = 0.02f;
} else if constexpr (std::is_same<AType, MLFloat16>::value) {
base_opts.output_abs_error = 0.055f;
base_opts.output_rel_error = 0.02f;
} else if (base_opts.accuracy_level == 4) {
base_opts.output_abs_error = 0.1f;
}

{
Expand Down

0 comments on commit 3828c33

Please sign in to comment.