Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support open-clip onnx export #1466

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
47 changes: 47 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from .model_patcher import (
FalconModelPatcher,
OpenCLIPModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
SentenceTransformersTransformerPatcher,
Expand Down Expand Up @@ -869,6 +870,52 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class OpenCLIPOnnxConfig(CLIPOnnxConfig):
DEFAULT_ONNX_OPSET = 18

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
"attention_mask": {0: "text_batch_size"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"text_features": {0: "text_batch_size"},
"image_features": {0: "image_batch_size"},
"logit_scale": {},
}

def rename_ambiguous_inputs(self, inputs):
model_inputs = {}
model_inputs["image"] = inputs["pixel_values"]
model_inputs["text"] = inputs["input_ids"]
return model_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# override sequence_length shape here in the kwargs
kwargs["sequence_length"] = self._preprocessors[0].model_max_length
return super().generate_dummy_inputs(framework, **kwargs)

def generate_dummy_inputs_for_validation(self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
) -> Dict[str, Any]:
if "attention_mask" in reference_model_inputs:
reference_model_inputs.pop("attention_mask")
if "image" in onnx_input_names and "pixel_values" in reference_model_inputs:
reference_model_inputs["image"] = reference_model_inputs.pop("pixel_values")
if "text" in onnx_input_names and "input_ids" in reference_model_inputs:
reference_model_inputs["text"] = reference_model_inputs.pop("input_ids")
return super().generate_dummy_inputs_for_validation(reference_model_inputs)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return OpenCLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand Down
36 changes: 36 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import dataclasses
import functools
import importlib.util
import inspect
import math
import sys
Expand All @@ -25,6 +26,7 @@
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from transformers.utils import is_torch_available

from ...utils.import_utils import is_open_clip_available

if is_torch_available():
import torch
Expand Down Expand Up @@ -794,3 +796,37 @@ def patched_forward(input_ids, attention_mask, pixel_values):
return {"text_embeds": text_embeds, "image_embeds": image_embeds}

self.patched_forward = patched_forward

if is_open_clip_available():
import open_clip

def _text_global_pool_patched(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
if pool_type == 'first':
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == 'last':
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == 'argmax':
text = text.to(dtype=torch.int32) # ONNX Runtime is unable to run argmax with int64 input, hence this cast.
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens


class OpenCLIPModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self.original_text_global_pool = open_clip.transformer.text_global_pool

def __enter__(self):
open_clip.transformer.text_global_pool.__code__ = _text_global_pool_patched.__code__

def __exit__(self, exc_type, exc_value, traceback):
open_clip.transformer.text_global_pool.__code__ = self.original_text_global_pool.__code__

28 changes: 24 additions & 4 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class TasksManager:
"image-classification": "create_model",
}

_OPEN_CLIP_TASKS_TO_MODEL_LOADERS = {
"zero-shot-image-classification": "create_model_and_transforms",
}

_SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS = {
"feature-extraction": "SentenceTransformer",
"sentence-similarity": "SentenceTransformer",
Expand All @@ -205,6 +209,7 @@ class TasksManager:
"diffusers": _DIFFUSERS_TASKS_TO_MODEL_LOADERS,
"sentence_transformers": _SENTENCE_TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
"timm": _TIMM_TASKS_TO_MODEL_LOADERS,
"open_clip": _OPEN_CLIP_TASKS_TO_MODEL_LOADERS,
"transformers": _TRANSFORMERS_TASKS_TO_MODEL_LOADERS,
}

Expand Down Expand Up @@ -401,6 +406,10 @@ class TasksManager:
onnx="CamembertOnnxConfig",
tflite="CamembertTFLiteConfig",
),
"open-clip": supported_tasks_mapping(
"zero-shot-image-classification",
onnx="OpenCLIPOnnxConfig",
),
"clip": supported_tasks_mapping(
"feature-extraction",
"zero-shot-image-classification",
Expand Down Expand Up @@ -1656,7 +1665,7 @@ def standardize_model_attributes(
full_model_path = Path(model_name_or_path) / subfolder
is_local = full_model_path.is_dir()

if library_name == "timm":
if library_name == "timm" or library_name == "open_clip":
# Retrieve model config
config_path = full_model_path / "config.json"

Expand All @@ -1673,9 +1682,12 @@ def standardize_model_attributes(
# Set config as in transformers
setattr(model, "config", model_config)

# Update model_type for model
with open(config_path) as fp:
model_type = json.load(fp)["architecture"]
if library_name == "timm":
# Update model_type for model
with open(config_path) as fp:
model_type = json.load(fp)["architecture"]
else:
model_type = "open-clip"

setattr(model.config, "model_type", model_type)
elif library_name == "sentence_transformers":
Expand Down Expand Up @@ -1796,6 +1808,14 @@ def get_model_from_task(
model = model_class(
model_name_or_path, device=device, cache_folder=cache_folder, use_auth_token=use_auth_token
)
return model
elif library_name == "open_clip":
model, _, _ = model_class(f"hf-hub:{model_name_or_path}", cache_dir=cache_dir, output_dict=True)
TasksManager.standardize_model_attributes(
model_name_or_path, model, subfolder, revision, cache_dir, library_name
)
return model

else:
try:
if framework == "pt":
Expand Down
5 changes: 5 additions & 0 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_auto_gptq_available = _is_package_available("auto_gptq")
_timm_available = _is_package_available("timm")
_sentence_transformers_available = _is_package_available("sentence_transformers")
_open_clip_available = _is_package_available("open_clip")

torch_version = None
if is_torch_available():
Expand Down Expand Up @@ -127,6 +128,10 @@ def is_timm_available():
return _timm_available


def is_open_clip_available():
return _open_clip_available


def is_sentence_transformers_available():
return _sentence_transformers_available

Expand Down
5 changes: 5 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def __getattr__(self, attr_name):
for attr in attr_name[:-1]:
config = getattr(config, attr)

# We cast potential dictionaries to PretrainedConfig for getattr to work for nested structures, where nested dictionaries
# may not always themselves be PretrainedConfig instances (e.g. timm, open_clip).
if isinstance(config, dict):
config = PretrainedConfig.from_dict(config)

attr = getattr(config, leaf_attr_name, None)

# If the attribute was not specified manually, try to fallback on the attribute_map.
Expand Down
5 changes: 5 additions & 0 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_accelerate_available,
is_auto_gptq_available,
is_diffusers_available,
is_open_clip_available,
is_sentence_transformers_available,
is_timm_available,
)
Expand Down Expand Up @@ -143,6 +144,10 @@ def require_timm(test_case):
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)


def require_open_clip(test_case):
return unittest.skipUnless(is_open_clip_available(), "test requires open_clip")(test_case)


def require_sentence_transformers(test_case):
return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case)

Expand Down
4 changes: 4 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,7 @@
"sentence-transformers-clip": "sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers-transformer": "sentence-transformers/clip-ViT-B-32-multilingual-v1",
}

PYTORCH_OPEN_CLIP_MODEL = {
"open-clip": "laion/CLIP-ViT-B-16-laion2B-s34B-b88K",
}
26 changes: 26 additions & 0 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from ..exporters_utils import (
PYTORCH_EXPORT_MODELS_TINY,
PYTORCH_OPEN_CLIP_MODEL,
PYTORCH_SENTENCE_TRANSFORMERS_MODEL,
PYTORCH_STABLE_DIFFUSION_MODEL,
PYTORCH_TIMM_MODEL,
Expand Down Expand Up @@ -296,6 +297,31 @@ def test_exporters_cli_fp16_timm(
):
self._onnx_export(model_name, task, monolith, no_post_process, device="cuda", fp16=True)

@parameterized.expand(PYTORCH_OPEN_CLIP_MODEL.items())
@require_torch
@require_vision
@require_open_clip
def test_exporters_cli_pytorch_cpu_open_clip(self, model_type: str, model_name: str):
self._onnx_export(model_name, model_type)

@parameterized.expand(PYTORCH_OPEN_CLIP_MODEL.items())
@require_torch_gpu
@require_vision
@require_open_clip
@slow
@pytest.mark.run_slow
def test_exporters_cli_pytorch_gpu_open_clip(self, model_type: str, model_name: str):
self._onnx_export(model_name, model_type, device="cuda")

@parameterized.expand(PYTORCH_OPEN_CLIP_MODEL.items())
@require_torch_gpu
@require_vision
@require_open_clip
@slow
@pytest.mark.run_slow
def test_exporters_cli_fp16_open_clip(self, model_type: str, model_name: str):
self._onnx_export(model_name, model_type, device="cuda", fp16=True)

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
@require_torch
@require_vision
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
PYTORCH_SENTENCE_TRANSFORMERS_MODEL,
PYTORCH_STABLE_DIFFUSION_MODEL,
PYTORCH_TIMM_MODEL,
PYTORCH_OPEN_CLIP_MODEL,
TENSORFLOW_EXPORT_MODELS,
VALIDATE_EXPORT_ON_SHAPES_SLOW,
)
Expand Down