Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propagate_precision pass #2853

Open
wants to merge 28 commits into
base: develop
Choose a base branch
from
Open

Add propagate_precision pass #2853

wants to merge 28 commits into from

Conversation

pfultz2
Copy link
Collaborator

@pfultz2 pfultz2 commented Mar 2, 2024

No description provided.

@@ -0,0 +1,20 @@
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

License

Comment on lines +357 to +359
bool is_integral() const { return std::is_integral<type>{}; }
bool is_signed() const { return std::is_signed<type>{}; }
bool is_unsigned() const { return std::is_unsigned<type>{}; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why auto was changed to bool ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is for readability/clarity. I don't see why these shouldn't resolve to anything but bool? Unless we want to use value here specified by stl?

https://en.cppreference.com/w/cpp/types/is_unsigned

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing x.is_integral() != y.is_integral() will fail to compile because they will be different types using auto. So I explicitly convert it to bool instead.

friend bool operator>=(const precision& xp, const precision& yp)
{
return (xp > yp) or (xp == yp);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might seem like an odd ask but why not make these xor vs or? If one is true then the other doesn't matter as the result shouldn't be true. Anyway adding xor instead of or here can speed things up / check for errors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xor wont short circuit.

@@ -0,0 +1,191 @@
#include <migraphx/propagate_precision.hpp>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add license here too

@@ -0,0 +1,158 @@
#include <migraphx/propagate_precision.hpp>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

License

auto mul = m2.add_instruction(migraphx::make_op("mul"), sqrt, y);
m2.add_return({mul});
}
EXPECT(m1.sort() == m2.sort());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is done to preserve precision throughout all these divides/square roots but are we not worried about the added overhead here now? we've just converted the fp16 set of ops to double, or is compute not a concern here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we are converting to double anyways. For elementwise, the compute shouldn't be that much overhead since these are all essentially unary operators.

Copy link

codecov bot commented Mar 2, 2024

Codecov Report

Attention: Patch coverage is 94.11765% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 91.76%. Comparing base (84fc9f0) to head (da8471d).

Files Patch % Lines
src/propagate_precision.cpp 94.04% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #2853      +/-   ##
===========================================
+ Coverage    91.75%   91.76%   +0.01%     
===========================================
  Files          473      475       +2     
  Lines        17958    18043      +85     
===========================================
+ Hits         16478    16558      +80     
- Misses        1480     1485       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Mar 2, 2024

Test Batch Rate new
0f785b
Rate old
5ba023
Diff Compare
torchvision-resnet50 64 2,826.08 2,824.95 0.04%
torchvision-resnet50_fp16 64 6,579.77 6,575.26 0.07%
torchvision-densenet121 32 2,105.76 2,101.96 0.18%
torchvision-densenet121_fp16 32 3,696.98 3,683.24 0.37%
torchvision-inceptionv3 32 1,602.40 1,606.31 -0.24%
torchvision-inceptionv3_fp16 32 2,551.04 2,555.93 -0.19%
cadene-inceptionv4 16 717.22 717.78 -0.08%
cadene-resnext64x4 16 680.55 680.75 -0.03%
slim-mobilenet 64 5,900.63 5,910.66 -0.17%
slim-nasnetalarge 64 153.92 153.88 0.03%
slim-resnet50v2 64 2,592.86 2,590.48 0.09%
bert-mrpc-onnx 8 920.94 960.32 -4.10% 🔴
bert-mrpc-tf 1 399.57 400.86 -0.32%
pytorch-examples-wlang-gru 1 392.57 394.04 -0.37%
pytorch-examples-wlang-lstm 1 368.63 366.42 0.60%
torchvision-resnet50_1 1 603.29 606.33 -0.50%
cadene-dpn92_1 1 389.75 393.42 -0.93%
cadene-resnext101_1 1 331.98 332.05 -0.02%
onnx-taau-downsample 1 307.25 307.56 -0.10%
dlrm-criteoterabyte 1 28.79 28.80 -0.01%
dlrm-criteoterabyte_fp16 1 48.40 48.29 0.24%
agentmodel 1 7,243.91 7,346.89 -1.40%
unet_fp16 2 57.79 57.56 0.39%
resnet50v1_fp16 1 910.81 917.25 -0.70%
resnet50v1_int8 1 794.12 815.05 -2.57%
bert_base_cased_fp16 64 1,053.46 1,053.37 0.01%
bert_large_uncased_fp16 32 301.66 301.70 -0.02%
bert_large_fp16 1 158.70 158.88 -0.12%
distilgpt2_fp16 16 1,858.05 1,860.49 -0.13%
yolov5s 1 475.81 481.01 -1.08%
tinyllama 1 32.99 33.01 -0.06%
vicuna-fastchat 1 157.29 159.19 -1.19%
whisper-tiny-encoder 1 348.02 347.33 0.20%
whisper-tiny-decoder 1 395.52 396.69 -0.30%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator

migraphx-bot commented Mar 2, 2024


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

return result;
}

void propagate_precision::apply(module_pass_manager& mpm) const
Copy link
Member

@umangyadav umangyadav Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you write docstrings for all of these functions that describes what they are supposed to do and how they work?
Also add how the pass is supposed to work and how it helps with precision or accuracy or performance ?
We can read the code but it is not time efficient for all to get high level understanding.

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(propagate_reduce)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test where the pass doesn't do anything ?

Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some background: where are we failing accuracy because of precision changes?

@@ -0,0 +1,20 @@
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_PROMOTE_PRECISION_HPP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why the include guard is named differently?

@pfultz2
Copy link
Collaborator Author

pfultz2 commented Apr 4, 2024

For some background: where are we failing accuracy because of precision changes?

This is related to the fp16 inaccuracy with llamav2(see #2556). #2883 will use FP32 for large reduce_means, but it still isnt enough to get accurate results(or avoid nans). So this will use FP32 for the x^2/n on the input and it will use FP32 for the rsqrt(mean + epsilon) that follows the reduce_mean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants