Skip to content

Commit

Permalink
fix: bug in synced dynamic routes
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Sep 5, 2024
1 parent 40f7ec1 commit 42fa288
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 62 deletions.
115 changes: 62 additions & 53 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,31 +499,16 @@ def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False)
return all_vector_ids, metadata

def get_routes(self) -> List[Tuple]:
"""
Gets a list of route and utterance objects currently stored in the index, including additional metadata.
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
Returns:
List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
"""
_, metadata = self._get_all(include_metadata=True)
route_tuples = [
(
data.get("sr_route", ""),
data.get("sr_utterance", ""),
(
json.loads(data["sr_function_schema"])
if data.get("sr_function_schema", "")
else {}
),
{
key: value
for key, value in data.items()
if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
},
)
for data in metadata
]
return route_tuples # type: ignore
route_tuples = parse_route_info(metadata=metadata)
return route_tuples

def delete(self, route_name: str):
route_vec_ids = self._get_route_ids(route_name=route_name)
Expand Down Expand Up @@ -553,8 +538,7 @@ def query(
route_filter: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query vector and return the top_k results.
"""Search the index for the query vector and return the top_k results.
:param vector: The query vector to search for.
:type vector: np.ndarray
Expand Down Expand Up @@ -633,11 +617,11 @@ async def aquery(
return np.array(scores), route_names

async def aget_routes(self) -> list[tuple]:
"""
Asynchronously get a list of route and utterance objects currently stored in the index.
"""Asynchronously get a list of route and utterance objects currently
stored in the index.
Returns:
List[Tuple]: A list of (route_name, utterance) objects.
:return: A list of (route_name, utterance) objects.
:rtype: List[Tuple]
"""
if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.")
Expand Down Expand Up @@ -703,8 +687,15 @@ async def _async_describe_index(self, name: str):
async def _async_get_all(
self, prefix: Optional[str] = None, include_metadata: bool = False
) -> tuple[list[str], list[dict]]:
"""
Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
"""Retrieves all vector IDs from the Pinecone index using pagination
asynchronously.
:param prefix: The prefix to filter the vectors by.
:type prefix: Optional[str]
:param include_metadata: Whether to include metadata in the response.
:type include_metadata: bool
:return: A tuple containing a list of vector IDs and a list of metadata dictionaries.
:rtype: tuple[list[str], list[dict]]
"""
if self.index is None:
raise ValueError("Index is None, could not retrieve vector IDs.")
Expand Down Expand Up @@ -754,8 +745,13 @@ async def _async_get_all(
return all_vector_ids, metadata

async def _async_fetch_metadata(self, vector_id: str) -> dict:
"""
Fetch metadata for a single vector ID asynchronously using the async_client.
"""Fetch metadata for a single vector ID asynchronously using the
async_client.
:param vector_id: The ID of the vector to fetch metadata for.
:type vector_id: str
:return: A dictionary containing the metadata for the vector.
:rtype: dict
"""
url = f"https://{self.host}/vectors/fetch"

Expand Down Expand Up @@ -786,31 +782,44 @@ async def _async_fetch_metadata(self, vector_id: str) -> dict:
)

async def _async_get_routes(self) -> List[Tuple]:
"""
Asynchronously gets a list of route and utterance objects currently stored in the index, including additional metadata.
"""Asynchronously gets a list of route and utterance objects currently
stored in the index, including additional metadata.
Returns:
List[Tuple]: A list of tuples, each containing route, utterance, function schema and additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
"""
_, metadata = await self._async_get_all(include_metadata=True)
route_info = [
(
data.get("sr_route", ""),
data.get("sr_utterance", ""),
(
json.loads(data["sr_function_schema"])
if data.get("sr_function_schema", "")
else {}
),
{
key: value
for key, value in data.items()
if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
},
)
for data in metadata
]
route_info = parse_route_info(metadata=metadata)
return route_info # type: ignore

def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]



def parse_route_info(metadata: List[Dict[str, Any]]) -> List[Tuple]:
"""Parses metadata from Pinecone index to extract route, utterance, function
schema and additional metadata.
:param metadata: List of metadata dictionaries.
:type metadata: List[Dict[str, Any]]
:return: A list of tuples, each containing route, utterance, function schema and additional metadata.
:rtype: List[Tuple]
"""
route_info = []
for record in metadata:
sr_route = record.get("sr_route", "")
sr_utterance = record.get("sr_utterance", "")
sr_function_schema = json.loads(record.get("sr_function_schema", "{}"))
if sr_function_schema == {}:
sr_function_schema = None

additional_metadata = {
key: value
for key, value in record.items()
if key not in ["sr_route", "sr_utterance", "sr_function_schema"]
}
# TODO: Not a fan of tuple packing here
route_info.append((sr_route, sr_utterance, sr_function_schema, additional_metadata))
return route_info
22 changes: 14 additions & 8 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(
# if routes list has been passed, we initialize index now
if self.index.sync:
# initialize index now
logger.info(f"JB TEMP: {self.routes=}")
if len(self.routes) > 0:
self._add_and_sync_routes(routes=self.routes)
else:
Expand Down Expand Up @@ -544,15 +545,20 @@ def _add_and_sync_routes(self, routes: List[Route]):
)

# Update local route layer state
self.routes = [
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=[data.get("function_schemas", None)],
metadata=data.get("metadata", {}),
logger.info([data.get("function_schemas", None) for _, data in layer_routes_dict.items()])
self.routes = []
for route, data in layer_routes_dict.items():
function_schemas = data.get("function_schemas", None)
if function_schemas is not None:
function_schemas = [function_schemas]
self.routes.append(
Route(
name=route,
utterances=data.get("utterances", []),
function_schemas=function_schemas,
metadata=data.get("metadata", {}),
)
)
for route, data in layer_routes_dict.items()
]

def _extract_routes_details(
self, routes: List[Route], include_metadata: bool = False
Expand Down
5 changes: 4 additions & 1 deletion semantic_router/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ def __call__(
raise ValueError("OpenAI client is not initialized.")
try:
tools: Union[List[Dict[str, Any]], NotGiven] = (
function_schemas if function_schemas is not None else NOT_GIVEN
function_schemas if function_schemas else NOT_GIVEN
)
logger.info(f"{function_schemas=}")
logger.info(f"{function_schemas is None=}")
logger.info(f"{tools=}")

completion = self.client.chat.completions.create(
model=self.name,
Expand Down

0 comments on commit 42fa288

Please sign in to comment.