Skip to content

Commit

Permalink
Add nomic embedding model provider
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoice committed Sep 22, 2024
1 parent 6df7703 commit 074f800
Show file tree
Hide file tree
Showing 16 changed files with 484 additions and 2 deletions.
Empty file.
13 changes: 13 additions & 0 deletions api/core/model_runtime/model_providers/nomic/_assets/icon_l_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions api/core/model_runtime/model_providers/nomic/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)


class _CommonNomic:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
}
26 changes: 26 additions & 0 deletions api/core/model_runtime/model_providers/nomic/nomic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class NomicAtlasProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
model_instance.validate_credentials(model="nomic-embed-text-v1.5", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex
29 changes: 29 additions & 0 deletions api/core/model_runtime/model_providers/nomic/nomic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
provider: nomic
label:
zh_Hans: Nomic Atlas
en_US: Nomic Atlas
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.svg
background: "#EFF1FE"
help:
title:
en_US: Get your API key from Nomic Atlas
zh_Hans: 从Nomic Atlas获取 API Key
url:
en_US: https://atlas.nomic.ai/data
supported_model_types:
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: nomic_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: nomic-embed-text-v1.5
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: "0.1"
unit: "0.000001"
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: nomic-embed-text-v1
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: "0.1"
unit: "0.000001"
currency: USD
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import time
from functools import wraps
from typing import Optional

from nomic import embed
from nomic import login as nomic_login

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from core.model_runtime.model_providers.nomic._common import _CommonNomic


def nomic_login_required(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
if not kwargs.get("credentials"):
raise ValueError("missing credentials parameters")
credentials = kwargs.get("credentials")
if "nomic_api_key" not in credentials:
raise ValueError("missing nomic_api_key in credentials parameters")
# nomic login
nomic_login(credentials["nomic_api_key"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
return func(*args, **kwargs)

return wrapper


class NomicTextEmbeddingModel(_CommonNomic, TextEmbeddingModel):
"""
Model class for nomic text embedding model.
"""

def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
embeddings, prompt_tokens, total_tokens = self.embed_text(
model=model,
credentials=credentials,
texts=texts,
)

# calc usage
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=prompt_tokens, total_tokens=total_tokens
)
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)

def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0

_, prompt_tokens, _ = self.embed_text(
model=model,
credentials=credentials,
texts=texts,
)
return prompt_tokens

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# call embedding model
self.embed_text(model=model, credentials=credentials, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@nomic_login_required
def embed_text(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int, int]:
"""Call out to Nomic's embedding endpoint.
Args:
model: The model to use for embedding.
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text, and tokens usage.
"""
embeddings: list[list[float]] = []
prompt_tokens = 0
total_tokens = 0

response = embed.text(
model=model,
texts=texts,
)

if not (response and "embeddings" in response):
raise ValueError("Embedding data is missing in the response.")

if not (response and "usage" in response):
raise ValueError("Response usage is missing.")

if "prompt_tokens" not in response["usage"]:
raise ValueError("Response usage does not contain prompt tokens.")

if "total_tokens" not in response["usage"]:
raise ValueError("Response usage does not contain total tokens.")

embeddings = [list(map(float, e)) for e in response["embeddings"]]
total_tokens = response["usage"]["total_tokens"]
prompt_tokens = response["usage"]["prompt_tokens"]
return embeddings, prompt_tokens, total_tokens

def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: prompt tokens
:param total_tokens: total tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens,
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
)

return usage
78 changes: 77 additions & 1 deletion api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 074f800

Please sign in to comment.