Skip to content

Commit

Permalink
add ipex backend
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Nov 25, 2024
1 parent 348190d commit 0455f54
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
8 changes: 4 additions & 4 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin):
model_card_data (:class:`~sentence_transformers.model_card.SentenceTransformerModelCardData`, optional): A model
card data object that contains information about the model. This is used to generate a model card when saving
the model. If not set, a default model card data object is created.
backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", or "openvino".
backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", "openvino", or "ipex".
See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information
on the different backends.
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
tokenizer_kwargs: dict[str, Any] | None = None,
config_kwargs: dict[str, Any] | None = None,
model_card_data: SentenceTransformerModelCardData | None = None,
backend: Literal["torch", "onnx", "openvino"] = "torch",
backend: Literal["torch", "onnx", "openvino", "ipex"] = "torch",
) -> None:
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand Down Expand Up @@ -382,8 +382,8 @@ def __init__(
# Pass the model to the model card data for later use in generating a model card upon saving this model
self.model_card_data.register_model(self)

def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
"""Return the backend used for inference, which can be one of "torch", "onnx", or "openvino".
def get_backend(self) -> Literal["torch", "onnx", "openvino", "ipex"]:
"""Return the backend used for inference, which can be one of "torch", "onnx", "openvino" or "ipex".
Returns:
str: The backend used for inference.
Expand Down
24 changes: 22 additions & 2 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Transformer(nn.Module):
tokenizer_name_or_path: Name or path of the tokenizer. When
None, then model_name_or_path is used
backend: Backend used for model inference. Can be `torch`, `onnx`,
or `openvino`. Default is `torch`.
`openvino`, or `ipex`. Default is `torch`.
"""

save_in_root: bool = True
Expand Down Expand Up @@ -143,8 +143,10 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
elif backend == "openvino":
self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args)
elif backend == "ipex":
self._load_ipex_model(model_name_or_path, config, cache_dir, **model_args)
else:
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.")
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, `openvino`, or `ipex`.")

def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
if is_peft_available():
Expand Down Expand Up @@ -210,6 +212,24 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar
if export:
self._backend_warn_to_save(model_name_or_path, is_local, backend_name)

def _load_ipex_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
try:
from optimum.intel import IPEXModel
except ModuleNotFoundError:
raise Exception(
"Using the IPEX backend requires installing Optimum and IPEX. "
"You can install them with pip: `pip install optimum-intel[ipex]`."
)

self.auto_model: IPEXModel = IPEXModel.from_pretrained(
model_name_or_path,
config=config,
cache_dir=cache_dir,
**model_args,
)
# Wrap the save_pretrained method to save the model in the correct subfolder
self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, self.backend)

def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
try:
import onnxruntime as ort
Expand Down

0 comments on commit 0455f54

Please sign in to comment.