Skip to content

Commit

Permalink
Add Alibaba-NLP/gte-large-en-v1.5 model support (#22)
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuanliu <[email protected]>
  • Loading branch information
kaixuanliu authored Jul 12, 2024
1 parent 3db1796 commit 8399599
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch

from loguru import logger
Expand All @@ -12,6 +13,8 @@
__all__ = ["Model"]

HTCORE_AVAILABLE = True
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]

try:
import habana_frameworks.torch.core as htcore
except ImportError as e:
Expand Down Expand Up @@ -51,7 +54,7 @@ def get_model(model_path: Path, dtype: Optional[str]):
raise ValueError("CPU device only supports float32 dtype")
device = torch.device("cpu")

config = AutoConfig.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

if config.model_type == "bert":
config: BertConfig
Expand All @@ -63,10 +66,10 @@ def get_model(model_path: Path, dtype: Optional[str]):
):
return FlashBert(model_path, device, dtype)
else:
return DefaultModel(model_path, device, dtype)
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
else:
try:
return DefaultModel(model_path, device, dtype)
return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE)
except:
raise RuntimeError(f"Unknown model_type {config.model_type}")

Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@


class DefaultModel(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
def __init__(self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
trust_remote: bool=False):
if device == torch.device("hpu"):
adapt_transformers_to_gaudi()
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
model = AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote).to(dtype).to(device)
if device == torch.device("hpu"):
logger.info("Use graph mode for HPU")
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
Expand Down

0 comments on commit 8399599

Please sign in to comment.