From f269a24e9e6548a7a2d3a6c6faab50491da0bcf2 Mon Sep 17 00:00:00 2001 From: Ismail Ashraq Date: Sun, 15 Dec 2024 17:54:24 +0800 Subject: [PATCH] fix vector shape for single utterance --- semantic_router/routers/base.py | 4 ++-- semantic_router/routers/semantic.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 328cf2b7..18d71ced 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -522,7 +522,7 @@ def _retrieve_top_route( """ # get relevant results (scores and routes) results = self._retrieve( - xq=np.array(vector), top_k=self.top_k, route_filter=route_filter + xq=vector[0], top_k=self.top_k, route_filter=route_filter ) # decide most relevant routes top_class, top_class_scores = self._semantic_classify(results) @@ -535,7 +535,7 @@ async def _async_retrieve_top_route( ) -> Tuple[Optional[Route], List[float]]: # get relevant results (scores and routes) results = await self._async_retrieve( - xq=np.array(vector), top_k=self.top_k, route_filter=route_filter + xq=vector[0], top_k=self.top_k, route_filter=route_filter ) # decide most relevant routes top_class, top_class_scores = await self._async_semantic_classify(results) diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index 33af2a32..41c92d53 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -40,14 +40,12 @@ def _encode(self, text: list[str]) -> Any: """Given some text, encode it.""" # create query vector xq = np.array(self.encoder(text)) - xq = np.squeeze(xq) # Reduce to 1d array. return xq async def _async_encode(self, text: list[str]) -> Any: """Given some text, encode it.""" # create query vector xq = np.array(await self.encoder.acall(docs=text)) - xq = np.squeeze(xq) # Reduce to 1d array. return xq def add(self, routes: List[Route] | Route):