Skip to content

Commit

Permalink
add embedding input type parameter (#8724)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Sep 24, 2024
1 parent debe595 commit 4669eb2
Show file tree
Hide file tree
Showing 33 changed files with 239 additions and 35 deletions.
9 changes: 7 additions & 2 deletions api/core/embedding/cached_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from sqlalchemy.exc import IntegrityError

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
Expand Down Expand Up @@ -56,7 +57,9 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
for i in range(0, len(embedding_queue_texts), max_chunks):
batch_texts = embedding_queue_texts[i : i + max_chunks]

embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT
)

for vector in embedding_result.embeddings:
try:
Expand Down Expand Up @@ -100,7 +103,9 @@ def embed_query(self, text: str) -> list[float]:
redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try:
embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
)

embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
Expand Down
10 changes: 10 additions & 0 deletions api/core/embedding/embedding_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from enum import Enum


class EmbeddingInputType(Enum):
"""
Enum for embedding input type.
"""

DOCUMENT = "document"
QUERY = "query"
7 changes: 6 additions & 1 deletion api/core/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable, Generator, Sequence
from typing import IO, Optional, Union, cast

from core.embedding.embedding_constant import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
Expand Down Expand Up @@ -158,12 +159,15 @@ def get_llm_num_tokens(
tools=tools,
)

def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
def invoke_text_embedding(
self, texts: list[str], user: Optional[str] = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
) -> TextEmbeddingResult:
"""
Invoke large language model
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
Expand All @@ -176,6 +180,7 @@ def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) ->
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
)

def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import ConfigDict

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
Expand All @@ -20,35 +21,47 @@ class TextEmbeddingModel(AIModel):
model_config = ConfigDict(protected_namespaces=())

def invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke large language model
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
self.started_at = time.perf_counter()

try:
return self._invoke(model, credentials, texts, user)
return self._invoke(model, credentials, texts, user, input_type)
except Exception as e:
raise self._transform_invoke_error(e)

@abstractmethod
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke large language model
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:param input_type: input type
:return: embeddings result
"""
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tiktoken
from openai import AzureOpenAI

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import AIModelEntity, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
Expand All @@ -17,7 +18,12 @@

class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
base_model_name = credentials["base_model_name"]
credentials_kwargs = self._to_credential_kwargs(credentials)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from requests import post

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
Expand Down Expand Up @@ -35,7 +36,12 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
api_base: str = "http://api.baichuan-ai.com/v1/embeddings"

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UnknownServiceError,
)

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
Expand All @@ -30,7 +31,12 @@

class BedrockTextEmbeddingModel(TextEmbeddingModel):
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from cohere.core import RequestOptions

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
Expand All @@ -25,7 +26,12 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
"""

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests
from huggingface_hub import HfApi, InferenceClient

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
Expand All @@ -18,7 +19,12 @@

class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
client = InferenceClient(token=credentials["huggingfacehub_api_token"])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from typing import Optional

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
Expand All @@ -23,7 +24,12 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
"""

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
Expand All @@ -26,7 +27,12 @@ class HunyuanTextEmbeddingModel(TextEmbeddingModel):
"""

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from requests import post

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
Expand Down Expand Up @@ -60,7 +61,12 @@ def transform_jina_input_text(model, text):
return data

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from requests import post
from yarl import URL

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
Expand All @@ -26,7 +27,12 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
"""

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from requests import post

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
Expand Down Expand Up @@ -34,7 +35,12 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
api_base: str = "https://api.minimax.chat/v1/embeddings"

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import requests

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
Expand All @@ -27,7 +28,12 @@ class MixedBreadTextEmbeddingModel(TextEmbeddingModel):
api_base: str = "https://api.mixedbread.ai/v1"

def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from nomic import embed
from nomic import login as nomic_login

from core.embedding.embedding_constant import EmbeddingInputType
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
Expand Down Expand Up @@ -46,6 +47,7 @@ def _invoke(
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down
Loading

0 comments on commit 4669eb2

Please sign in to comment.