Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable gemm activation for non-float data types #19612

Merged
merged 11 commits into from
Feb 24, 2024

Conversation

smk2007
Copy link
Member

@smk2007 smk2007 commented Feb 22, 2024

Description

Disable gemm activation for non-float data types

Motivation and Context

When a float16 model contains a Gemm+Relu subgraph, the gemm_activation_fusion will kick in and cause the two nodes to be eliminated and replaced with a FusedGemm. This however is only registered for the float data type. This causes model load failures.

Disable the fusion for non-float data types.

@tianleiwu
Copy link
Contributor

Based on my understanding, there is another bug that it will fuse FusedGemm for CUDA EP, but FusedGemm has no CUDA implementation.

Could you check that both nodes are assigned to CPU EP before the fusion?

@smk2007
Copy link
Member Author

smk2007 commented Feb 23, 2024

Based on my understanding, there is another bug that it will fuse FusedGemm for CUDA EP, but FusedGemm has no CUDA implementation.

Could you check that both nodes are assigned to CPU EP before the fusion?

AFAIK GemmActivationFusion already checks GetCompatibleExecutionProviders()

and graph_transformer_utils.cc initializes the GemmActivationFusion transformer with only the cpu_ep as being compatible.
transformers.emplace_back(std::make_unique(cpu_ep));

tianleiwu
tianleiwu previously approved these changes Feb 23, 2024
@smk2007 smk2007 merged commit 46c4d7f into main Feb 24, 2024
84 of 94 checks passed
@smk2007 smk2007 deleted the user/sheilk/disable-gemm-activation-for-fp16 branch February 24, 2024 02:20
maggie1059 pushed a commit that referenced this pull request Mar 5, 2024
### Description
Disable gemm activation for non-float data types


### Motivation and Context
When a float16 model contains a Gemm+Relu subgraph, the
gemm_activation_fusion will kick in and cause the two nodes to be
eliminated and replaced with a FusedGemm. This however is only
registered for the float data type. This causes model load failures.

Disable the fusion for non-float data types.

---------

Co-authored-by: Sheil Kumar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants