Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2-VL Batch Bug #2495

Open
2 of 4 tasks
LugerW-A opened this issue Nov 25, 2024 · 17 comments
Open
2 of 4 tasks

Qwen2-VL Batch Bug #2495

LugerW-A opened this issue Nov 25, 2024 · 17 comments
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@LugerW-A
Copy link

LugerW-A commented Nov 25, 2024

System Info

x86
Tensorrt_LLM 0.16.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Qwen2-VL examples

Expected behavior

Dose Qwen2-VL support batch prompt?
When the input is a batch, only the first result returns correctly, while the rest are all empty.
print(input_ids.shape)
print(prompt_table.shape)
print(prompt_tasks)
outputs = self.model.generate(
input_ids,
input_position_ids=None,
mrope_params=mrope_params,
sampling_config=None,
prompt_table=prompt_table,
prompt_tasks=prompt_tasks,
max_new_tokens=max_new_tokens,
end_id=end_id,
pad_id=self.model.tokenizer.pad_token_id
if self.model.tokenizer.pad_token_id is not None else
self.model.tokenizer.all_special_ids[0],
top_k=self.args.top_k,
top_p=self.args.top_p,
temperature=self.args.temperature,
repetition_penalty=self.args.repetition_penalty,
num_beams=self.args.num_beams,
output_sequence_lengths=True,
return_dict=True)

actual behavior

input_ids only differ in the first dimension, but the results are incorrect(empty).

additional notes

none

@LugerW-A LugerW-A added the bug Something isn't working label Nov 25, 2024
@hello-11 hello-11 added the triaged Issue has been triaged by maintainers label Nov 25, 2024
@sunnyqgg
Copy link
Collaborator

Hi @LugerW-A , it supports batch inference, and you need to follow the batch process provided by official QWen2-VL, please see more info at: https://github.com/QwenLM/Qwen2-VL?tab=readme-ov-file , like:
messages1 = [
{
"role": "user",
"content": [
{"type": "image", "image": "xxx/image1.jpg"},
{"type": "text", "text": "Describe this picture?"},
],
}
]
messages2 = [
{
"role": "user",
"content": [
{"type": "image", "image": "xxxx/image2.jpg"},
{"type": "text", "text": "Describe this picture? and what kind of coulor doese it containe?"},
],
}
]
messages = [messages1, messages2]
texts = [
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages
]
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")

@sun2011yao
Copy link

sun2011yao commented Nov 29, 2024

@sunnyqgg HI, referring to the above writing method, the second output is empty. Have you printed the second output result?
When I run it here, it shows that the second output is all eos_token_id

@sunnyqgg
Copy link
Collaborator

HI @sun2011yao do you specify the --batch_size when running with multi batch?

@sun2011yao
Copy link

HI @sun2011yao do you specify the --batch_size when running with multi batch?

yes, run the command as follows:
python3 run.py
--hf_model_dir ./${MODEL_NAME}
--batch_size 2
--image_path ./pics/demo.jpeg
--run_profiling
--max_new_tokens 50
--visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder
--llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu/

@sunnyqgg
Copy link
Collaborator

Hi,
If you add messages = [messages1, messages2] like above for default , so please don't add --image_path ./pics/demo.jpeg, otherwise it will don't work, I'll add multi batch by specifying multi values for --image_path later.

@sun2011yao
Copy link

sun2011yao commented Nov 29, 2024

Hi, If you add messages = [messages1, messages2] like above for default , so please don't add --image_path ./pics/demo.jpeg, otherwise it will don't work, I'll add multi batch by specifying multi values for --image_path later.

HI, i removed --image_path, but second result still empty.
[['The image shows a woman sitting on a sandy beach with a dog. The dog is wearing a colorful harness and is sitting on its hind legs, giving a high-five to the woman. The woman is wearing a plaid shirt and is smiling. The'], ['']]

Can you get the correct results there?

@sun2011yao
Copy link

@sunnyqgg Hi, have you reproduced this problem?

@YSF-A
Copy link

YSF-A commented Dec 2, 2024

I guess I met the same error.

it supports batch inference, and you need to follow the batch process provided by official QWen2-VL, please see more info at: https://github.com/QwenLM/Qwen2-VL?tab=readme-ov-file , like:

@sunnyqgg Hi, I run command like follows
python3 run.py --hf_model_dir ${my_hf_model_dir} --visual_engine_dir ${my_visual_engine_dir} --llm_engine_dir ${my_llm_engine_dir} --kv_cache_free_gpu_memory_fraction 0.7 --batch_size 2

And I got an error: `IndexError: index 1 is out of bounds for dimension 0 with size 1 . This error is raised in https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/multimodal_model_runner.py#L1111

Therefore I add following code to https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/runtime/multimodal_model_runner.py#L1741

image = image.repeat(self.args.batch_size, 1) image_grid_thw = image_grid_thw.repeat(self.args.batch_size, 1) input_ids = input_ids.repeat(self.args.batch_size, 1) attention_mask = attention_mask.repeat(self.args.batch_size, 1)

And the second result it empty.

Anyone have an idea about this error. Thanks.
`

@hello-11 hello-11 assigned hello-11 and sunnyqgg and unassigned hello-11 Dec 2, 2024
@sun2011yao
Copy link

@sunnyqgg HI, I found some doubts about the following line of code. should it be modified to
rotary_coef_cache_buffer = params.mrope_rotary_sin_cos + batch_idx * params.rotary_embedding_max_positions * params.half_rotary_dim ?
but after the modification, the second output result is still wrong.

rotary_coef_cache_buffer = params.mrope_rotary_sin_cos + batch_idx * params.rotary_embedding_max_positions

@sunnyqgg
Copy link
Collaborator

sunnyqgg commented Dec 3, 2024

Hi,
rotary_coef_cache_buffer = params.mrope_rotary_sin_cos + batch_idx * params.rotary_embedding_max_positions * params.half_rotary_dim this's right, I have already changed it internally.
And I can't reproduce the issue you met, multi-batch works fine in my side:

[11/27/2024-05:24:31] [TRT-LLM] [I] ---------------------------------------------------------
[11/27/2024-05:24:31] [TRT-LLM] [I] 
[Q] None
[11/27/2024-05:24:31] [TRT-LLM] [I] 
[A]: ['The picture shows a red panda sitting on a wooden platform. The red panda has a white face with black markings around its eyes and ears. The panda']
[11/27/2024-05:24:31] [TRT-LLM] [I] 
[A]: ['The picture shows a bamboo plant with green leaves and stems. The color of the bamboo is green.']

Since I don't have your code, can you check the shape of input_ids , mrope_params and output_ids in the file tensorrt_llm/runtime/multimodal_model_runner.py

 output_ids = self.model.generate(
                input_ids,input_position_ids=input_position_ids
                if self.model_type == 'cogvlm' else None,
                mrope_params=mrope_params
                if self.model_type == 'qwen2_vl' else None,
                sampling_config=None,
                prompt_table=prompt_table,
                prompt_tasks=prompt_tasks,
                max_new_tokens=max_new_tokens,
                end_id=end_id,
                pad_id=self.tokenizer.pad_token_id
                if self.tokenizer.pad_token_id is not None else
                self.tokenizer.all_special_ids[0],
                top_k=self.args.top_k,
                top_p=self.args.top_p,
                temperature=self.args.temperature,
                repetition_penalty=self.args.repetition_penalty,
                num_beams=self.args.num_beams,
                output_sequence_lengths=False,
                return_dict=False)

@sun2011yao
Copy link

@sunnyqgg Hi, thank you for your reply. This question is a bit strange. Below are my complete running steps and related modifications. Can you help check it?

  1. convert_checkpoint.py
    MODEL_NAME=Qwen2-VL-2B-Instruct
    python3 ../qwen/convert_checkpoint.py
    --model_dir=./${MODEL_NAME}
    --output_dir=tmp/trt_models/${MODEL_NAME}/fp16/1-gpu
    --dtype float16

  2. build
    trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu
    --output_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu
    --gemm_plugin=float16
    --gpt_attention_plugin=float16
    --max_batch_size=4
    --max_input_len=2048 --max_seq_len=3072
    --max_prompt_embedding_table_size=14208
    python build_visual_engine.py --model_type qwen2_vl --model_path ./${MODEL_NAME}

  3. run
    python3 run.py
    --hf_model_dir ./${MODEL_NAME}
    --batch_size 2
    --max_new_tokens 100
    --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder
    --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/fp16/1-gpu/

Modifications to multimodal_model_runner.py:
messages1 = [{
"role":
"user",
"content": [
{
"type": "image",
"image": raw_image,
},
{
"type": "text",
"text": input_text
},
],
}]
messages2 = [{
"role":
"user",
"content": [
{
"type": "image",
"image": raw_image,
},
{
"type": "text",
"text": input_text
},
],
}]
messages = [messages1, messages2]
texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages]

input_ids shape: [2, 917]
mrope_params.mrope_rotary_sin_cos shape: [2, 4194304]
mrope_params.mrope_position_deltas shape: [2,1]
output_ids shape: [2, 1, 3072]

output text:
[12/03/2024-02:54:58] [TRT-LLM] [I]
[A]: ['The image shows a woman sitting on a sandy beach with a dog. The dog is wearing a colorful harness and is sitting on the sand, with its front paws raised in a high-five gesture. The woman is smiling and appears to be enjoying the moment. The background features the ocean with waves crashing onto the shore, and the sky is clear with a warm, golden hue, suggesting it might be either sunrise or sunset. The overall scene conveys a sense of relaxation and companionship between the woman']
[12/03/2024-02:54:58] [TRT-LLM] [I]
[A]: ['This image']

@alimgl-pixel
Copy link

I have meet similiar case with @sun2011yao ,when 2 images as input for qwenvl2-72b, the first return text describe 2 images at same time, the second text is absolutely wrong for both 2 images

@YSF-A
Copy link

YSF-A commented Dec 5, 2024

it supports batch inference, and you need to follow the batch process provided by official QWen2-VL, please see more info at: https://github.com/QwenLM/Qwen2-VL?tab=readme-ov-file , like:

rotary_coef_cache_buffer = params.mrope_rotary_sin_cos + batch_idx * params.rotary_embedding_max_positions * params.half_rotary_dim this's right, I have already changed it internally.

My case is a bit different. After fix as described above, the same two prompts can normally generate the same result.
But if the two prompt are different, only one can generate normally and another is empty.

@sunnyqgg
Copy link
Collaborator

sunnyqgg commented Dec 5, 2024

Hi,
You can also add those modifications to make it right, or waiting for the next public:
Image

Thanks,
Sunny

@YSF-A
Copy link

YSF-A commented Dec 6, 2024

You can also add those modifications to make it right, or waiting for the next public:

@sunnyqgg Hi, with above modifications, it seems that, for two different prompts, one prompt generates results of both two images while another generate empty. I am not sure is there something wrong with myself.

@alimgl-pixel
Copy link

alimgl-pixel commented Dec 7, 2024

@sunnyqgg follow the modifications, it seems does not work , two results describe both images at same time

@sunnyqgg
Copy link
Collaborator

Hi,
Please use the latest main code, for the multi-batch inference if you find the second answer contains the first image, please update the attention_mask_vit in tensorrt_llm/runtime/multimodal_model_runner.py:

           attention_mask_vit = torch.full([1, seq_length, seq_length],
                                           torch.finfo(torch.float16).min,
                                           device=image.device,
                                           dtype=image.dtype)
           for i in range(1, len(cu_seqlens)):
               attention_mask_vit[..., cu_seqlens[i - 1]:cu_seqlens[i],
                                  cu_seqlens[i - 1]:cu_seqlens[i]] = 0

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

6 participants