Skip to content

Commit

Permalink
Merge branch 'main' into py_typed
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam authored Dec 13, 2024
2 parents 430a770 + 2c09ff7 commit 07ece05
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 45 deletions.
2 changes: 1 addition & 1 deletion docs/00-introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
"source": [
"from semantic_router.routers import SemanticRouter\n",
"\n",
"sr = SemanticRouter(encoder=encoder, routes=routes)"
"sr = SemanticRouter(encoder=encoder, routes=routes, auto_sync=\"local\")"
]
},
{
Expand Down
48 changes: 37 additions & 11 deletions semantic_router/encoders/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import json
from typing import List, Optional, Any
from typing import Dict, List, Optional, Any, Union
import os
from time import sleep
import tiktoken
Expand Down Expand Up @@ -138,11 +138,14 @@ def _initialize_client(
) from err
return bedrock_client

def __call__(self, docs: List[str]) -> List[List[float]]:
def __call__(
self, docs: List[Union[str, Dict]], model_kwargs: Optional[Dict] = None
) -> List[List[float]]:
"""Generates embeddings for the given documents.
Args:
docs: A list of strings representing the documents to embed.
model_kwargs: A dictionary of model-specific inference parameters.
Returns:
A list of lists, where each inner list contains the embedding values for a
Expand All @@ -168,13 +171,29 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
embeddings = []
if self.name and "amazon" in self.name:
for doc in docs:
embedding_body = json.dumps(
{
"inputText": doc,
}
)

embedding_body = {}

if isinstance(doc, dict):
embedding_body["inputText"] = doc.get("text")
embedding_body["inputImage"] = doc.get(
"image"
) # expects a base64-encoded image
else:
embedding_body["inputText"] = doc

# Add model-specific inference parameters
if model_kwargs:
embedding_body = embedding_body | model_kwargs

# Clean up null values
embedding_body = {k: v for k, v in embedding_body.items() if v}

# Format payload
embedding_body_payload: str = json.dumps(embedding_body)

response = self.client.invoke_model(
body=embedding_body,
body=embedding_body_payload,
modelId=self.name,
accept="application/json",
contentType="application/json",
Expand All @@ -184,9 +203,16 @@ def __call__(self, docs: List[str]) -> List[List[float]]:
elif self.name and "cohere" in self.name:
chunked_docs = self.chunk_strings(docs)
for chunk in chunked_docs:
chunk = json.dumps(
{"texts": chunk, "input_type": self.input_type}
)
chunk = {"texts": chunk, "input_type": self.input_type}

# Add model-specific inference parameters
# Note: if specified, input_type will be overwritten by model_kwargs
if model_kwargs:
chunk = chunk | model_kwargs

# Format payload
chunk = json.dumps(chunk)

response = self.client.invoke_model(
body=chunk,
modelId=self.name,
Expand Down
85 changes: 76 additions & 9 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime
import time
from typing import Any, List, Optional, Tuple, Union, Dict
import json

Expand Down Expand Up @@ -157,26 +159,91 @@ def delete_index(self):
logger.warning("This method should be implemented by subclasses.")
self.index = None

def _read_hash(self) -> ConfigParameter:
"""
Read the hash of the previously written index.
def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
"""Read a config parameter from the index.
This method should be implemented by subclasses.
:param field: The field to read.
:type field: str
:param scope: The scope to read.
:type scope: str | None
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
logger.warning("This method should be implemented by subclasses.")
return ConfigParameter(
field="sr_hash",
field=field,
value="",
namespace="",
scope=scope,
)

def _write_config(self, config: ConfigParameter):
def _read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index.
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
Write a config parameter to the index.
return self._read_config(field="sr_hash")

This method should be implemented by subclasses.
def _write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Write a config parameter to the index.
:param config: The config parameter to write.
:type config: ConfigParameter
:return: The config parameter that was written.
:rtype: ConfigParameter
"""
logger.warning("This method should be implemented by subclasses.")
return config

def lock(
self, value: bool, wait: int = 0, scope: str | None = None
) -> ConfigParameter:
"""Lock/unlock the index for a given scope (if applicable). If index
already locked/unlocked, raises ValueError.
:param scope: The scope to lock.
:type scope: str | None
:param wait: The number of seconds to wait for the index to be unlocked, if
set to 0, will raise an error if index is already locked/unlocked.
:type wait: int
:return: The config parameter that was locked.
:rtype: ConfigParameter
"""
start_time = datetime.now()
while True:
if self._is_locked(scope=scope) != value:
# in this case, we can set the lock value
break
if (datetime.now() - start_time).total_seconds() < wait:
# wait for 2.5 seconds before checking again
time.sleep(2.5)
else:
raise ValueError(
f"Index is already {'locked' if value else 'unlocked'}."
)
lock_param = ConfigParameter(
field="sr_lock",
value=str(value),
scope=scope,
)
self._write_config(lock_param)
return lock_param

def _is_locked(self, scope: str | None = None) -> bool:
"""Check if the index is locked for a given scope (if applicable).
:param scope: The scope to check.
:type scope: str | None
:return: True if the index is locked, False otherwise.
:rtype: bool
"""
lock_config = self._read_config(field="sr_lock", scope=scope)
if lock_config.value == "True":
return True
elif lock_config.value == "False" or not lock_config.value:
return False
else:
raise ValueError(f"Invalid lock value: {lock_config.value}")

def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
"""
Expand Down
35 changes: 20 additions & 15 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,39 +405,43 @@ def query(
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names

def _read_hash(self) -> ConfigParameter:
def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
scope = scope or self.namespace
if self.index is None:
return ConfigParameter(
field="sr_hash",
field=field,
value="",
namespace=self.namespace,
scope=scope,
)
hash_id = f"sr_hash#{self.namespace}"
hash_record = self.index.fetch(
ids=[hash_id],
config_id = f"{field}#{scope}"
config_record = self.index.fetch(
ids=[config_id],
namespace="sr_config",
)
if hash_record["vectors"]:
if config_record["vectors"]:
return ConfigParameter(
field="sr_hash",
value=hash_record["vectors"][hash_id]["metadata"]["value"],
created_at=hash_record["vectors"][hash_id]["metadata"]["created_at"],
namespace=self.namespace,
field=field,
value=config_record["vectors"][config_id]["metadata"]["value"],
created_at=config_record["vectors"][config_id]["metadata"][
"created_at"
],
scope=scope,
)
else:
logger.warning("Configuration for hash parameter not found in index.")
logger.warning(f"Configuration for {field} parameter not found in index.")
return ConfigParameter(
field="sr_hash",
field=field,
value="",
namespace=self.namespace,
scope=scope,
)

def _write_config(self, config: ConfigParameter) -> None:
def _write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Method to write a config parameter to the remote Pinecone index.
:param config: The config parameter to write to the index.
:type config: ConfigParameter
"""
config.scope = config.scope or self.namespace
if self.index is None:
raise ValueError("Index has not been initialized.")
if self.dimensions is None:
Expand All @@ -446,6 +450,7 @@ def _write_config(self, config: ConfigParameter) -> None:
vectors=[config.to_pinecone(dimensions=self.dimensions)],
namespace="sr_config",
)
return config

async def aquery(
self,
Expand Down
15 changes: 13 additions & 2 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,18 @@ async def _async_retrieve_top_route(
route = self.check_for_matching_routes(top_class)
return route, top_class_scores

def sync(self, sync_mode: str, force: bool = False) -> List[str]:
def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
"""Runs a sync of the local routes with the remote index.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
:param force: Whether to force the sync even if the local and remote
hashes already match. Defaults to False.
:type force: bool, optional
:param wait: The number of seconds to wait for the index to be unlocked
before proceeding with the sync. If set to 0, will raise an error if
index is already locked/unlocked.
:type wait: int
:return: A list of diffs describing the addressed differences between
the local and remote route layers.
:rtype: List[str]
Expand All @@ -565,7 +569,9 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]:
remote_utterances=local_utterances,
)
return diff.to_utterance_str()
# otherwise we continue with the sync, first creating a diff
# otherwise we continue with the sync, first locking the index
_ = self.index.lock(value=True, wait=wait)
# first creating a diff
local_utterances = self.to_config().to_utterances()
remote_utterances = self.index.get_utterances()
diff = UtteranceDiff.from_utterances(
Expand All @@ -576,6 +582,8 @@ def sync(self, sync_mode: str, force: bool = False) -> List[str]:
sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode)
# and execute
self._execute_sync_strategy(sync_strategy)
# unlock index after sync
_ = self.index.lock(value=False)
return diff.to_utterance_str()

def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
Expand Down Expand Up @@ -781,6 +789,9 @@ def delete(self, route_name: str):
:param route_name: the name of the route to be deleted
:type str:
"""
# ensure index is not locked
if self.index._is_locked():
raise ValueError("Index is locked. Cannot delete route.")
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
Expand Down
11 changes: 6 additions & 5 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from difflib import Differ
from enum import Enum
import json
Expand Down Expand Up @@ -62,12 +62,13 @@ def __str__(self):
class ConfigParameter(BaseModel):
field: str
value: str
namespace: Optional[str] = None
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat())
scope: Optional[str] = None
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)

def to_pinecone(self, dimensions: int):
if self.namespace is None:
namespace = ""
namespace = self.scope or ""
return {
"id": f"{self.field}#{namespace}",
"values": [0.1] * dimensions,
Expand Down
Loading

0 comments on commit 07ece05

Please sign in to comment.