diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b791742a7e737..401b60abed3b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -194,7 +195,7 @@ protected void doAssertLuceneQuery(SemanticQueryBuilder queryBuilder, Query quer } private void assertSparseEmbeddingLuceneQuery(Query query) { - Query innerQuery = assertOuterBooleanQuery(query); + Query innerQuery = assertOuterSparseVectorQuery(query); assertThat(innerQuery, instanceOf(BooleanQuery.class)); BooleanQuery innerBooleanQuery = (BooleanQuery) innerQuery; @@ -207,7 +208,7 @@ private void assertSparseEmbeddingLuceneQuery(Query query) { } private void assertTextEmbeddingLuceneQuery(Query query) { - Query innerQuery = assertOuterBooleanQuery(query); + Query innerQuery = assertOuterSparseVectorQuery(query); Class expectedKnnQueryClass = switch (denseVectorElementType) { case FLOAT -> KnnFloatVectorQuery.class; @@ -217,9 +218,11 @@ private void assertTextEmbeddingLuceneQuery(Query query) { assertThat(innerQuery, instanceOf(expectedKnnQueryClass)); } - private Query assertOuterBooleanQuery(Query query) { - assertThat(query, instanceOf(BooleanQuery.class)); - BooleanQuery outerBooleanQuery = (BooleanQuery) query; + private Query assertOuterSparseVectorQuery(Query query) { + assertThat(query, instanceOf(SparseVectorQueryWrapper.class)); + var wrapper = (SparseVectorQueryWrapper) query; + assertThat(wrapper.getTermsQuery(), instanceOf(BooleanQuery.class)); + BooleanQuery outerBooleanQuery = (BooleanQuery) wrapper.getTermsQuery(); List outerMustClauses = new ArrayList<>(); List outerFilterClauses = new ArrayList<>();