Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP for scalar_mul_add #3454

Conversation

aarushjain29
Copy link
Contributor

@aarushjain29 aarushjain29 commented Sep 18, 2024

Draft PR

In the current example from SD Clip, there are 2 key kernels add_kernel and mul_add_kernel. After the computation of these kernels are completed, their results are fed into an mlir kernel which I think is a convolution.

I am implementing a scalar_mul_add in simplify_algebra. This will specifically handle where the multiplication operation is scalar and followed by a convolution. If these cases are met then directly perform a * (x+b) without expanding or rewriting it.

There needs to be an exception in mul_add. If the multiplication is scalar then it should not match and hence rewrite will not be done.

@@ -440,11 +440,44 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

if(a_ins->get_shape().scalar())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also check that it goes into convolution or gemm for the exception. It should also be done in the matcher.

@aarushjain29 aarushjain29 force-pushed the 3432-improve-simplify_algebra-to-find-more-horizontal-fusion-opportunities branch from e174012 to 7c2fdf5 Compare September 25, 2024 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve simplify_algebra to find more horizontal fusion opportunities
2 participants