Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support open-clip onnx export #1466

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

isaac-chung
Copy link

What does this PR do?

Fixes # (issue) #1450

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@isaac-chung
Copy link
Author

@fxmarty here's the WIP. Currently when I run the export command, I get

(v1) isaacchung@Isaacs-MBP optimum % optimum-cli export onnx -m laion/CLIP-ViT-B-32-laion2B-s34B-b79K --framework pt clip_onnx
Automatic task detection to zero-shot-image-classification.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
/Users/isaacchung/work/transformers/src/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
  warnings.warn(
Using the export variant default. Available variants are:
        - default: The default ONNX variant.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Using framework PyTorch: 2.0.0
Traceback (most recent call last):
  File "/Users/isaacchung/virtualenv/v1/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
  File "/Users/isaacchung/work/optimum/optimum/commands/optimum_cli.py", line 163, in main
    service.run()
  File "/Users/isaacchung/work/optimum/optimum/commands/export/onnx.py", line 239, in run
    main_export(
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/__main__.py", line 505, in main_export
    _, onnx_outputs = export_models(
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/convert.py", line 752, in export_models
    export(
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/convert.py", line 855, in export
    export_output = export_pytorch(
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/convert.py", line 542, in export_pytorch
    dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/base.py", line 454, in generate_dummy_inputs
    dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/base.py", line 222, in _create_dummy_input_generator_classes
    first_inputs_gen = self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config, **kwargs)
  File "/Users/isaacchung/work/optimum/optimum/utils/input_generators.py", line 343, in __init__
    self.vocab_size = normalized_config.vocab_size
  File "/Users/isaacchung/work/optimum/optimum/utils/normalized_config.py", line 109, in __getattr__
    return super().__getattr__(attr_name)
  File "/Users/isaacchung/work/optimum/optimum/utils/normalized_config.py", line 69, in __getattr__
    raise AttributeError(f'Could not find the attribute named "{leaf_attr_name}" in the normalized config.')
AttributeError: Could not find the attribute named "vocab_size" in the normalized config.

I'm stepping away for a bit this evening but I will continue afterwards.

@fxmarty
Copy link
Contributor

fxmarty commented Oct 18, 2023

@isaac-chung Apology for the convoluted error. Here is a patch for this one: isaac-chung#2

There is then an other error about the naming of the model inputs. One approach you could take is similar to

def rename_ambiguous_inputs(self, inputs):
&
inputs = self.rename_ambiguous_inputs(inputs)
so that we keep the same naming as transformers.

Hopefully then input / outputs of the open_clip model is the same as transformers as well and we don't run into more issues.

@isaac-chung
Copy link
Author

isaac-chung commented Oct 18, 2023

@fxmarty amazing, thank you! 🙌 I've merged your patch into this branch, and will continue on reconciling the inputs / output naming differences.

In terms of tests, should we use one of these LAION models (the smallest)? Or something smaller?

@fxmarty
Copy link
Contributor

fxmarty commented Oct 18, 2023

Yes, it appears there are no small open clip models on the Hub. Probably the most reasonable option (if you don't want to go into the hassle of addind a tiny testing model) is to just add a nightly test (@slow) for open clip with a somewhat "small" model.

@isaac-chung
Copy link
Author

isaac-chung commented Oct 18, 2023

Yep the @slow option sounds good. Perhaps the exercise of adding a tiny testing model can be separate.

For input mapping, the key matching was straightforward. However, it seems like there's some shape mismatch. Looks like CLIP tokenizers pad inputs to len=77 (see colab notebook), but inputs["input_ids"] shape has (batch_size, 16):

...
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/model_patcher.py", line 118, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
  File "/Users/isaacchung/virtualenv/v1/lib/python3.9/site-packages/open_clip/model.py", line 248, in forward
    text_features = self.encode_text(text, normalize=True) if text is not None else None
  File "/Users/isaacchung/virtualenv/v1/lib/python3.9/site-packages/open_clip/model.py", line 233, in encode_text
    x = x + self.positional_embedding.to(cast_dtype)
RuntimeError: The size of tensor a (16) must match the size of tensor b (77) at non-singleton dimension 1

It does seem like maybe_save_preprocessors was run and the preprocessors were saved, but maybe not initialized properly? Have you seen this kind of mismatch before?

tokenizer: CLIPTokenizerFast(name_or_path='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', vocab_size=49408, model_max_length=77, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
        49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
        49407: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

@fxmarty
Copy link
Contributor

fxmarty commented Oct 19, 2023

@isaac-chung Thank you. It appears that the tokenize function from open_clip and the open_clip.tokenizer.HFTokenizer do pad by default to tokenizer.tokenizer.model_max_length, while Transformers CLIP do not. It is unclear to me why there is this difference.

Something you could try in the OpenCLIPOnnxConfig is first to remove the dynamic axis for the sequence length by redefining the inputs property:

    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        return {
            "input_ids": {0: "text_batch_size"},
            "pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
            "attention_mask": {0: "text_batch_size"},
        }

As for the shape itself used during the export, it should be held in this kwargs:

def generate_dummy_inputs(self, framework: str = "pt", **kwargs) -> Dict:

What I may suggest is to override the shape from the openclip onnx config:

    def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
        # override sequence_length shape here in the kwargs
        return super().generate_dummy_inputs(framework, **kwargs)

You will need to access the tokenizer though, that hopefully you can access under the self._preprocessors attribute.

@isaac-chung
Copy link
Author

Thanks for the tips so far, @fxmarty ! I am wondering if we could draw some inspiration from pix2StructOnnxConfig, esp. def overwrite_shape_and_generate_input(...), where they also require padding inputs to the max length.

# pix2struct takes inputs whose so-called sequence length is **static** to max_patches, so we do NOT use
# the passed sequence_length that behaves as a dynamic shape.

Probably needs to be massaged to fit into our OpenCLIPOnnxConfig. Will investigate.

@fxmarty
Copy link
Contributor

fxmarty commented Oct 20, 2023

@isaac-chung Yes that is a good option as well! It is just that overwrite_shape_and_generate_input is not defined at the top level OnnxConfig, so you would need to define it there (and basically just generate inputs, no override of shapes at the tope level), and redefine it in OpenClipOnnxConfig to override the relevant shapes.

@isaac-chung
Copy link
Author

I see what you mean now, thanks again! Will continue with the generate_dummy_inputs approach. Seems more straightforward 💪

@isaac-chung
Copy link
Author

Seeing a different error, so I'll count this as progress :D It seems like this issue has been resolved, but specifying --opset 18 and installing torch nightly yielded the following

 File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/convert.py", line 883, in export
    config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/base.py", line 302, in fix_dynamic_axes
    session = InferenceSession(model_path.as_posix(), providers=providers, sess_options=session_options)
  File "/Users/isaacchung/virtualenv/v1/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/Users/isaacchung/virtualenv/v1/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 463, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ArgMax(13) node with name '/ArgMax'

This looks similar to this issue but doesn't seem to apply to the types here (read that we only deal with fp16).

@fxmarty
Copy link
Contributor

fxmarty commented Oct 23, 2023

Hi @isaac-chung , could you check where is the argmax in the modeling code? What do you mean by we only deal with fp16?

@isaac-chung
Copy link
Author

The argmax is within model.encode_text. In this line, self.text_pool_type is arg_max.

What do you mean by we only deal with fp16?

I misread things, please disregard.

@fxmarty
Copy link
Contributor

fxmarty commented Oct 23, 2023

Thank you! Can you check the dtype of text here https://github.com/mlfoundations/open_clip/blob/7b8dd2cbaf9cca13ca5b1defa6a321a145eb166c/src/open_clip/transformer.py#L558?

A solution would be to use the ModelPatcher (see e.g.

class SpeechT5ModelPatcher(ModelPatcher):
def __enter__(self):
self.patch_ops()
self._model.speecht5.decoder.prenet.forward = types.MethodType(
patched_speecht5_prenet_forward, self._model.speecht5.decoder.prenet
)
setattr(self._model, self.orig_forward_name, self.patched_forward)
def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
setattr(self._model, self.orig_forward_name, self.orig_forward)
self._model.speecht5.decoder.prenet.forward = types.MethodType(
self.original_speecht5_prenet_forward, self._model.speecht5.decoder.prenet
)
&
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return SpeechT5ModelPatcher(self, model, model_kwargs=model_kwargs)
) to cast the arg text in text_global_pool to a dtype accepted by ONNX Runtime: https://github.com/microsoft/onnxruntime/blob/main/docs/OperatorKernels.md

Something like:

if is_open_clip_available():
    import open_clip

class OpenClipModelPatcher(...):
    def __enter__(...):
        ...
        open_clip.transformer.text_global_pool = text_global_pool_patched

    def __exit__(...):

@isaac-chung
Copy link
Author

text is of type torch.Tensor.

And great, thank you! Let me take a look and report back.

@fxmarty
Copy link
Contributor

fxmarty commented Oct 23, 2023

I meant dtype!

@isaac-chung
Copy link
Author

Whoops, let me see. I got int64. Maybe we should cast it to something like int8?

In [5]: text.dtype
Out[5]: torch.int64

It is odd that the main branch does not have the same method.

@fxmarty
Copy link
Contributor

fxmarty commented Oct 23, 2023

@isaac-chung According to https://github.com/microsoft/onnxruntime/blob/main/docs/OperatorKernels.md you could try to add a cast to torch.int32 before the operation.

@isaac-chung
Copy link
Author

isaac-chung commented Oct 23, 2023

Turns out the newest commit that introduced open_clip.transformer.text_global_pool is not in v2.22.0 yet. So I installed from source.

With the newest commit here, the same NOT IMPLEMENTED error still persists. Wondering if my implementation is missing something glaringly obvious 🤔

optimum/exporters/onnx/model_patcher.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/model_patcher.py Outdated Show resolved Hide resolved

def __enter__(self):
self.patch_ops()
open_clip.transformer.text_global_pool = _text_global_pool_patched
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try:

Suggested change
open_clip.transformer.text_global_pool = _text_global_pool_patched
open_clip.transformer.text_global_pool.__code__ = _text_global_pool_patched.__code__

I believe that given that the model is already instanciated, the old text_global_pool is still used despite patching. This can be checked by printing stuff in _text_global_pool_patched.

Copy link
Author

@isaac-chung isaac-chung Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yep with the addition of __code__, print stmt works within _text_global_pool_patched. The torch.export seems to have run without issues, and the onnx input names are now ['images','text']. Later this error is raised:

  File "/Users/isaacchung/work/optimum/optimum/exporters/onnx/base.py", line 331, in fix_dynamic_axes
    outputs = session.run(None, onnx_inputs)
  File "/Users/isaacchung/virtualenv/v1/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 216, in run
    self._validate_input(list(input_feed.keys()))
  File "/Users/isaacchung/virtualenv/v1/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 198, in _validate_input
    raise ValueError(
ValueError: Required inputs (['image', 'text']) are missing from input feed (['input_ids', 'pixel_values', 'attention_mask']).

Interestingly, in fix_dynamic_axes, setting to_fix = [] allows the ONNX export to succeed, but the following warning is raised, which tells me that the model is likely incorrect without fix_dynamic_axes.

The ONNX export succeeded with the warning: The exported ONNX model does not have the exact same outputs as what is provided in OpenCLIPOnnxConfig. Difference: 4624.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the input name should be fixed as well there. You can check how the inputs are generated and whether the input remap is called & where/why/why not.

Copy link
Author

@isaac-chung isaac-chung Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made some good progress in this new commit. I feel that we are close. The only diff now is in the values are mismatching. It's not a small difference given the ATOL. Maybe the preprocessing is off. There's no stochastic part either.
[edit]: Could it be the int64 -> int32 conversion we added? Might be a bit too large to be that.

Validating ONNX model clip_onnx/model.onnx...
        -[✓] ONNX model output names match reference model (logit_scale, image_features, text_features)
        - Validating ONNX Model output "text_features":
                -[✓] (2, 512) matches (2, 512)
                -[x] values not close enough, max diff: 0.3609406352043152 (atol: 1e-05)
        - Validating ONNX Model output "image_features":
                -[✓] (2, 512) matches (2, 512)
                -[x] values not close enough, max diff: 0.36094051599502563 (atol: 1e-05)
        - Validating ONNX Model output "logit_scale":
                -[✓] () matches ()
                -[✓] all values close (atol: 1e-05)
The ONNX export succeeded with the warning: The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance 1e-05:
- text_features: max diff = 0.3609406352043152
- image_features: max diff = 0.36094051599502563.
 The exported model was saved at: clip_onnx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, great work! I wonder if 0.36 is actually small compared to the output distribution or not. It could be worth checking the mean absolute difference. Maybe there's a dropout with training=True somewhere (like here https://github.com/huggingface/transformers/blob/0baa9246cb1ddac355db1df7824a521426599eb7/src/transformers/models/speecht5/modeling_speecht5.py#L720, though I doubt it)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same error - so it does not seem to be a device issue.

Copy link

@mertalev mertalev Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually working on exporting open_clip models and ran into an issue when calling torch.jit.trace. Not sure if it's the same issue, but I had to explicitly do

for param in model.parameters():
   param.requires_grad_(False)

.eval() and with torch.no_grad() didn't help; it would get stuck while tracing unless I added the above snippet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did you add the snippet btw?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just before calling torch.jit.trace. Here's the link

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Just tested it out by add that within TasksManager.get_model_from_task, and setting jit=False, still the same error. I don't think the issue here is jit related.

@isaac-chung
Copy link
Author

Just a heads up - I anticipate a busier week this week, so hoping to continue the investigation afterwards. If any new ideas pop to mind, feel free to continue the discussion and I can catch back up.

@isaac-chung
Copy link
Author

It is possible that the discrepancy is due to the task being auto-detected as "zero-shot-classification" instead of "feature-extraction". Currently only the image/text embeds are in the output, where as for ZS-classification we need the logits (I was reading this). The scale more-or-less matches the magnitude of logit_scale.

I haven't been able to confirm this yet - will start looking into adding the feature-extraction task as well and how to add the logic for generating those logits from features in the model config. It seems like this might require the model patcher when the task is ZS-classification?

@isaac-chung
Copy link
Author

Sorry for the long pause.
Printing out the validation step ref outputs vs. onnx outputs, at a glance the values look very close.

Validating ONNX model clip_onnx/model.onnx...
ref_outputs_dict={'image_features': tensor([[ 0.0082, -0.0139, -0.0986,  ...,  0.0043,  0.0170, -0.0243],
        [ 0.0125, -0.0246, -0.0875,  ...,  0.0078,  0.0198, -0.0267]]), 'text_features': tensor([[-0.0114, -0.0058, -0.0060,  ..., -0.0212, -0.0546, -0.0145],
        [-0.0216, -0.0326, -0.0414,  ...,  0.0304,  0.0093,  0.0012]]), 'logit_scale': tensor(100.0000)}
onnx_outputs=[array([[ 0.00823098, -0.01392082, -0.09864607, ...,  0.00433102,
         0.01700564, -0.02425753],
       [ 0.01247843, -0.02459271, -0.08752848, ...,  0.00782712,
         0.01979438, -0.02666864]], dtype=float32), array([[-0.01138339, -0.00584862, -0.0059775 , ..., -0.02118963,
        -0.05462093, -0.01450986],
       [-0.02160461, -0.03263872, -0.04141301, ...,  0.03040083,
         0.00926695,  0.00119737]], dtype=float32), array(100.00001, dtype=float32)]

Validation would only pass if ATOL_FOR_VALIDATION = 1.

@kechan
Copy link

kechan commented May 31, 2024

Thanks all very much, looks like lot of work done to support CLIP indeed. I am looking to try export ONNX.
I am using the L14 and here's the cli output. Will utilize and report back if I see any issues.

optimum-cli export onnx -m laion/CLIP-ViT-L-14-laion2B-s32B-b82K --framework pt clip_onnx
.../python310_env/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
warnings.warn(
Automatic task detection to zero-shot-image-classification.
.../python310_env/lib/python3.10/site-packages/transformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use CLIPImageProcessor instead.
warnings.warn(
Using the export variant default. Available variants are:
- default: The default ONNX variant.

***** Exporting submodel 1/1: CLIPModel *****
Using framework PyTorch: 2.2.1
.../python310_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:279: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
.../python310_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:319: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
.../python310_env/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:86: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if input_shape[-1] > 1 or self.sliding_window is not None:
.../python310_env/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:162: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if past_key_values_length > 0:
.../python310_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:287: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
.../python310_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py:296: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
.../python310_env/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:5859: UserWarning: Exporting aten::index operator of advanced indexing in opset 14 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.
warnings.warn(
Post-processing the exported models...
Deduplicating shared (tied) weights...

Validating ONNX model clip_onnx/model.onnx...
-[✓] ONNX model output names match reference model (image_embeds, text_embeds, logits_per_text, logits_per_image)
- Validating ONNX Model output "logits_per_image":
-[✓] (2, 2) matches (2, 2)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "logits_per_text":
-[✓] (2, 2) matches (2, 2)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "text_embeds":
-[✓] (2, 768) matches (2, 768)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "image_embeds":
-[✓] (2, 768) matches (2, 768)
-[✓] all values close (atol: 1e-05)
The ONNX export succeeded and the exported model was saved at: clip_onnx

@sushilkhadkaanon
Copy link

sushilkhadkaanon commented Aug 29, 2024

Hi @isaac-chung @fxmarty , I've cloned your repo and installed optimum from isaac-chung:support-open-clip-onnx-export,
But still it's not working for me,
orch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::_native_multi_head_attention' to ONNX opset version 18 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

I think this issue might be due to newer version of dependencies, Could you guys please give me info about your environment so that I can reproduce. Thanks!

@sushilkhadkaanon
Copy link

sushilkhadkaanon commented Aug 30, 2024

Hi @kechan were you able to export to onnx?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants