-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support of Qwen models (7b, 14b, 72b) to DeepSpeed-FastGen #4913
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@microsoft-github-policy-service agree |
@ZonePG thank you for this contribution! Could you please run the pre-commit and commit any modified files? |
@mrwyattii |
mrwyattii
approved these changes
Jan 9, 2024
github-merge-queue
bot
removed this pull request from the merge queue due to failed status checks
Jan 10, 2024
mauryaavinash95
pushed a commit
to mauryaavinash95/DeepSpeed
that referenced
this pull request
Feb 17, 2024
…soft#4913) This PR adds support for Qwen models 7b, 14b and 72b. ### Test Code for mii pipeline: ```python from mii import pipeline pipe = pipeline("Qwen/Qwen-7B-Chat") pipe.tokenizer.tokenizer.eos_token_id = 151643 output = pipe(["DeepSpeed is"], max_new_tokens=128, do_sample=False) print(output) ``` for huggingface: ```python from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() inputs = tokenizer('DeepSpeed is', return_tensors='pt') inputs = inputs.to(model.device) pred = model.generate(**inputs, max_new_tokens=128, do_sample=False, repetition_penalty=1.0) test = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False) print(test) ``` ### Qwen 7B Huggingface output with prompt "DeepSpeed is": ``` a high-performance, low-latency database management system designed for real-time analytics and machine learning applications. It is built on top of Apache Arrow, a columnar in-memory data format, and is optimized for processing large volumes of data in parallel.\n\nDeepSpeed offers several key features that make it well-suited for real-time analytics and machine learning applications:\n\n1. High Performance: DeepSpeed is designed to deliver high performance by leveraging parallel processing and optimized data structures. It can process large volumes of data in real-time, making it ideal for applications that require real-time analytics.\n\n2. Low Latency: DeepSpeed is designed to minimize latency by ``` DeepSpeed-FastGen output with prompt "DeepSpeed is": ``` a high-performance, low-latency database management system designed for real-time analytics and machine learning applications. It is built on top of Apache Arrow, a columnar in-memory data format, and is optimized for processing large volumes of data in parallel.\n\nDeepSpeed offers several key features that make it well-suited for real-time analytics and machine learning applications:\n\n1. High Performance: DeepSpeed is designed to deliver high performance by leveraging parallel processing and optimized data structures. It can process large volumes of data in real-time, making it ideal for applications that require real-time analytics.\n\n2. Low Latency: DeepSpeed is designed to minimize latency by ``` ### Qwen 72B Huggingface output with prompt "DeepSpeed is": ``` 是一个开源的深度学习优化库,它提供了多种优化技术,包括模型并行、数据并行、混合并行、ZeRO内存优化等。它可以帮助用户在大规模GPU集群上训练深度学习模型,提高训练速度,减少内存使用。\n在Deepspeed中,模型并行是一种将模型的不同部分分配到不同的GPU上的技术。这样可以处理模型太大,无法放在一个GPU上的问题。数据并行是将数据集分成多个部分,每个部分在不同的GPU上进行训练。混合并行则是结合了模型并行和数据并行,以更有效地利用GPU资源 ``` DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding: ``` 是一个开源的深度学习优化库,它提供了多种优化技术,包括模型并行、数据并行、混合并行、ZeRO内存优化等。它可以帮助用户在大规模GPU集群上训练深度学习模型,提高训练速度,减少内存使用。\n在Deepspeed中,模型并行是一种将模型的不同部分分配到不同的GPU上的技术。这样可以处理模型太大,无法放在一个GPU上的问题。数据并行是将数据集分成多个部分,每个部分在不同的GPU上进行训练。混合并行则是结合了模型并行和数据并行,以更有效地利用GPU资源 ``` --------- Co-authored-by: Michael Wyatt <[email protected]> Co-authored-by: Logan Adams <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds support for Qwen models 7b, 14b and 72b.
Test Code
for mii pipeline:
for huggingface:
Qwen 7B
Huggingface output with prompt "DeepSpeed is":
DeepSpeed-FastGen output with prompt "DeepSpeed is":
Qwen 72B
Huggingface output with prompt "DeepSpeed is":
DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding: