Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BF16 using generic_float class #3578

Merged
merged 74 commits into from
Nov 21, 2024
Merged

Implement BF16 using generic_float class #3578

merged 74 commits into from
Nov 21, 2024

Conversation

richagadgil
Copy link
Contributor

@richagadgil richagadgil commented Oct 30, 2024

Uses generic float class (#3522) to create bf16 class.

BF16 has 1 sign bit, 8 bits for the exponent, and 7 bits for the mantissa: bf16 = migraphx::generic_float<7, 8>;

Summary of changes:

  1. generic_float.cpp : Change the subnormal (when exponent==0) conversion to differentiate between FP16 and BF16 types
  2. migraphx.h, shape.hpp, hip_gemm_impl.cpp, gemm_impl.cpp: Add BF16 shape type
  3. type_traits.hpp: Add traits for BF16 type
  4. tests/: Add tests for BF16 type

test/op_shape_test.cpp Outdated Show resolved Hide resolved
{
static std::string format()
{
// TODO: no standard format in numpy for bf16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an issue or is this tracked somewhere already?

@TedThemistokleous
Copy link
Collaborator

@richagadgil update your description. I think we've done a pass on this with some comments.

test/op_shape_test.cpp Outdated Show resolved Hide resolved
@TedThemistokleous TedThemistokleous added the roadmap Tasks to finish for a release label Nov 18, 2024
@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 20, 2024

For the CI failures, just add an if statement to skip bf16 for now in the compile_math case in jit.cpp. That should get it passing and we can work on enabling it on a later PR.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
47a181
Rate old
0f36aa
Diff Compare
torchvision-resnet50 64 3,254.36 3,261.99 -0.23%
torchvision-resnet50_fp16 64 6,989.58 6,984.41 0.07%
torchvision-densenet121 32 2,435.66 2,434.46 0.05%
torchvision-densenet121_fp16 32 4,085.94 4,068.77 0.42%
torchvision-inceptionv3 32 1,628.48 1,630.14 -0.10%
torchvision-inceptionv3_fp16 32 2,745.44 2,746.22 -0.03%
cadene-inceptionv4 16 764.67 765.59 -0.12%
cadene-resnext64x4 16 810.97 809.78 0.15%
slim-mobilenet 64 7,467.12 7,474.57 -0.10%
slim-nasnetalarge 64 208.49 208.58 -0.05%
slim-resnet50v2 64 3,442.62 3,441.49 0.03%
bert-mrpc-onnx 8 1,148.81 1,150.80 -0.17%
bert-mrpc-tf 1 466.40 465.54 0.18%
pytorch-examples-wlang-gru 1 419.06 420.06 -0.24%
pytorch-examples-wlang-lstm 1 473.00 381.98 23.83% 🔆
torchvision-resnet50_1 1 761.83 750.44 1.52%
cadene-dpn92_1 1 402.29 398.35 0.99%
cadene-resnext101_1 1 382.31 382.96 -0.17%
onnx-taau-downsample 1 345.79 346.08 -0.08%
dlrm-criteoterabyte 1 33.34 33.35 -0.02%
dlrm-criteoterabyte_fp16 1 52.76 52.68 0.15%
agentmodel 1 8,177.28 8,091.53 1.06%
unet_fp16 2 58.83 58.77 0.09%
resnet50v1_fp16 1 926.85 943.16 -1.73%
resnet50v1_int8 1 994.66 1,012.12 -1.72%
bert_base_cased_fp16 64 1,170.84 1,169.97 0.07%
bert_large_uncased_fp16 32 363.41 363.75 -0.10%
bert_large_fp16 1 198.71 199.03 -0.16%
distilgpt2_fp16 16 2,197.22 2,201.98 -0.22%
yolov5s 1 535.63 539.79 -0.77%
tinyllama 1 43.65 43.42 0.53%
vicuna-fastchat 1 177.06 175.75 0.75%
whisper-tiny-encoder 1 418.16 418.02 0.03%
whisper-tiny-decoder 1 423.43 428.37 -1.15%

Check results before merge 🔆

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@causten causten merged commit 952a257 into develop Nov 21, 2024
40 of 45 checks passed
@causten causten deleted the bf16 branch November 21, 2024 02:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Tasks to finish for a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants