From 617dff766d6a9cfd19ab6dd5a13c080e66d30299 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 9 Oct 2024 09:32:37 +0100 Subject: [PATCH] add validation for queryVector and queryVectorBuilder --- .../search/retriever/KnnRetrieverBuilder.java | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 5974dea23b577..facda1a30a5ac 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -114,8 +114,25 @@ public KnnRetrieverBuilder( int numCands, Float similarity ) { + if (queryVector == null && queryVectorBuilder == null) { + throw new IllegalArgumentException( + format( + "either [%s] or [%s] must be provided", + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ) + ); + } else if (queryVector != null && queryVectorBuilder != null) { + throw new IllegalArgumentException( + format( + "only one of [%s] and [%s] must be provided", + QUERY_VECTOR_FIELD.getPreferredName(), + QUERY_VECTOR_BUILDER_FIELD.getPreferredName() + ) + ); + } this.field = field; - this.queryVector = () -> queryVector; + this.queryVector = queryVector != null ? () -> queryVector : null; this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCands; @@ -175,6 +192,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { @Override public QueryBuilder topDocsQuery() { + assert queryVector != null : "query vector must be materialized at this point"; assert rankDocs != null : "rankDocs should have been materialized by now"; var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true); if (preFilterQueryBuilders.isEmpty()) { @@ -187,6 +205,7 @@ public QueryBuilder topDocsQuery() { @Override public QueryBuilder explainQuery() { + assert queryVector != null : "query vector must be materialized at this point"; assert rankDocs != null : "rankDocs should have been materialized by now"; var rankDocsQuery = new RankDocsQueryBuilder( rankDocs, @@ -203,10 +222,11 @@ public QueryBuilder explainQuery() { @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + assert queryVector != null : "query vector must be materialized at this point."; KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder( field, VectorData.fromFloats(queryVector.get()), - queryVectorBuilder, + null, k, numCands, similarity @@ -223,6 +243,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder } // ---- FOR TESTING XCONTENT PARSING ---- + @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { builder.field(FIELD_FIELD.getPreferredName(), field); @@ -230,7 +251,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands); if (queryVector != null) { - builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get()); } if (queryVectorBuilder != null) { @@ -248,7 +269,8 @@ public boolean doEquals(Object o) { return k == that.k && numCands == that.numCands && Objects.equals(field, that.field) - && Arrays.equals(queryVector.get(), that.queryVector.get()) + && ((queryVector == null && that.queryVector == null) + || (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get()))) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(similarity, that.similarity); } @@ -256,7 +278,7 @@ public boolean doEquals(Object o) { @Override public int doHashCode() { int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity); - result = 31 * result + Arrays.hashCode(queryVector.get()); + result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null); return result; }