Skip to content

Commit

Permalink
fix conversion for text embeddings for fp16 models (#968)
Browse files Browse the repository at this point in the history
* fix conversion for text embeddings for fp16 models

* fix rebasing issue

* apply review comments

* Update tests/openvino/utils_tests.py
  • Loading branch information
eaidova authored Nov 8, 2024
1 parent d357376 commit 222748e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
7 changes: 7 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
GptNeoxJapaneseModelPatcher,
GptNeoxModelPatcher,
IBertModelPatcher,
InputEmbeddingPatcher,
InternLM2Patcher,
InternLMModelPatcher,
InternVLChatImageEmbeddingModelPatcher,
Expand Down Expand Up @@ -1264,6 +1265,12 @@ def rename_ambiguous_inputs(self, inputs):
model_inputs["input"] = inputs["input_ids"]
return model_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
# making 16bit tracable overrides embeedings input signature these changes required to prevent this issue
return InputEmbeddingPatcher(self, model, model_kwargs)


class LlavaConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
Expand Down
21 changes: 21 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2991,3 +2991,24 @@ def __init__(
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class InputEmbeddingPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward

def forward(self, input):
return self.__orig_forward(input)

model.forward = types.MethodType(forward, model)

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward

0 comments on commit 222748e

Please sign in to comment.