From 83995998605ef53040355532b0bc55219cfd8b6a Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Fri, 12 Jul 2024 21:21:55 +0800 Subject: [PATCH] Add `Alibaba-NLP/gte-large-en-v1.5` model support (#22) Signed-off-by: kaixuanliu --- .../server/text_embeddings_server/models/__init__.py | 9 ++++++--- .../text_embeddings_server/models/default_model.py | 8 ++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 8ff0dd25..360a2d40 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -1,3 +1,4 @@ +import os import torch from loguru import logger @@ -12,6 +13,8 @@ __all__ = ["Model"] HTCORE_AVAILABLE = True +TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] + try: import habana_frameworks.torch.core as htcore except ImportError as e: @@ -51,7 +54,7 @@ def get_model(model_path: Path, dtype: Optional[str]): raise ValueError("CPU device only supports float32 dtype") device = torch.device("cpu") - config = AutoConfig.from_pretrained(model_path) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) if config.model_type == "bert": config: BertConfig @@ -63,10 +66,10 @@ def get_model(model_path: Path, dtype: Optional[str]): ): return FlashBert(model_path, device, dtype) else: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) else: try: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) except: raise RuntimeError(f"Unknown model_type {config.model_type}") diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index a4ae0c43..75f5d8a4 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -17,10 +17,14 @@ class DefaultModel(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + trust_remote: bool=False): if device == torch.device("hpu"): adapt_transformers_to_gaudi() - model = AutoModel.from_pretrained(model_path).to(dtype).to(device) + model = AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote).to(dtype).to(device) if device == torch.device("hpu"): logger.info("Use graph mode for HPU") model = wrap_in_hpu_graph(model, disable_tensor_cache=True)