Skip to content

Commit

Permalink
Fix torch.onnx.export of Qwen2-VL vision encoder (huggingface#34852)
Browse files Browse the repository at this point in the history
* Fix torch.onnx.export of Qwen2-VL vision encoder

This PR fixes onnx export support for the vision encoder of Qwen2-VL, which converts the `cu_seqlens` to `torch.int32`, leading to errors later on when using the values for slicing.

https://github.com/huggingface/transformers/blob/c57eafdaa119eecae8557be4c626629bc1adc0fd/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1044-L1046

## Error:
```
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Slice, node name: /blocks.0/attn/Slice_4): axes has inconsistent type tensor(int64)
```

## Code to reproduce issue:
```py

import requests
from PIL import Image
import torch
from transformers import (
    AutoProcessor,
    Qwen2VLForConditionalGeneration,
)

# Constants
VISION_MODEL_NAME = "vision_encoder.onnx"

# Load model and processor
model_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
model = Qwen2VLForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)

# Prepare inputs
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
conversation = [
    {
        "role": "user",
        "content": [
            { "type": "image" },
            { "type": "text", "text": "Describe this image."},
        ],
    },
]
images = [image]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=[text_prompt], images=images, padding=True, return_tensors="pt")

## Vision model
vision_inputs = dict(
    pixel_values=inputs["pixel_values"],
    grid_thw=inputs["image_grid_thw"],
)
vision_inputs_positional = tuple(vision_inputs.values())
vision_outputs = model.visual.forward(*vision_inputs_positional)  # Test forward pass
torch.onnx.export(
    model.visual,
    args=vision_inputs_positional,
    f=VISION_MODEL_NAME,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=list(vision_inputs.keys()),
    output_names=["image_features"],
    dynamic_axes={
        "pixel_values": {
            0: "batch_size * grid_t * grid_h * grid_w",
            1: "channel * temporal_patch_size * patch_size * patch_size",
        },
        "grid_thw": {0: "batch_size"},
        "image_features": {0: "batch_size * grid_t * grid_h * grid_w"},
    },
)

# Load and check the exported model model
import onnx
model = onnx.load(VISION_MODEL_NAME)
onnx.checker.check_model(model, full_check=True)
inferred = onnx.shape_inference.infer_shapes(model, check_type=True)
```

* Formatting

* [run-slow] qwen2_vl
  • Loading branch information
xenova authored Nov 26, 2024
1 parent d5cf91b commit 1f6b423
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
rotary_pos_emb = self.rot_pos_emb(grid_thw)

cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32
dim=0, dtype=grid_thw.dtype
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

Expand Down

0 comments on commit 1f6b423

Please sign in to comment.