Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Special case for K==0 in CUDA MatMul #21525

Merged
merged 5 commits into from
Aug 13, 2024
Merged

Conversation

yuslepukhin
Copy link
Member

@yuslepukhin yuslepukhin commented Jul 26, 2024

Description

This change addresses a case where we multiply two matrices, and their inner dimension is 0.
numpy and Eigen which is being used in our CPU EP implementation correctly handle this case
and output a [M, N] matrix filled with zeros.

Motivation and Context

This is required to support GenAI empty input Lora implementation.

Addresses: #21483

tianleiwu
tianleiwu previously approved these changes Jul 26, 2024
@tianleiwu
Copy link
Contributor

/azp run Windows GPU CI Pipeline

Copy link

No pipelines are associated with this pull request.

tianleiwu
tianleiwu previously approved these changes Aug 2, 2024
@yuslepukhin yuslepukhin merged commit c2911bb into main Aug 13, 2024
95 of 98 checks passed
@yuslepukhin yuslepukhin deleted the yuslepukhin/matmul_zero_k branch August 13, 2024 18:27
@ranjitshs
Copy link
Contributor

ranjitshs commented Aug 14, 2024

@yuslepukhin @tianleiwu
FYI. Our local AIX CI reported the below newly introduced test case failure. I will try to debug.
Please let me know if you have any suggestion.

1: [ RUN      ] MathOpTest.MatMul_ZeroK
1: /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/test/providers/checkers.cc:390: Failure
1: The difference between cur_expected[i] and cur_actual[i] is NaNQ, which exceeds tolerance, where
1: cur_expected[i] evaluates to 0,
1: cur_actual[i] evaluates to -NaNQ, and
1: tolerance evaluates to 9.9999997473787516e-06.
1: i:3
1: Google Test trace:
1: /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/test/providers/checkers.cc:566: provider type: CPUExecutionProvider
1: /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/test/providers/base_tester.cc:827: registered execution providers: CPUExecutionProvider
1: 
1: /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/test/providers/checkers.cc:390: Failure
1: The difference between cur_expected[i] and cur_actual[i] is NaNQ, which exceeds tolerance, where
1: cur_expected[i] evaluates to 0,
1: cur_actual[i] evaluates to -NaNQ, and
1: tolerance evaluates to 9.9999997473787516e-06.
1: i:5
1: Google Test trace:
1: /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/test/providers/checkers.cc:566: provider type: CPUExecutionProvider
1: /home/buildusr/jenkins/workspace/onnxruntime-openxl/onnxruntime/onnxruntime/test/providers/base_tester.cc:827: registered execution providers: CPUExecutionProvider
1: 
1:

@yuslepukhin
Copy link
Member Author

We do not have AIX build, however, my hunch is that std::fill() with the actual datatype(0) such as float would do better than memset. Can you try it locally and let us know? I would provide a fix.

@ranjitshs
Copy link
Contributor

Hi @yuslepukhin
This test failure is intermittent and not happening every time, so it's okay for time being.
I will try to see the frequency of failure in local CI system.
Thanks for the std::fill suggestion. I haven't tried because failure is intermittent.
is it more reliable than memset() for data type other than integers family ?

@yuslepukhin
Copy link
Member Author

It depends on the floating-point format used on AIX. I have not worked with AIX for nearly 20 years. memset() assumes that all zeros are fine all datatypes. std::fill() would rely on explicitly constructing a given data type constant that represents zero on any given platform with any given compiler, and, therefore, is more portable.

yuslepukhin added a commit that referenced this pull request Aug 15, 2024
### Description
Replace `memset(0)` with `std::fill(T{})`. This would ensure that all
the types are initialized in a portable way.

### Motivation and Context
Some platforms exhibit intermittent failures with NaN results.
Follow up to: #21525

Cc: @ranjitshs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants