From c40dfcaa31515e569893a48bf664ae0b1b5c4121 Mon Sep 17 00:00:00 2001 From: mkr Date: Wed, 22 Mar 2023 15:00:01 +0100 Subject: [PATCH 1/2] Using the LuceneRawQuery inside the BoostQuery to avoid the need to re-parse the embedding query String in Elasticsearch. --- .../querqy/embeddings/EmbeddingsRewriter.java | 54 +++---------------- 1 file changed, 8 insertions(+), 46 deletions(-) diff --git a/src/main/java/querqy/embeddings/EmbeddingsRewriter.java b/src/main/java/querqy/embeddings/EmbeddingsRewriter.java index fbe4cd0..b392d60 100644 --- a/src/main/java/querqy/embeddings/EmbeddingsRewriter.java +++ b/src/main/java/querqy/embeddings/EmbeddingsRewriter.java @@ -10,7 +10,6 @@ import querqy.model.Node; import querqy.model.QuerqyQuery; import querqy.model.Query; -import querqy.model.StringRawQuery; import querqy.model.Term; import querqy.rewrite.QueryRewriter; import querqy.rewrite.RewriterOutput; @@ -77,20 +76,19 @@ public RewriterOutput rewrite(final ExpandedQuery query, } protected ExpandedQuery applyEmbedding(final Embedding embedding, final ExpandedQuery inputQuery) { - + KnnVectorQuery knnVectorQuery = new KnnVectorQuery(vectorField, embedding.asVector(), topK); switch (queryMode) { case BOOST: - inputQuery.addBoostUpQuery(new BoostQuery(new StringRawQuery(null, makeEmbeddingQueryString(embedding), - Clause.Occur.SHOULD, true), boost)); + inputQuery.addBoostUpQuery(new BoostQuery( + new LuceneRawQuery(null, Clause.Occur.SHOULD,true, knnVectorQuery), + boost + )); break; case MAIN_QUERY: - // this is a workaround to avoid changing Querqy's query object model for now: - // as we cant set a StringRawQuery as the userQuery, we use a match all for that, add a vector query - // as a filter query (retrieve only knn) and a boost query (rank by distance) - //inputQuery.setUserQuery(new MatchAllQuery()); - inputQuery.setUserQuery(new LuceneRawQuery(null, Clause.Occur.MUST, - true, new KnnVectorQuery(vectorField, embedding.asVector(), topK))); + inputQuery.setUserQuery( + new LuceneRawQuery(null, Clause.Occur.MUST,true, knnVectorQuery) + ); break; default: throw new IllegalStateException("Unknown query mode: " + queryMode); @@ -100,42 +98,6 @@ protected ExpandedQuery applyEmbedding(final Embedding embedding, final Expanded return inputQuery; } - protected String makeEmbeddingQueryString(final Embedding embedding) { - return "{!func}sum(100,query({!knn f=" + vectorField + " topK=" + topK + " v='[" + embedding.asCommaSeparatedString() + "]'}))"; - } - - protected String embeddingToString(final float[] embedding) { - final StringBuilder sb = new StringBuilder(embedding.length * 16); - for (int i = 0; i < embedding.length; i++) { - if (i > 0) { - sb.append(", "); - } - sb.append(embedding[i]); - } - return sb.toString(); - } - - protected ExpandedQuery applyVectorQuery(final String embeddingQueryString, final ExpandedQuery inputQuery) { - - - final StringRawQuery embeddingsQuery = new StringRawQuery(null, embeddingQueryString, Clause.Occur.SHOULD, true); - switch (queryMode) { - case BOOST: - inputQuery.addBoostUpQuery(new BoostQuery(embeddingsQuery, boost)); - break; - case MAIN_QUERY: - // this is a workaround to avoid changing Querqy's query object model for now: - // as we cant set a StringRawQuery as the userQuery, we use a match all for that, add a vector query - // as a filter query (retrieve only knn) and a boost query (rank by distance) - inputQuery.setUserQuery(new StringRawQuery(null, embeddingQueryString, Clause.Occur.MUST, true)); - break; - default: - throw new IllegalStateException("Unknown query mode: " + queryMode); - - } - - return inputQuery; - } /** * Traverse the query graph, collect all the terms and join them into a string */ From 479a366614746d038999c9a3e1a253668553eeaf Mon Sep 17 00:00:00 2001 From: mkr Date: Wed, 22 Mar 2023 15:02:38 +0100 Subject: [PATCH 2/2] Fix for supporting embedding vectors that don't contain decimals (CCE before). --- .../embeddings/ChorusEmbeddingModel.java | 18 +++++++++++++++++- .../querqy/embeddings/EmbeddingsRewriter.java | 13 ++++--------- .../embeddings/ChorusEmbeddingModelTest.java | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java diff --git a/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java b/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java index b6bc42c..5c6a90b 100644 --- a/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java +++ b/src/main/java/querqy/embeddings/ChorusEmbeddingModel.java @@ -1,8 +1,12 @@ package querqy.embeddings; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import querqy.solr.utils.JsonUtil; import java.io.IOException; +import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.MalformedURLException; @@ -17,6 +21,8 @@ public class ChorusEmbeddingModel implements EmbeddingModel { private static final String CONTENT_TYPE_JSON = "application/json"; + private static final ObjectMapper objectMapper = new ObjectMapper(); + private URL url; private boolean normalize = true; @@ -65,14 +71,23 @@ public Embedding getEmbedding(final String text) { os.write(input, 0, input.length); } - embedding = Embedding.of((List) JsonUtil.readJson(con.getInputStream(), Map.class).get("embedding")); + embedding = parseEmbeddingFromResponse(con.getInputStream()); embeddingsCache.putEmbedding(cacheKey, embedding); return embedding; } catch (final IOException e) { throw new RuntimeException(e); } + } + public Embedding parseEmbeddingFromResponse(InputStream is) { + try { + JsonNode responseTree = objectMapper.readTree(is); + List embedding = objectMapper.convertValue(responseTree.path("embedding"), new TypeReference<>() {}); + return Embedding.of(embedding); + } catch (IOException e) { + throw new RuntimeException(e); + } } protected String toJsonString(final String text) { @@ -86,4 +101,5 @@ protected String toJsonString(final String text) { ))); } + } diff --git a/src/main/java/querqy/embeddings/EmbeddingsRewriter.java b/src/main/java/querqy/embeddings/EmbeddingsRewriter.java index b392d60..de3558d 100644 --- a/src/main/java/querqy/embeddings/EmbeddingsRewriter.java +++ b/src/main/java/querqy/embeddings/EmbeddingsRewriter.java @@ -76,23 +76,18 @@ public RewriterOutput rewrite(final ExpandedQuery query, } protected ExpandedQuery applyEmbedding(final Embedding embedding, final ExpandedQuery inputQuery) { - KnnVectorQuery knnVectorQuery = new KnnVectorQuery(vectorField, embedding.asVector(), topK); + KnnVectorQuery knnVectorQuery = new KnnVectorQuery(vectorField, embedding.asVector(), topK); + LuceneRawQuery luceneRawQuery = new LuceneRawQuery(null, Clause.Occur.MUST,true, knnVectorQuery); switch (queryMode) { case BOOST: - inputQuery.addBoostUpQuery(new BoostQuery( - new LuceneRawQuery(null, Clause.Occur.SHOULD,true, knnVectorQuery), - boost - )); + inputQuery.addBoostUpQuery(new BoostQuery(luceneRawQuery, boost)); break; case MAIN_QUERY: - inputQuery.setUserQuery( - new LuceneRawQuery(null, Clause.Occur.MUST,true, knnVectorQuery) - ); + inputQuery.setUserQuery(luceneRawQuery); break; default: throw new IllegalStateException("Unknown query mode: " + queryMode); - } return inputQuery; diff --git a/src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java b/src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java new file mode 100644 index 0000000..f87e0fe --- /dev/null +++ b/src/test/java/querqy/solr/embeddings/ChorusEmbeddingModelTest.java @@ -0,0 +1,19 @@ +package querqy.solr.embeddings; + +import org.junit.Assert; +import org.junit.Test; +import querqy.embeddings.ChorusEmbeddingModel; +import querqy.embeddings.Embedding; + +import java.io.ByteArrayInputStream; +import java.nio.charset.StandardCharsets; + +public class ChorusEmbeddingModelTest { + + @Test + public void testParseJson() { + String embeddingJson = "{ \"embedding\": [0.3, 1, 5] }"; + Embedding e = new ChorusEmbeddingModel().parseEmbeddingFromResponse(new ByteArrayInputStream(embeddingJson.getBytes(StandardCharsets.UTF_8))); + Assert.assertArrayEquals(e.asVector(), new float[] { 0.3f, 1f, 5f}, 0f); + } +}