diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index 5ec3381e..7adf60c7 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -17,7 +17,7 @@ """ import json -from typing import List, Optional, Any +from typing import Dict, List, Optional, Any, Union import os from time import sleep import tiktoken @@ -138,11 +138,14 @@ def _initialize_client( ) from err return bedrock_client - def __call__(self, docs: List[str]) -> List[List[float]]: + def __call__( + self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None + ) -> List[List[float]]: """Generates embeddings for the given documents. Args: docs: A list of strings representing the documents to embed. + model_kwargs: A dictionary of model-specific inference parameters. Returns: A list of lists, where each inner list contains the embedding values for a @@ -168,13 +171,29 @@ def __call__(self, docs: List[str]) -> List[List[float]]: embeddings = [] if self.name and "amazon" in self.name: for doc in docs: - embedding_body = json.dumps( - { - "inputText": doc, - } - ) + + embedding_body = {} + + if isinstance(doc, dict): + embedding_body["inputText"] = doc.get("text") + embedding_body["inputImage"] = doc.get( + "image" + ) # expects a base64-encoded image + else: + embedding_body["inputText"] = doc + + # Add model-specific inference parameters + if model_kwargs: + embedding_body = embedding_body | model_kwargs + + # Clean up null values + embedding_body = {k: v for k, v in embedding_body.items() if v} + + # Format payload + embedding_body_payload: str = json.dumps(embedding_body) + response = self.client.invoke_model( - body=embedding_body, + body=embedding_body_payload, modelId=self.name, accept="application/json", contentType="application/json", @@ -184,9 +203,16 @@ def __call__(self, docs: List[str]) -> List[List[float]]: elif self.name and "cohere" in self.name: chunked_docs = self.chunk_strings(docs) for chunk in chunked_docs: - chunk = json.dumps( - {"texts": chunk, "input_type": self.input_type} - ) + chunk = {"texts": chunk, "input_type": self.input_type} + + # Add model-specific inference parameters + # Note: if specified, input_type will be overwritten by model_kwargs + if model_kwargs: + chunk = chunk | model_kwargs + + # Format payload + chunk = json.dumps(chunk) + response = self.client.invoke_model( body=chunk, modelId=self.name,