Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove weight parallelism #137

Merged
merged 6 commits into from
Aug 12, 2024
Merged

Conversation

eitanturok
Copy link
Contributor

@eitanturok eitanturok commented Aug 9, 2024

What does this PR do?

This PR removes weight parallelism as we never use it. Tagging @tgale96.

Since we use FSDP's weight parallelism and not our own custom weight parallelism in MegaBlocks, I wanted to remove the weight parallelism implementation.

Specifically, we

  1. Remove test_parallelism.py because this file tests that weight parallelism and expert parallelism have the same results.
  2. Remove moe_weight_parallelism and weight_parallel_group from the args.
  3. Remove weight parallelism from all the layers.

Because moe_weight_parallelism is False by default and weight_parallel_group is None by default,
mpu.get_weight_parallel_world_size(args) always returned 1 and mpu.get_weight_parallel_rank(args) always returns 0. This allowed us to drastically simplify things in mlp.create_dmoe_expert_weights().

Also, can I get an extra close review of my changes to the MemoryOptimizedMLP.parallel_forward() method? I noticed that the group would always be None there but I am hesitant to hard-code this in. Not sure if this is the right thing to do.

Also, I ran all tests locally and they pass.

(Also, enjoy this nice PR template that I added!)

What issue(s) does this change relate to?

Before submitting

  • Have you read the contributor guidelines?
  • Is this change a documentation change or typo fix? If so, skip the rest of this checklist.
  • Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so.
  • Did you update any related docs and document your change?
  • Did you update any related tests and add any new tests related to your change? (see testing)
  • Did you run the tests locally to make sure they pass?
  • Did you run pre-commit on your change? (see the pre-commit section of prerequisites)

@tgale96
Copy link
Contributor

tgale96 commented Aug 9, 2024

So this looks great but I think you can remove a lot more code!

weight_parallel.py can be completely removed, I think. And then all uses of it, including the parallel_forward function you called out, which I don't think is used anywhere anyways?

Deleting that file will help you track down every use of weight parallelism as well, since it all gets routed into that one.

@eitanturok
Copy link
Contributor Author

eitanturok commented Aug 9, 2024

  • removed parallel_forward
  • deleted weight_parallel.py
  • all tests still pass
  • Discussing with repo maintainers, we decided to delete weight parallelism rather than deprecate it

megablocks/layers/mlp.py Outdated Show resolved Hide resolved
@tgale96
Copy link
Contributor

tgale96 commented Aug 12, 2024

LGTM! The last thing I might do is to just grep weight_parallel, if you haven't already. But I think you got everything.

@eitanturok
Copy link
Contributor Author

eitanturok commented Aug 12, 2024

Already grep-ed! Will merge this later today.

Copy link
Contributor

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mihir-db mihir-db merged commit 27d3d2c into databricks:main Aug 12, 2024
3 checks passed
@eitanturok eitanturok deleted the weight-parallelism branch August 20, 2024 05:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants