Skip to content

Commit

Permalink
Merge pull request #307 from aurelio-labs/james/bedrock-encoder-merge
Browse files Browse the repository at this point in the history
feat: bedrock encoder merge
  • Loading branch information
jamescalam authored Jun 2, 2024
2 parents 7133f1f + a93b348 commit 41a5874
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 113 deletions.
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ qdrant-client = {version = "^1.8.0", optional = true}
google-cloud-aiplatform = {version = "^1.45.0", optional = true}
requests-mock = "^1.12.1"
boto3 = { version = "^1.34.98", optional = true }
botocore = {version = "^1.34.110", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand All @@ -50,7 +51,7 @@ processing = ["matplotlib"]
mistralai = ["mistralai"]
qdrant = ["qdrant-client"]
google = ["google-cloud-aiplatform"]
bedrock = ["boto3"]
bedrock = ["boto3", "botocore"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand All @@ -75,4 +76,4 @@ build-backend = "poetry.core.masonry.api"
line-length = 88

[tool.mypy]
ignore_missing_imports = true
ignore_missing_imports = true
200 changes: 112 additions & 88 deletions semantic_router/encoders/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import json
from typing import List, Optional, Any
import os
from time import sleep
import tiktoken
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger


class BedrockEncoder(BaseEncoder):
Expand Down Expand Up @@ -67,25 +69,23 @@ def __init__(
Raises:
ValueError: If the Bedrock Platform client fails to initialize.
"""

super().__init__(name=name, score_threshold=score_threshold)
self.access_key_id = self.get_env_variable("access_key_id", access_key_id)
self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id)
self.secret_access_key = self.get_env_variable(
"secret_access_key", secret_access_key
"AWS_SECRET_ACCESS_KEY", secret_access_key
)
self.session_token = self.get_env_variable("AWS_SESSION_TOKEN", session_token)
self.region = self.get_env_variable("AWS_REGION", region, default="us-west-1")

self.region = self.get_env_variable(
"AWS_DEFAULT_REGION", region, default="us-west-1"
)
self.input_type = input_type

try:
self.client = self._initialize_client(
self.access_key_id,
self.secret_access_key,
self.session_token,
self.region,
)

except Exception as e:
raise ValueError(f"Bedrock client failed to initialise. Error: {e}") from e

Expand Down Expand Up @@ -115,30 +115,27 @@ def _initialize_client(
"You can install them with: "
"`pip install boto3`"
)

access_key_id = access_key_id or os.getenv("access_key_id")
aws_secret_key = secret_access_key or os.getenv("secret_access_key")
region = region or os.getenv("AWS_REGION", "us-west-2")

access_key_id = access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_key = secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
region = region or os.getenv("AWS_DEFAULT_REGION", "us-west-2")
if access_key_id is None:
raise ValueError("AWS access key ID cannot be 'None'.")

if aws_secret_key is None:
raise ValueError("AWS secret access key cannot be 'None'.")

session = boto3.Session(
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
)
try:
bedrock_client = boto3.client(
bedrock_client = session.client(
"bedrock-runtime",
aws_access_key_id=access_key_id,
aws_secret_access_key=secret_access_key,
aws_session_token=session_token,
region_name=region,
)
except Exception as err:
raise ValueError(
f"The Bedrock client failed to initialize. Error: {err}"
) from err

return bedrock_client

def __call__(self, docs: List[str]) -> List[List[float]]:
Expand All @@ -155,77 +152,101 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
ValueError: If the Bedrock Platform client is not initialized or if the
API call fails.
"""
try:
from botocore.exceptions import ClientError
except ImportError:
raise ImportError(
"Please install Amazon's Botocore client library to use the BedrockEncoder. "
"You can install them with: "
"`pip install botocore`"
)
if self.client is None:
raise ValueError("Bedrock client is not initialised.")
try:
embeddings = []

def chunk_strings(strings, MAX_WORDS=20):
"""
Breaks up a list of strings into smaller chunks.
Args:
strings (list): A list of strings to be chunked.
max_chunk_size (int): The maximum size of each chunk. Default is 75.
Returns:
list: A list of lists, where each inner list contains a chunk of strings.
"""
encoding = tiktoken.get_encoding("cl100k_base")
chunked_strings = []
current_chunk = []

for text in strings:
encoded_text = encoding.encode(text)

if len(encoded_text) > MAX_WORDS:
current_chunk = [
encoding.decode(encoded_text[i : i + MAX_WORDS])
for i in range(0, len(encoded_text), MAX_WORDS)
]
else:
current_chunk = [encoding.decode(encoded_text)]

chunked_strings.append(current_chunk)
return chunked_strings

if self.name and "amazon" in self.name:
for doc in docs:
embedding_body = json.dumps(
{
"inputText": doc,
}
)
response = self.client.invoke_model(
body=embedding_body,
modelId=self.name,
accept="application/json",
contentType="application/json",
)

response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
elif self.name and "cohere" in self.name:
chunked_docs = chunk_strings(docs)
for chunk in chunked_docs:
chunk = json.dumps({"texts": chunk, "input_type": self.input_type})

response = self.client.invoke_model(
body=chunk,
modelId=self.name,
accept="*/*",
contentType="application/json",
)

response_body = json.loads(response.get("body").read())

chunk_embeddings = response_body.get("embeddings")
embeddings.extend(chunk_embeddings)
else:
raise ValueError("Unknown model name")
return embeddings
except Exception as e:
raise ValueError(f"Bedrock call failed. Error: {e}") from e
max_attempts = 3
for attempt in range(max_attempts):
try:
embeddings = []
if self.name and "amazon" in self.name:
for doc in docs:
embedding_body = json.dumps(
{
"inputText": doc,
}
)
response = self.client.invoke_model(
body=embedding_body,
modelId=self.name,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
elif self.name and "cohere" in self.name:
chunked_docs = self.chunk_strings(docs)
for chunk in chunked_docs:
chunk = json.dumps(
{"texts": chunk, "input_type": self.input_type}
)
response = self.client.invoke_model(
body=chunk,
modelId=self.name,
accept="*/*",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
chunk_embeddings = response_body.get("embeddings")
embeddings.extend(chunk_embeddings)
else:
raise ValueError("Unknown model name")
return embeddings
except ClientError as error:
if attempt < max_attempts - 1:
if error.response["Error"]["Code"] == "ExpiredTokenException":
logger.warning(
"Session token has expired. Retrying initialisation."
)
try:
self.session_token = os.getenv("AWS_SESSION_TOKEN")
self.client = self._initialize_client(
self.access_key_id,
self.secret_access_key,
self.session_token,
self.region,
)
except Exception as e:
raise ValueError(
f"Bedrock client failed to reinitialise. Error: {e}"
) from e
sleep(2**attempt)
logger.warning(f"Retrying in {2**attempt} seconds...")
raise ValueError(
f"Retries exhausted, Bedrock call failed. Error: {error}"
) from error
except Exception as e:
raise ValueError(f"Bedrock call failed. Error: {e}") from e
raise ValueError("Bedrock call failed to return embeddings.")

def chunk_strings(self, strings, MAX_WORDS=20):
"""
Breaks up a list of strings into smaller chunks.
Args:
strings (list): A list of strings to be chunked.
max_chunk_size (int): The maximum size of each chunk. Default is 20.
Returns:
list: A list of lists, where each inner list contains a chunk of strings.
"""
encoding = tiktoken.get_encoding("cl100k_base")
chunked_strings = []
for text in strings:
encoded_text = encoding.encode(text)
chunks = [
encoding.decode(encoded_text[i : i + MAX_WORDS])
for i in range(0, len(encoded_text), MAX_WORDS)
]
chunked_strings.append(chunks)
return chunked_strings

@staticmethod
def get_env_variable(var_name, provided_value, default=None):
Expand All @@ -238,6 +259,7 @@ def get_env_variable(var_name, provided_value, default=None):
Returns:
str: The value of the environment variable or the provided/default value.
None: Where AWS_SESSION_TOKEN is not set or provided
Raises:
ValueError: If no value is provided and the environment variable is not set.
Expand All @@ -246,5 +268,7 @@ def get_env_variable(var_name, provided_value, default=None):
return provided_value
value = os.getenv(var_name, default)
if value is None:
if var_name == "AWS_SESSION_TOKEN":
return None
raise ValueError(f"No {var_name} provided")
return value
Loading

0 comments on commit 41a5874

Please sign in to comment.