Skip to content

Commit

Permalink
Merge branch 'main' into fix_is_valid_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam authored May 15, 2024
2 parents c6e9f85 + b781441 commit 1e51c35
Show file tree
Hide file tree
Showing 11 changed files with 1,865 additions and 468 deletions.
1,323 changes: 1,323 additions & 0 deletions docs/encoders/bedrock.ipynb

Large diffs are not rendered by default.

602 changes: 144 additions & 458 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.39"
version = "0.0.40"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <[email protected]>",
Expand All @@ -18,7 +18,7 @@ packages = [{include = "semantic_router"}]
python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
cohere = "^5.00"
mistralai= {version = "^0.0.12", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
Expand All @@ -38,6 +38,7 @@ matplotlib = { version = "^3.8.3", optional = true}
qdrant-client = {version = "^1.8.0", optional = true}
google-cloud-aiplatform = {version = "^1.45.0", optional = true}
requests-mock = "^1.12.1"
boto3 = { version = "^1.34.98", optional = true }

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand All @@ -49,6 +50,7 @@ processing = ["matplotlib"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
google = ["google-cloud-aiplatform"]
bedrock = ["boto3"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]

__version__ = "0.0.39"
__version__ = "0.0.40"
4 changes: 4 additions & 0 deletions semantic_router/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional

from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.bedrock import BedrockEncoder
from semantic_router.encoders.bm25 import BM25Encoder
from semantic_router.encoders.clip import CLIPEncoder
from semantic_router.encoders.cohere import CohereEncoder
Expand Down Expand Up @@ -29,6 +30,7 @@
"VitEncoder",
"CLIPEncoder",
"GoogleEncoder",
"BedrockEncoder",
]


Expand Down Expand Up @@ -67,6 +69,8 @@ def __init__(self, type: str, name: Optional[str]):
self.model = CLIPEncoder(name=name)
elif self.type == EncoderType.GOOGLE:
self.model = GoogleEncoder(name=name)
elif self.type == EncoderType.BEDROCK:
self.model = BedrockEncoder(name=name) # type: ignore
else:
raise ValueError(f"Encoder type '{type}' not supported")

Expand Down
250 changes: 250 additions & 0 deletions semantic_router/encoders/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
This module provides the BedrockEncoder class for generating embeddings using Amazon's Bedrock Platform.
The BedrockEncoder class is a subclass of BaseEncoder and utilizes the TextEmbeddingModel from the
Amazon's Bedrock Platform to generate embeddings for given documents. It requires an AWS Access Key ID
and AWS Secret Access Key and supports customization of the pre-trained model, score threshold, and region.
Example usage:
from semantic_router.encoders.bedrock_encoder import BedrockEncoder
encoder = BedrockEncoder(access_key_id="your-access-key-id", secret_access_key="your-secret-key", region="your-region")
embeddings = encoder(["document1", "document2"])
Classes:
BedrockEncoder: A class for generating embeddings using the Bedrock Platform.
"""

import json
from typing import List, Optional, Any
import os
import tiktoken
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault


class BedrockEncoder(BaseEncoder):
client: Any = None
type: str = "bedrock"
input_type: Optional[str] = "search_query"
name: str
access_key_id: Optional[str] = None
secret_access_key: Optional[str] = None
session_token: Optional[str] = None
region: Optional[str] = None

def __init__(
self,
name: str = EncoderDefault.BEDROCK.value["embedding_model"],
input_type: Optional[str] = "search_query",
score_threshold: float = 0.3,
access_key_id: Optional[str] = None,
secret_access_key: Optional[str] = None,
session_token: Optional[str] = None,
region: Optional[str] = None,
):
"""Initializes the BedrockEncoder.
Args:
name: The name of the pre-trained model to use for embedding.
If not provided, the default model specified in EncoderDefault will
be used.
score_threshold: The threshold for similarity scores.
access_key_id: The AWS access key id for an IAM principle.
If not provided, it will be retrieved from the access_key_id
environment variable.
secret_access_key: The secret access key for an IAM principle.
If not provided, it will be retrieved from the AWS_SECRET_KEY
environment variable.
session_token: The session token for an IAM principle.
If not provided, it will be retrieved from the AWS_SESSION_TOKEN
environment variable.
region: The location of the Bedrock resources.
If not provided, it will be retrieved from the AWS_REGION
environment variable, defaulting to "us-west-1"
Raises:
ValueError: If the Bedrock Platform client fails to initialize.
"""

super().__init__(name=name, score_threshold=score_threshold)
self.access_key_id = self.get_env_variable("access_key_id", access_key_id)
self.secret_access_key = self.get_env_variable(
"secret_access_key", secret_access_key
)
self.session_token = self.get_env_variable("AWS_SESSION_TOKEN", session_token)
self.region = self.get_env_variable("AWS_REGION", region, default="us-west-1")

self.input_type = input_type

try:
self.client = self._initialize_client(
self.access_key_id,
self.secret_access_key,
self.session_token,
self.region,
)

except Exception as e:
raise ValueError(f"Bedrock client failed to initialise. Error: {e}") from e

def _initialize_client(
self, access_key_id, secret_access_key, session_token, region
):
"""Initializes the Bedrock client.
Args:
access_key_id: The Amazon access key ID.
secret_access_key: The Amazon secret key.
region: The location of the AI Platform resources.
Returns:
An instance of the TextEmbeddingModel client.
Raises:
ImportError: If the required Bedrock libraries are not
installed.
ValueError: If the Bedrock client fails to initialize.
"""
try:
import boto3
except ImportError:
raise ImportError(
"Please install Amazon's Boto3 client library to use the BedrockEncoder. "
"You can install them with: "
"`pip install boto3`"
)

access_key_id = access_key_id or os.getenv("access_key_id")
aws_secret_key = secret_access_key or os.getenv("secret_access_key")
region = region or os.getenv("AWS_REGION", "us-west-2")

if access_key_id is None:
raise ValueError("AWS access key ID cannot be 'None'.")

if aws_secret_key is None:
raise ValueError("AWS secret access key cannot be 'None'.")

try:
bedrock_client = boto3.client(
"bedrock-runtime",
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
region_name=region,
)
except Exception as err:
raise ValueError(
f"The Bedrock client failed to initialize. Error: {err}"
) from err

return bedrock_client

def __call__(self, docs: List[str]) -> List[List[float]]:
"""Generates embeddings for the given documents.
Args:
docs: A list of strings representing the documents to embed.
Returns:
A list of lists, where each inner list contains the embedding values for a
document.
Raises:
ValueError: If the Bedrock Platform client is not initialized or if the
API call fails.
"""
if self.client is None:
raise ValueError("Bedrock client is not initialised.")
try:
embeddings = []

def chunk_strings(strings, MAX_WORDS=20):
"""
Breaks up a list of strings into smaller chunks.
Args:
strings (list): A list of strings to be chunked.
max_chunk_size (int): The maximum size of each chunk. Default is 75.
Returns:
list: A list of lists, where each inner list contains a chunk of strings.
"""
encoding = tiktoken.get_encoding("cl100k_base")
chunked_strings = []
current_chunk = []

for text in strings:
encoded_text = encoding.encode(text)

if len(encoded_text) > MAX_WORDS:
current_chunk = [
encoding.decode(encoded_text[i : i + MAX_WORDS])
for i in range(0, len(encoded_text), MAX_WORDS)
]
else:
current_chunk = [encoding.decode(encoded_text)]

chunked_strings.append(current_chunk)
return chunked_strings

if self.name and "amazon" in self.name:
for doc in docs:
embedding_body = json.dumps(
{
"inputText": doc,
}
)
response = self.client.invoke_model(
body=embedding_body,
modelId=self.name,
accept="application/json",
contentType="application/json",
)

response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
elif self.name and "cohere" in self.name:
chunked_docs = chunk_strings(docs)
for chunk in chunked_docs:
chunk = json.dumps({"texts": chunk, "input_type": self.input_type})

response = self.client.invoke_model(
body=chunk,
modelId=self.name,
accept="*/*",
contentType="application/json",
)

response_body = json.loads(response.get("body").read())

chunk_embeddings = response_body.get("embeddings")
embeddings.extend(chunk_embeddings)
else:
raise ValueError("Unknown model name")
return embeddings
except Exception as e:
raise ValueError(f"Bedrock call failed. Error: {e}") from e

@staticmethod
def get_env_variable(var_name, provided_value, default=None):
"""Retrieves environment variable or uses a provided value.
Args:
var_name (str): The name of the environment variable.
provided_value (Optional[str]): The provided value to use if not None.
default (Optional[str]): The default value if the environment variable is not set.
Returns:
str: The value of the environment variable or the provided/default value.
Raises:
ValueError: If no value is provided and the environment variable is not set.
"""
if provided_value is not None:
return provided_value
value = os.getenv(var_name, default)
if value is None:
raise ValueError(f"No {var_name} provided")
return value
11 changes: 9 additions & 2 deletions semantic_router/encoders/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

import cohere
from cohere.types.embed_response import EmbedResponse_EmbeddingsByType

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
Expand Down Expand Up @@ -42,8 +43,14 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
raise ValueError("Cohere client is not initialized.")
try:
embeds = self.client.embed(
docs, input_type=self.input_type, model=self.name
texts=docs, input_type=self.input_type, model=self.name
)
return embeds.embeddings
# Check for unsupported type.
if isinstance(embeds, EmbedResponse_EmbeddingsByType):
raise NotImplementedError(
"Handling of EmbedByTypeResponseEmbeddings is not implemented."
)
else:
return embeds.embeddings
except Exception as e:
raise ValueError(f"Cohere API call failed. Error: {e}") from e
13 changes: 8 additions & 5 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def to_dict(self):

class PineconeIndex(BaseIndex):
index_prefix: str = "semantic-router--"
api_key: Optional[str] = None
index_name: str = "index"
dimensions: Union[int, None] = None
metric: str = "cosine"
Expand Down Expand Up @@ -69,7 +70,12 @@ def __init__(
self.host = host
self.namespace = namespace
self.type = "pinecone"
self.client = self._initialize_client(api_key=api_key)
self.api_key = api_key or os.getenv("PINECONE_API_KEY")

if self.api_key is None:
raise ValueError("Pinecone API key is required.")

self.client = self._initialize_client(api_key=self.api_key)

def _initialize_client(self, api_key: Optional[str] = None):
try:
Expand All @@ -82,9 +88,6 @@ def _initialize_client(self, api_key: Optional[str] = None):
"You can install it with: "
"`pip install 'semantic-router[pinecone]'`"
)
api_key = api_key or os.getenv("PINECONE_API_KEY")
if api_key is None:
raise ValueError("Pinecone API key is required.")
pinecone_args = {"api_key": api_key, "source_tag": "semantic-router"}
if self.namespace:
pinecone_args["namespace"] = self.namespace
Expand Down Expand Up @@ -190,7 +193,7 @@ def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False)
params: Dict = {}
if self.namespace:
params["namespace"] = self.namespace
headers = {"Api-Key": os.environ["PINECONE_API_KEY"]}
headers = {"Api-Key": self.api_key}
metadata = []

while True:
Expand Down
1 change: 1 addition & 0 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class EncoderType(Enum):
VIT = "vit"
CLIP = "clip"
GOOGLE = "google"
BEDROCK = "bedrock"


class EncoderInfo(BaseModel):
Expand Down
5 changes: 5 additions & 0 deletions semantic_router/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ class EncoderDefault(Enum):
"GOOGLE_EMBEDDING_MODEL", "textembedding-gecko@003"
),
}
BEDROCK = {
"embedding_model": os.environ.get(
"BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1"
)
}
Loading

0 comments on commit 1e51c35

Please sign in to comment.