Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Deepspeed inference does not support the Qwen model #4840

Closed
jiahe7ay opened this issue Dec 19, 2023 · 4 comments
Closed

[BUG] Deepspeed inference does not support the Qwen model #4840

jiahe7ay opened this issue Dec 19, 2023 · 4 comments
Labels
bug Something isn't working inference

Comments

@jiahe7ay
Copy link

jiahe7ay commented Dec 19, 2023

Describe the bug
I use deepspeed.init_inference to accelerate the inference of the Qwen model. When I compare it with not using deepspeed.init_inference, I find that there is no acceleration.

Then I assert whether the Qwen module is initialized as the DeepspeedTransformerInference class, but it is not initialized successfully.

I'm curious about one thing: the Qwen model is also a pure encoder architecture, similar to the GPT model. Why does the initialization fail?

To Reproduce
the code is :

# init deepspeed inference engine

def measure_latency(model, tokenizer, payload, generation_args, device):
    input_ids = tokenizer(payload, return_tensors="pt").input_ids.to(device)
    latencies = []
    # warm up
    for _ in range(2):
        _ =  model.generate(input_ids, **generation_args)
    # Timed run
    for _ in range(10):
        start_time = perf_counter()
        _ = model.generate(input_ids, **generation_args)
        latency = perf_counter() - start_time
        latencies.append(latency)
    # Compute run statistics
    time_avg_ms = 1000 * np.mean(latencies)
    time_std_ms = 1000 * np.std(latencies)
    time_p95_ms = 1000 * np.percentile(latencies,95)
    return f"P95 latency (ms) - {time_p95_ms}; Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f};", time_p95_ms


ds_model = deepspeed.init_inference(
    model=qwen_model,      # Transformers models
    mp_size=1,        # Number of GPU
    dtype=torch.float16, # dtype of the weights (fp16)
    replace_method="auto", # Lets DS autmatically identify the layer to replace
    replace_with_kernel_inject=True, # replace the model with the kernel injector
)
print(f"model is loaded on device {ds_model.module.device}")

from deepspeed.ops.transformer.inference import DeepSpeedTransformerInference
assert isinstance(ds_model.module.transformer.h[0], DeepSpeedTransformerInference) == True, "Model not sucessfully initalized"

# Test model
example = "My name is Philipp and I"
input_ids = tokenizer(example,return_tensors="pt").input_ids.to(model.device)
logits = ds_model.generate(input_ids, do_sample=True, max_length=100)
tokenizer.decode(logits[0].tolist())

payload = (
    "Hello my name is Philipp. I am getting in touch with you because i didn't get a response from you. What do I need to do to get my new card which I have requested 2 weeks ago? Please help me and answer this email in the next 7 days. Best regards and have a nice weekend but it"
    * 2
)

print(f'Payload sequence length is: {len(tokenizer(payload)["input_ids"])}')

# generation arguments
generation_args = dict(do_sample=False, num_beams=1, min_length=128, max_new_tokens=128)
ds_results = measure_latency(ds_model, tokenizer, payload, generation_args, ds_model.module.device)

print(f"DeepSpeed model: {ds_results[0]}")

Expected behavior
The inference of the Qwen model has been accelerated.

System info (please complete the following information):

  • OS: Ubuntu 20.04
  • GPU count and types 1 A800
  • Deepspeed version 0.11.2
  • transformers version 4.34.0
  • torch version 2.0.3
  • cuda version 11.1
  • Python version
  • Any other relevant info about your setup

Additional context
I think Qwen is a very popular large model, and I hope the official release its adaptation soon.
Qwen: https://github.com/QwenLM/Qwen

@jiahe7ay jiahe7ay added bug Something isn't working inference labels Dec 19, 2023
@zhudongwork
Copy link

I have the same error.

here is my code:

import os
import torch
import deepspeed
import numpy as np
import transformers

from time import perf_counter
from transformers import AutoTokenizer, AutoModelForCausalLM
from deepspeed.ops.transformer.inference import DeepSpeedTransformerInference

transformers.logging.set_verbosity_error()


model_name = "/mnt/model_repository/Qwen-14B-Chat"
payload = "hello"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

with deepspeed.OnDevice(dtype=torch.float16, device="meta", enabled=True):
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True)

ds_model = deepspeed.init_inference(
    model=model,      
    mp_size=2,        
    dtype=torch.float16, 
    replace_method="auto", 
    replace_with_kernel_inject=True, 
)

assert isinstance(ds_model.module.transformer.h[0], DeepSpeedTransformerInference) == True, "Model not sucessfully initalized"


def test_inference():
    # 执行模型推理
    input_ids = tokenizer(payload, return_tensors="pt").input_ids.to(model.device)
    logits = ds_model.generate(input_ids, do_sample=True, max_length=100)
    print(tokenizer.decode(logits[0].tolist()))


if __name__ == "__main__":
    test_inference()

cmd

deepspeed --num_nodes 1  --num_gpus 2 --master_port 3600 --hostfile hostfile qwen-ds.py

error info:

Traceback (most recent call last):
  File "/mnt/deepspeed/qwen-ds.py", line 23, in <module>
    ds_model = deepspeed.init_inference(
               ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/deepspeed/__init__.py", line 342, in init_inference
    engine = InferenceEngine(model, config=ds_inference_config)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/deepspeed/inference/engine.py", line 173, in __init__
    self.module.to(device)
  File "/root/miniconda3/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2460, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 857, in _apply
    self._buffers[key] = fn(buf)
                         ^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1158, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!

ds_report output:

[2023-12-26 08:01:48,938] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.1
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/root/miniconda3/lib/python3.11/site-packages/torch']
torch version .................... 2.1.0+cu121
deepspeed install path ........... ['/root/miniconda3/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.12.5, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.2
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 58.97 GB

@ZonePG
Copy link
Contributor

ZonePG commented Jan 14, 2024

refer to #4913

Install DeepSpeed from the latest source code and consider utilizing DeepSpeed-MII for optimal performance.

@rayquazaMega
Copy link

rayquazaMega commented Jan 15, 2024

refer to #4913

Install DeepSpeed from the latest source code and consider utilizing DeepSpeed-MII for optimal performance.

@ZonePG Thank you for your awesome work! I tried the following code(as jiahe7ay did):

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import deepspeed
from deepspeed.ops.transformer.inference import DeepSpeedTransformerInference
import torch
#import time
tokenizer = AutoTokenizer.from_pretrained("/data3/mnt/Qwen-VL-master/checkpoints/qwen/Qwen-VL-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("/data3/mnt/Qwen-VL-master/checkpoints/qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True).eval()
inputs = tokenizer('DeepSpeed is', return_tensors='pt')
inputs = inputs.to(model.device)

model = deepspeed.init_inference(
    model=model,          
    dtype=torch.float16, 
    replace_method="auto", 
    replace_with_kernel_inject=True, 
)

assert isinstance(model.module.transformer.h[0], DeepSpeedTransformerInference) == True, "Model not sucessfully initalized"

and found that the assert still failed and the inference speed did not speed up. Is it because of Qwen-VL’s Vision head? How should I use deepspeed acceleration correctly?

@ZonePG
Copy link
Contributor

ZonePG commented Jan 15, 2024

@rayquazaMega I have not used the VL model, I think it's not well-supported currently. For Chat or Base model, I would like recommend deepspeed-mii, it's describled in https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen and https://github.com/microsoft/DeepSpeed-MII

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference
Projects
None yet
Development

No branches or pull requests

4 participants