Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] add arithmetic_mode impl for bf16_emitters #27737

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

liubo-intel
Copy link
Contributor

@liubo-intel liubo-intel commented Nov 26, 2024

Details:

  • add arithmetic_mode impl for bf16_emitters to fix the 'inf' out issue when input data is out of rang [bf16_min,bf16_max] during f32->bf16, which may lead to 'UNK' outputs in some LLM cases

Tickets:

@liubo-intel liubo-intel requested review from a team as code owners November 26, 2024 02:52
@github-actions github-actions bot added the category: CPU OpenVINO CPU plugin label Nov 26, 2024
@liubo-intel
Copy link
Contributor Author

Hi, @chenhu-wang and @dmitry-gorokhov : could you please help review this pr ? thanks

@chenhu-wang
Copy link
Contributor

@liubo-intel , You use fixed "saturation" mode on avx512_core_bf16 ISA, but on some ISA like avx2, fixed "truncation" mode is used. This is not aligned.
We have arithmetic_mode mode_ in jit_store_emitter. jit_uni_vcvtneps2bf16 is used in store f32 to bf16 case, where the mode info is not respect there. I think we should take mode_(truncation or saturation) into jit_uni_vcvtneps2bf16 as a field and convert according to the mode. We can set the mode "saturation" as default.

@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch from 8f98525 to 99e994a Compare November 28, 2024 11:11
@liubo-intel liubo-intel changed the title [CPU] add truncation impl for vcvtneps2bf16 ISA [CPU] add arithmetic_mode impl for bf16_emitters Nov 28, 2024
@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch from 99e994a to f0b3f10 Compare November 29, 2024 06:04
@chenhu-wang
Copy link
Contributor

chenhu-wang commented Dec 3, 2024

@liubo-intel The positive overflow data is 0x7F7FXXXX. X could be any value. The convert should do smt as following function:

bf16 fp32_to_bf16(float data) {
    if (data is overflow) {
        if (saturation mode)
            return bf16_max;    // bf16_max is 0x7F7F
        else // truncation mode
            return 0x7F7F;       // drop 16 lower bit, directly
    } else {
        return vcvtneps2bf16(data);
    }
}

As we can see the saturation mode and truncation mode return same result for overflowed data. I think we can remove the explicit mode for simplicity, extend mode when it is really needed explicitly in future.
For the vectorized version in implementation, we can saturation the f32 overflow data 0x7F7FXXXX to f32 data 0x7F7F0000. Then ne round the vector along with another data. As the round bit is 0 for 0x7F7F0000, it is still 0x7F7F0000 after ne round. This keep result correct. So what we need is like a pre process to saturate the overflowed input. The remaining vcvtneps2bf16 have no change. The negative overflow is the same logic.
What do you think?

@liubo-intel
Copy link
Contributor Author

liubo-intel commented Dec 3, 2024

@liubo-intel The positive overflow data is 0x7F7FXXXX. X could be any value. The convert should do smt as following function:

bf16 fp32_to_bf16(float data) {
    if (data is overflow) {
        if (saturation mode)
            return bf16_max;    // bf16_max is 0x7F7F
        else // truncation mode
            return 0x7F7F;       // drop 16 lower bit, directly
    } else {
        return vcvtneps2bf16(data);
    }
}

As we can see the saturation mode and truncation mode return same result for overflowed data. I think we can remove the explicit mode for simplicity, extend mode when it is really needed explicitly in future. For the vectorized version in implementation, we can saturation the f32 overflow data 0x7F7FXXXX to f32 data 0x7F7F0000. Then ne round the vector along with another data. As the round bit is 0 for 0x7F7F0000, it is still 0x7F7F0000 after ne round. This keep result correct. So what we need is like a pre process to saturate the overflowed input. The remaining vcvtneps2bf16 have no change. The negative overflow is the same logic. What do you think?

Hi, @chenhu-wang : as we synced offline, 'vcvtneps2bf16' is not used for truncation mode because this instruction output of overflow data is 'inf' instead of truncation values.
And I have tried to use 'saturate_input()' function for both mode(saturation and truncation), and keep the following process the same as previous before this pr in order to keep use 'vcvtneps2bf16' instruction for bf16 platform truncation mode, but this method will make the "src/frontends/onnx/tests/tests_python/test_backend." CI fail, I think it means this method not align with onnx truncation impl. so it seems we need to keep both saturation and truncation impl in order to keep the correct outputs.
for the performance concern, maybe let's consider it after perf tests?

@chenhu-wang
Copy link
Contributor

Hi, @chenhu-wang : as we synced offline, 'vcvtneps2bf16' is not used for truncation mode because this instruction output of overflow data is 'inf' instead of truncation values. And I have tried to use 'saturate_input()' function for both mode(saturation and truncation), and keep the following process the same as previous before this pr in order to keep use 'vcvtneps2bf16' instruction for bf16 platform truncation mode, but this method will make the "src/frontends/onnx/tests/tests_python/test_backend." CI fail, I think it means this method not align with onnx truncation impl. so it seems we need to keep both saturation and truncation impl in order to keep the correct outputs. for the performance concern, maybe let's consider it after perf tests?

The input of vcvtneps2bf16 is bf16_max for overflow data after saturation. vcvtneps2bf16 to bf16_max is still bf16_max, not clear why inf is generated. How onnx truncation impl work, could you please elaborate more detail?

@liubo-intel
Copy link
Contributor Author

Hi, @chenhu-wang : as we synced offline, 'vcvtneps2bf16' is not used for truncation mode because this instruction output of overflow data is 'inf' instead of truncation values. And I have tried to use 'saturate_input()' function for both mode(saturation and truncation), and keep the following process the same as previous before this pr in order to keep use 'vcvtneps2bf16' instruction for bf16 platform truncation mode, but this method will make the "src/frontends/onnx/tests/tests_python/test_backend." CI fail, I think it means this method not align with onnx truncation impl. so it seems we need to keep both saturation and truncation impl in order to keep the correct outputs. for the performance concern, maybe let's consider it after perf tests?

The input of vcvtneps2bf16 is bf16_max for overflow data after saturation. vcvtneps2bf16 to bf16_max is still bf16_max, not clear why inf is generated. How onnx truncation impl work, could you please elaborate more detail?

Hi, @chenhu-wang : the failed CI case is an onnx backend testcase(the same as shown in commit f0b3f10 CI failure of this pr). I'm not sure whether is it worthy to investigate deeper this onnx backend impl or involve onnx team related colleagues, if the impact of current method on performance is within an acceptable range. Hi, @dmitry-gorokhov what's your suggestions about this?

@liubo-intel
Copy link
Contributor Author

liubo-intel commented Dec 6, 2024

Hi, @chenhu-wang and @dmitry-gorokhov : I think I know the reason why onnx backend test fails if we use 'saturation' method(function) for truncation mode. this 'saturation' function will change the original (f32)input ['nan, -nan, inf, -inf'] values which is not expected.
so it seems we should also handle these special values in 'saturation' method(function) besides normal overflow values.

@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch from 160c7b8 to 72444fb Compare December 6, 2024 07:17
Comment on lines 55 to 82
h->uni_vpaddd(aux, in, aux);
h->vfixupimmps(aux, in, table_val("selector"), 0);
h->uni_vpaddd(aux, clamped, aux);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please double check we can remove h->vfixupimmps(aux, in, table_val("selector"), 0); here. Lines above have possibility to spoil the inf/nan bit again.

Copy link
Contributor Author

@liubo-intel liubo-intel Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 58 to 68
Vmm in = Vmm(in_vec_idxs[0]);
Vmm clamped = Vmm(aux_vec_idxs[0]);
saturate_input(clamped, in, "bf16_min", "bf16_max");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can consider reuse Vmm "in" with Vmm "clamped". Vmm "in" will store clamped data and save one aux. clamp does not break "keep_source_intact" as I understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saturate_input(in, in, "bf16_min", "bf16_max"); will make the wrong outputs, maybe because the temp data will need this original 'in' data?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is:

saturate_input(clamped, in, "bf16_min", "bf16_max");
h->uni_vmovups(in, clamped);

clamped is just a aux vec, do not hold for long. Just a taste.

@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch from 60ffe50 to a5c2669 Compare December 6, 2024 13:49
@yuxu42
Copy link
Contributor

yuxu42 commented Dec 9, 2024

Hi @dmitry-gorokhov Could you please take a review? Thanks!

@wenjiew wenjiew added this to the 2025.0 milestone Dec 9, 2024
@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch from a5c2669 to fd3e1fb Compare December 11, 2024 09:57
@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch 2 times, most recently from ff1294d to b43892d Compare December 12, 2024 07:56
@liubo-intel
Copy link
Contributor Author

liubo-intel commented Dec 12, 2024

Hi, @chenhu-wang and @dmitry-gorokhov : since I found truncation impl instructions may affect the model's infer time in certain situations, how about we limit these truncation impl to eltwise Constant inputs only?(commit:d5c00a2) . From my understanding, overflow cases during f32->bf16 are mostly from input constant values(models are trained in f32 precision). Instead, activation overflow values could(or to be) handled by runtime process kernels.

…p acc fix and minimize the impact on performance
@liubo-intel liubo-intel force-pushed the liubo/bf16_emitter_trunc_impl branch from b43892d to d5c00a2 Compare December 12, 2024 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: CPU OpenVINO CPU plugin
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants