Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lm_head and embed_out tensor parallel #3962

Merged
merged 4 commits into from
Oct 9, 2023

Conversation

Yejing-Lai
Copy link
Contributor

This PR aims to add lm_head and embed_out layer tensor parallel. This applies to models whose last layer is named lm_head/embed_out, such as llama, gpt-j, bloom, opt, and so on.

@RezaYazdaniAminabadi
Copy link
Contributor

RezaYazdaniAminabadi commented Jul 14, 2023

This PR aims to add lm_head and embed_out layer tensor parallel. This applies to models whose last layer is named lm_head/embed_out, such as llama, gpt-j, bloom, opt, and so on.
Hi @Yejing-Lai

Thanks for the nice PR to add the parallelism in the last embedding linear layer. I was just wondering how much of the performance impact this has, considering that you have the communication overhead? Also, you can split the weight in different ways, either across the hidden_dim requiring an all-reduce (which is what you did), or across the embedding dimension which then needs an all-gather. Either way, has their own tradeoffs, in the case of embedding-dim we can reduce the amount of computation if there is a very large vocabulary, however, with sharding across the model dimension, we reduce the amount of communication in this case. I think having both of these options makes sense and we can later decide on which needs to be enabled.

@Yejing-Lai
Copy link
Contributor Author

This PR aims to add lm_head and embed_out layer tensor parallel. This applies to models whose last layer is named lm_head/embed_out, such as llama, gpt-j, bloom, opt, and so on.
Hi @Yejing-Lai

Thanks for the nice PR to add the parallelism in the last embedding linear layer. I was just wondering how much of the performance impact this has, considering that you have the communication overhead? Also, you can split the weight in different ways, either across the hidden_dim requiring an all-reduce (which is what you did), or across the embedding dimension which then needs an all-gather. Either way, has their own tradeoffs, in the case of embedding-dim we can reduce the amount of computation if there is a very large vocabulary, however, with sharding across the model dimension, we reduce the amount of communication in this case. I think having both of these options makes sense and we can later decide on which needs to be enabled.

Hi @RezaYazdaniAminabadi. We observed an increase in E2E performance on SPR after adding lm_head/embed_out TP. I think the amount of calculations reduced by splitting weights on hidden_dim or embedding_dim is the same, and they all distribute the weights equally to each rank. Thanks~

output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard],
self.weight.transpose(-1, -2))
if self.mp_group is not None:
dist.all_reduce(output, group=self.mp_group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use inference_all_reduce since #3919 is merged.

def forward(self, input):
assert input.shape[
-1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]'
input_shard = input.shape[-1] // self.world_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can actually make this work with input.shape[-1] not divisible by world_size. Need to consider relationship between this pR and the following PR #4011

@delock
Copy link
Collaborator

delock commented Jul 24, 2023

Divide among hidden dim will be a simpler solution. It is consistent with other layers and we can reuce inferece_all_reduce to reduce communication latency further. Low communcation size also means it more friendly to scaleout scenario.

@delock
Copy link
Collaborator

delock commented Aug 16, 2023

Hi @RezaYazdaniAminabadi, is this PR still under review, or embedding parallel method still needs more consideration?

replaced_module = replace_fn(replaced_module, ("lm_head", ), 0, "lm_head")
elif hasattr(replaced_module, "embed_out") and hasattr(replaced_module.embed_out,
"weight") and not replaced_module.embed_out.weight.is_meta:
replaced_module = replace_fn(replaced_module, ("embed_out", ), 0, "embed_out")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @RezaYazdaniAminabadi @delock, I refactored some code here, so that it will not integrate so deeply with former replace logic. And it will replace the lm_head or embed_out Linear after all other replace/load logics are finished. It will be more decoupling and cleaner now.

@dc3671
Copy link
Contributor

dc3671 commented Aug 16, 2023

About the performance, in our simple test, it will get an around 10% end2end improvement for bloom-176B on per token latency.

@dc3671
Copy link
Contributor

dc3671 commented Aug 16, 2023

Also, if need to make this function optional, I think we need to add an item in DeepSpeed's Config. Maybe this needs further discussion?

Copy link
Contributor

@RezaYazdaniAminabadi RezaYazdaniAminabadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All LGTM, thanks everyone for making this part of inference faster :)

@dc3671
Copy link
Contributor

dc3671 commented Oct 7, 2023

Hi @tjruwase , I have fixed the merge conflict.

@dc3671
Copy link
Contributor

dc3671 commented Oct 9, 2023

Hi @tjruwase , the CI problems are fixed.

@tjruwase tjruwase added this pull request to the merge queue Oct 9, 2023
Merged via the queue into microsoft:master with commit 6763e2d Oct 9, 2023
15 checks passed
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Oct 9, 2023
* add lm_head and embed_out tensor parallel

* fix load lm_head.weight name issue

* replace all_reduce with inference_all_reduce

* refactor lm_head tensor parallel

---------

Co-authored-by: Chen, Zhenhuan <[email protected]>
baodii pushed a commit to baodii/DeepSpeed that referenced this pull request Nov 7, 2023
* add lm_head and embed_out tensor parallel

* fix load lm_head.weight name issue

* replace all_reduce with inference_all_reduce

* refactor lm_head tensor parallel

---------

Co-authored-by: Chen, Zhenhuan <[email protected]>
jianan-gu added a commit to jianan-gu/DeepSpeedSYCLSupport that referenced this pull request Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants