Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[luci/pass] Refactor FuseAddWithFullyConnectedPass #13846

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

jiwaszki
Copy link
Contributor

This commit changes the order of searching for the pattern.

ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz [email protected]

This commit changes the order of searching for the pattern.

ONE-DCO-1.0-Signed-off-by: Jan Iwaszkiewicz <[email protected]>
@jiwaszki
Copy link
Contributor Author

For: #13685

@jiwaszki jiwaszki marked this pull request as ready for review September 2, 2024 15:40
@jiwaszki jiwaszki changed the title [DRAFT][luci/pass] Refactor FuseAddWithFullyConnectedPass [luci/pass] Refactor FuseAddWithFullyConnectedPass Sep 3, 2024
@jiwaszki jiwaszki added the PR/ready for review It is ready to review. Please review it. label Sep 3, 2024

auto fused_bias = luci::clone(addition);

// Add existing bias values
if (auto const_bias = dynamic_cast<luci::CircleConst *>(fc->bias()))
{
assert(const_bias->dtype() == loco::DataType::FLOAT32);
RETURN_FALSE_UNLESS(const_bias->dtype() == loco::DataType::FLOAT32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check this inside common_pass_checks? luci::clone adds a node into the graph (graph is modified), so all checks had to be done before luci::clone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this check is specific to float32.. It would be better to use a pattern class (PTAL how FuseGeluPass finds different patterns).

Copy link
Contributor Author

@jiwaszki jiwaszki Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinevening could this refactor be done in separate PR? In my opinion, it would be better to push it as-is and create another PR that will introduce pattern search like FuseGeluPass.

jinevening
jinevening previously approved these changes Sep 24, 2024
Copy link
Contributor

@jinevening jinevening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if (not(cond)) \
return false;

bool fc_with_add_pattern_check(const loco::DataType dtype, luci::CircleAdd **add,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) why use pointers of pointer? add is not updated here.
It would be better not to use double pointer if not necessary
and also use reference if need update in this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanshpark as for add I fully understand the concern. This is changed.
As for double pointers, it was all required by luci::fill to get proper handling (i.e. do not loose pointer locally in the function, which resulted in segfaults). Now I propose use of struct to keep track of nodes in the pattern. I see it as proper middle-ground solution, that also organize the required nodes in one place. What do you think?


RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add));
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q) Why are you using const luci::CircleAdd & ?


RETURN_FALSE_UNLESS(luci::fill(fc, addition).with_commutative_args_of(*add));
bool fc_with_add_pattern_check(const loco::DataType dtype, const luci::CircleAdd &add,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing loco::DataType,

template <typename DT>
bool fc_with_add_pattern_check(const luci::CircleAdd &add,

?

// TODO Support scalar addition
if (rank == 0)
return false;
PatternNodes nodes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems you want to find members of this nodes.
there are similar styles in other Pass that you can refer.
plz check compiler/luci/pass/src/FuseRmsNormPass.cpp that was recently added.

@seanshpark
Copy link
Contributor

@jiwaszki , are you still willing to work on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PR/ready for review It is ready to review. Please review it.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants