diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 6b73f2709..72ea8fa7b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -9,12 +9,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; -import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.index.query.QueryBuilderVisitor; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; @@ -48,6 +42,11 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNQuery; @@ -713,7 +712,7 @@ public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() { public void testVisit() { HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(new NeuralQueryBuilder()).add(new NeuralSparseQueryBuilder()); List visitedQueries = new ArrayList<>(); - hybridQueryBuilder.visit(QueryBuilderVisitor.NO_OP_VISITOR); + hybridQueryBuilder.visit(createTestVisitor(visitedQueries)); assertEquals(3, visitedQueries.size()); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index 97e1f7c5d..a1e8210e6 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -7,6 +7,14 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonMap; import static java.util.stream.Collectors.toList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.BooleanClause; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilderVisitor; import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; import java.io.IOException; @@ -20,11 +28,6 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.Weight; import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.CheckedConsumer; @@ -230,4 +233,18 @@ public float getMaxScore(int upTo) { protected static void initFeatureFlags() { System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey(), "true"); } + + protected static QueryBuilderVisitor createTestVisitor(List visitedQueries) { + return new QueryBuilderVisitor() { + @Override + public void accept(QueryBuilder qb) { + visitedQueries.add(qb); + } + + @Override + public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) { + return this; + } + }; + } }