From bc4d09717b77797196ed7af05e3127123f871a8f Mon Sep 17 00:00:00 2001 From: Jared McQueen Date: Thu, 1 Aug 2024 10:39:02 -0400 Subject: [PATCH] Update bedrock.py fixes https://github.com/Unstructured-IO/unstructured/issues/3461 --- unstructured/embed/bedrock.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/unstructured/embed/bedrock.py b/unstructured/embed/bedrock.py index dba52e7760..be30c49358 100644 --- a/unstructured/embed/bedrock.py +++ b/unstructured/embed/bedrock.py @@ -56,9 +56,17 @@ def embed_query(self, query): return np.array(self.bedrock_client.embed_query(query)) def embed_documents(self, elements: List[Element]) -> List[Element]: - embeddings = self.bedrock_client.embed_documents([str(e) for e in elements]) - elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings) - return elements_with_embeddings + # filter out empty text + non_empty_elements = [e for e in elements if e.text.strip()] + embeddings = self.bedrock_client.embed_documents( + [str(e) for e in non_empty_elements] + ) + elements_with_embeddings = self._add_embeddings_to_elements( + non_empty_elements, embeddings + ) + result = elements_with_embeddings + [e for e in elements if not e.text.strip()] + return result + def _add_embeddings_to_elements(self, elements, embeddings) -> List[Element]: assert len(elements) == len(embeddings)