diff --git a/README.md b/README.md index c33a043a..5baf8fe9 100644 --- a/README.md +++ b/README.md @@ -44,4 +44,3 @@ Not all features of TEI are currently supported as this is still a work in progr > The license to use TEI on Habana Gaudi is the one of TEI: https://github.com/huggingface/text-embeddings-inference/blob/main/LICENSE > > Please reach out to api-enterprise@huggingface.co if you have any question. - diff --git a/backends/grpc-client/src/client.rs b/backends/grpc-client/src/client.rs index 524a03b3..2f4868f5 100644 --- a/backends/grpc-client/src/client.rs +++ b/backends/grpc-client/src/client.rs @@ -64,4 +64,25 @@ impl Client { let response = self.stub.embed(request).await?.into_inner(); Ok(response.embeddings) } + + #[instrument(skip_all)] + pub async fn predict( + &mut self, + input_ids: Vec, + token_type_ids: Vec, + position_ids: Vec, + cu_seq_lengths: Vec, + max_length: u32, + ) -> Result> { + let request = tonic::Request::new(EmbedRequest { + input_ids, + token_type_ids, + position_ids, + max_length, + cu_seq_lengths, + }) + .inject_context(); + let response = self.stub.predict(request).await?.into_inner(); + Ok(response.scores) + } } diff --git a/backends/proto/embed.proto b/backends/proto/embed.proto index b245fd57..036f3db4 100644 --- a/backends/proto/embed.proto +++ b/backends/proto/embed.proto @@ -7,6 +7,8 @@ service EmbeddingService { rpc Embed (EmbedRequest) returns (EmbedResponse); /// Health check rpc Health (HealthRequest) returns (HealthResponse); + /// Predict + rpc Predict (EmbedRequest) returns (PredictResponse); } message HealthRequest {} @@ -28,3 +30,11 @@ message Embedding { message EmbedResponse { repeated Embedding embeddings = 1; } + +message Score { + repeated float values = 1; +} + +message PredictResponse { + repeated Score scores = 1; +} diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 360a2d40..b43fe18c 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -6,9 +6,11 @@ from typing import Optional from transformers import AutoConfig from transformers.models.bert import BertConfig +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES from text_embeddings_server.models.model import Model from text_embeddings_server.models.default_model import DefaultModel +from text_embeddings_server.models.classification_model import ClassificationModel __all__ = ["Model"] @@ -66,10 +68,15 @@ def get_model(model_path: Path, dtype: Optional[str]): ): return FlashBert(model_path, device, dtype) else: - return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) + if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): + return ClassificationModel(model_path, device, dtype) + else: + return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) else: try: - return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) + if config.architectures[0] in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values(): + return ClassificationModel(model_path, device, dtype) + else: + return DefaultModel(model_path, device, dtype, trust_remote=TRUST_REMOTE_CODE) except: - raise RuntimeError(f"Unknown model_type {config.model_type}") - + raise RuntimeError(f"Unsupported model_type {config.model_type}") diff --git a/backends/python/server/text_embeddings_server/models/classification_model.py b/backends/python/server/text_embeddings_server/models/classification_model.py new file mode 100644 index 00000000..e0660207 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/classification_model.py @@ -0,0 +1,74 @@ +import inspect +import torch + +from loguru import logger +from pathlib import Path +from typing import Type, List +from transformers import AutoModelForSequenceClassification +from opentelemetry import trace + +from habana_frameworks.torch.hpu import wrap_in_hpu_graph +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + +from text_embeddings_server.models import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + +class ClassificationModel(Model): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + if device == torch.device("hpu"): + adapt_transformers_to_gaudi() + + model = AutoModelForSequenceClassification.from_pretrained(model_path) + model = model.to(dtype).to(device) + if device == torch.device("hpu"): + logger.info("Use graph mode for HPU") + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + + self.hidden_size = model.config.hidden_size + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = model.config.pad_token_id + 1 + max_input_length = 0 + if hasattr(model.config, "max_seq_length"): + max_input_length = model.config.max_seq_length + else: + max_input_length = model.config.max_position_embeddings - position_offset + self.max_input_length = max_input_length + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None + ) + self.has_token_type_ids = ( + inspect.signature(model.forward).parameters.get("token_type_ids", None) + is not None + ) + + super(ClassificationModel, self).__init__(model=model, dtype=dtype, device=device) + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + pass + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} + if self.has_token_type_ids: + kwargs["token_type_ids"] = batch.token_type_ids + if self.has_position_ids: + kwargs["position_ids"] = batch.position_ids + + output = self.model(**kwargs, return_dict=True) + scores = output.logits.view(-1, ).tolist() + return [ + Score( + values=scores[i:i+1] + ) + for i in range(len(batch)) + ] diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index 75f5d8a4..c24ae4ce 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -11,7 +11,7 @@ from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from text_embeddings_server.models import Model -from text_embeddings_server.models.types import PaddedBatch, Embedding +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score tracer = trace.get_tracer(__name__) @@ -72,3 +72,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: ) for i in range(len(batch)) ] + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + pass diff --git a/backends/python/server/text_embeddings_server/models/model.py b/backends/python/server/text_embeddings_server/models/model.py index 3bf7b4dc..0a44c8de 100644 --- a/backends/python/server/text_embeddings_server/models/model.py +++ b/backends/python/server/text_embeddings_server/models/model.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import List, TypeVar, Type -from text_embeddings_server.models.types import Batch, Embedding +from text_embeddings_server.models.types import Batch, Embedding, Score B = TypeVar("B", bound=Batch) @@ -27,3 +27,7 @@ def batch_type(self) -> Type[B]: @abstractmethod def embed(self, batch: B) -> List[Embedding]: raise NotImplementedError + + @abstractmethod + def predict(self, batch: B) -> List[Score]: + raise NotImplementedError diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py index be38f462..2c7b968a 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -6,7 +6,7 @@ from opentelemetry import trace from text_embeddings_server.pb import embed_pb2 -from text_embeddings_server.pb.embed_pb2 import Embedding +from text_embeddings_server.pb.embed_pb2 import Embedding, Score tracer = trace.get_tracer(__name__) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get('PAD_SEQUENCE_TO_MULTIPLE_OF', 128)) diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index 4523d1b2..5b381237 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -33,6 +33,14 @@ async def Embed(self, request, context): return embed_pb2.EmbedResponse(embeddings=embeddings) + async def Predict(self, request, context): + max_input_length = self.model.max_input_length + batch = self.model.batch_type.from_pb(request, self.model.device, max_input_length) + + scores = self.model.predict(batch) + + return embed_pb2.PredictResponse(scores=scores) + def serve( model_path: Path, diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 195f1d37..fae80c61 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -26,15 +26,13 @@ impl PythonBackend { ) -> Result { match model_type { ModelType::Classifier => { - return Err(BackendError::Start( - "`classifier` model type is not supported".to_string(), - )) + None } ModelType::Embedding(pool) => { if pool != Pool::Cls { return Err(BackendError::Start(format!("{pool:?} is not supported"))); } - pool + Some(pool) } }; @@ -109,9 +107,32 @@ impl Backend for PythonBackend { Ok(embeddings) } - fn predict(&self, _batch: Batch) -> Result { - Err(BackendError::Inference( - "`predict` is not implemented".to_string(), - )) + fn predict(&self, batch: Batch) -> Result { + if !batch.raw_indices.is_empty() { + return Err(BackendError::Inference( + "raw embeddings are not supported for the Python backend.".to_string(), + )); + } + let batch_size = batch.len(); + let results = self + .tokio_runtime + .block_on(self.backend_client.clone().predict( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + batch.cumulative_seq_lengths, + batch.max_length, + )) + .map_err(|err| BackendError::Inference(err.to_string()))?; + let raw_results: Vec> = results.into_iter().map(|r| r.values).collect(); + + let mut predictions = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + + for (i, r) in raw_results.into_iter().enumerate() { + predictions.insert(i, r); + } + + Ok(predictions) } }