Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix torch.onnx.export of Qwen2-VL vision encoder (huggingface#34852)
* Fix torch.onnx.export of Qwen2-VL vision encoder This PR fixes onnx export support for the vision encoder of Qwen2-VL, which converts the `cu_seqlens` to `torch.int32`, leading to errors later on when using the values for slicing. https://github.com/huggingface/transformers/blob/c57eafdaa119eecae8557be4c626629bc1adc0fd/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1044-L1046 ## Error: ``` onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Slice, node name: /blocks.0/attn/Slice_4): axes has inconsistent type tensor(int64) ``` ## Code to reproduce issue: ```py import requests from PIL import Image import torch from transformers import ( AutoProcessor, Qwen2VLForConditionalGeneration, ) # Constants VISION_MODEL_NAME = "vision_encoder.onnx" # Load model and processor model_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" model = Qwen2VLForConditionalGeneration.from_pretrained(model_id).eval() processor = AutoProcessor.from_pretrained(model_id) # Prepare inputs url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" image = Image.open(requests.get(url, stream=True).raw) conversation = [ { "role": "user", "content": [ { "type": "image" }, { "type": "text", "text": "Describe this image."}, ], }, ] images = [image] text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=[text_prompt], images=images, padding=True, return_tensors="pt") ## Vision model vision_inputs = dict( pixel_values=inputs["pixel_values"], grid_thw=inputs["image_grid_thw"], ) vision_inputs_positional = tuple(vision_inputs.values()) vision_outputs = model.visual.forward(*vision_inputs_positional) # Test forward pass torch.onnx.export( model.visual, args=vision_inputs_positional, f=VISION_MODEL_NAME, export_params=True, opset_version=14, do_constant_folding=True, input_names=list(vision_inputs.keys()), output_names=["image_features"], dynamic_axes={ "pixel_values": { 0: "batch_size * grid_t * grid_h * grid_w", 1: "channel * temporal_patch_size * patch_size * patch_size", }, "grid_thw": {0: "batch_size"}, "image_features": {0: "batch_size * grid_t * grid_h * grid_w"}, }, ) # Load and check the exported model model import onnx model = onnx.load(VISION_MODEL_NAME) onnx.checker.check_model(model, full_check=True) inferred = onnx.shape_inference.infer_shapes(model, check_type=True) ``` * Formatting * [run-slow] qwen2_vl
- Loading branch information