Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser changes to handle MatMulIntegerToFloat #3445

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from

Conversation

TedThemistokleous
Copy link
Collaborator

@TedThemistokleous TedThemistokleous commented Sep 16, 2024

Changes to MatMul parser to handle the Microsoft Contrib operator MatMulintegarToFloat

Since we have the scale and zero points in our operands we can just perform a multiplied after int8 biases are added and then insert a regular dot on the scaled input values which should give the same output as the input data types.

Able to leverage the existing set of tests for matmul

Needs #3526 as there's a bug with dequantizelinear this has uncovered

@TedThemistokleous TedThemistokleous self-assigned this Sep 16, 2024
@TedThemistokleous
Copy link
Collaborator Author

TedThemistokleous commented Sep 16, 2024

TODO:

  • Add Parser tests for err cases
  • Add parser tests for base case
  • Add parser test for bias and zero point cases
  • Add verify tests for all of the above

@TedThemistokleous TedThemistokleous added onnxruntime PR changes interaction between MIGraphX and Onnxruntime Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase UAI labels Sep 16, 2024
Copy link

codecov bot commented Sep 16, 2024

Codecov Report

Attention: Patch coverage is 90.10989% with 9 lines in your changes missing coverage. Please review.

Project coverage is 92.17%. Comparing base (f5df004) to head (8a41f16).

Files with missing lines Patch % Lines
src/onnx/parse_matmul.cpp 90.10% 9 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3445      +/-   ##
===========================================
- Coverage    92.17%   92.17%   -0.01%     
===========================================
  Files          513      513              
  Lines        21536    21603      +67     
===========================================
+ Hits         19851    19912      +61     
- Misses        1685     1691       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Updated parser to handle bias case as well as bad scale conditions

Initial float/half tests
bad scale tests
bad bias tests
avoid tidy screaming about complexity
TedThemistokleous and others added 2 commits October 11, 2024 17:45
Use dequantizelinear which elminates the need to add in shifts due to int8/uint8 mismatches

still needs parser tests
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Show resolved Hide resolved
MIGRAPHX_THROW("PARSE_QUANT_DOT_SCALED: Bias have same dim as matrix B column");
}

has_valid_scale_bias = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As against invalid? ;-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If scale bias doesn't exist there isn't a bias at the end of the matmulintergertofloat added then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was simply wondering if has_scale_bias isn't what the intent is? :-)

src/onnx/parse_matmul.cpp Show resolved Hide resolved
return dequantized_op;
}

static instruction_ref handle_scaled_output(const onnx_parser::node_info& info,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many parameters. Ideally they should be handled by a struct parameter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're the same amount of a parameters gathered by the operator. These are all needed for dequantize steps and adding the proper unsqueeze->transpose paths. Order matters here with respect to matrix input A or B

src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
Use the parsed in op name for error messages to help logging should parser errors occur.
Change naming to be agnostic of input index.
@TedThemistokleous TedThemistokleous force-pushed the add_matmulintegertofloat_contrib_op branch from 42b787d to 9660e11 Compare October 31, 2024 22:00
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
bool a1_has_no_zp = (a1 == zp_a1);

auto unsq_scale_a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {-1}}}), scale_a0);
if(not a0_has_no_zp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Nit) Style: perhaps two negatives are not required, if there is a variable like a0_has_zp.

src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Outdated Show resolved Hide resolved
src/onnx/parse_matmul.cpp Show resolved Hide resolved
Ted Themistokleous added 2 commits November 7, 2024 14:03
Clean up uint8 handling for quant_dot. Fix tests
Copy link
Contributor

@lakhinderwalia lakhinderwalia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for following up the comments. Approved.

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
8a41f1
Rate old
4b96e1
Diff Compare
torchvision-resnet50 64 3,258.72 3,260.40 -0.05%
torchvision-resnet50_fp16 64 6,995.05 6,981.88 0.19%
torchvision-densenet121 32 2,437.13 2,436.50 0.03%
torchvision-densenet121_fp16 32 4,095.43 4,081.96 0.33%
torchvision-inceptionv3 32 1,638.27 1,638.04 0.01%
torchvision-inceptionv3_fp16 32 2,763.67 2,760.86 0.10%
cadene-inceptionv4 16 775.85 776.56 -0.09%
cadene-resnext64x4 16 811.91 811.67 0.03%
slim-mobilenet 64 7,534.34 7,540.50 -0.08%
slim-nasnetalarge 64 211.46 211.49 -0.01%
slim-resnet50v2 64 3,505.35 3,506.73 -0.04%
bert-mrpc-onnx 8 1,149.97 1,147.08 0.25%
bert-mrpc-tf 1 464.52 465.87 -0.29%
pytorch-examples-wlang-gru 1 422.89 423.73 -0.20%
pytorch-examples-wlang-lstm 1 394.60 389.07 1.42%
torchvision-resnet50_1 1 788.90 788.22 0.09%
cadene-dpn92_1 1 398.61 402.19 -0.89%
cadene-resnext101_1 1 376.73 382.83 -1.59%
onnx-taau-downsample 1 343.11 343.07 0.01%
dlrm-criteoterabyte 1 33.34 33.34 -0.00%
dlrm-criteoterabyte_fp16 1 52.71 52.75 -0.08%
agentmodel 1 8,175.55 8,325.15 -1.80%
unet_fp16 2 58.90 58.80 0.17%
resnet50v1_fp16 1 950.73 953.06 -0.24%
resnet50v1_int8 1 1,017.39 1,005.99 1.13%
bert_base_cased_fp16 64 1,169.96 1,170.44 -0.04%
bert_large_uncased_fp16 32 363.68 363.37 0.08%
bert_large_fp16 1 200.11 198.99 0.56%
distilgpt2_fp16 16 2,200.03 2,201.23 -0.05%
yolov5s 1 541.03 536.00 0.94%
tinyllama 1 43.44 43.45 -0.02%
vicuna-fastchat 1 170.00 174.10 -2.35%
whisper-tiny-encoder 1 418.05 418.74 -0.17%
whisper-tiny-decoder 1 436.60 425.97 2.50%

This build is OK for merge ✅

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From our conversation, need to test/handle higher dimensional matrix contractions (matrix mul). Also transpose with permutation = {0, 1} probably does nothing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Onnx Operators Adding or modifying an Onnx Operator in the MIGraphX codebase onnxruntime PR changes interaction between MIGraphX and Onnxruntime UAI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants