From 3275439b5aef154c3da5c7672e553bc3ae44ea82 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 27 Feb 2024 10:31:27 +0800 Subject: [PATCH] add ut Signed-off-by: zhichao-aws --- .../query/NeuralSparseQueryBuilderTests.java | 50 +++++++++++++++++-- .../NeuralSearchQueryVisitorTests.java | 26 ++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 2220ce326..89bcd57d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -22,7 +22,10 @@ import java.util.function.BiConsumer; import java.util.function.Supplier; +import org.junit.Before; +import org.opensearch.Version; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.SetOnce; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; @@ -38,6 +41,8 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -51,6 +56,11 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase { private static final String QUERY_NAME = "queryName"; private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); + @Before + public void setupClusterServiceToCurrentVersion() { + setUpClusterService(Version.CURRENT); + } + @SneakyThrows public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { /* @@ -162,7 +172,7 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() { } @SneakyThrows - public void testFromXContent_whenBuildWithMissingModelId_thenFail() { + public void testFromXContent_whenBuildWithMissingModelIdInCurrentVersion_thenSuccess() { /* { "VECTOR_FIELD": { @@ -177,6 +187,30 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() { .endObject() .endObject(); + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); + + assertNull(sparseEncodingQueryBuilder.modelId()); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMissingModelIdInOldVersion_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string" + } + } + */ + setUpClusterService(Version.V_2_12_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .endObject() + .endObject(); + XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); @@ -241,6 +275,11 @@ public void testToXContent() { assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); } + public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() { + setUpClusterService(Version.V_2_12_0); + testStreams(); + } + @SneakyThrows public void testStreams() { NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); @@ -436,10 +475,15 @@ public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { .modelId(MODEL_ID) .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - assertTrue(queryBuilder == sparseEncodingQueryBuilder); + assertSame(queryBuilder, sparseEncodingQueryBuilder); sparseEncodingQueryBuilder.queryTokensSupplier(() -> null); queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - assertTrue(queryBuilder == sparseEncodingQueryBuilder); + assertSame(queryBuilder, sparseEncodingQueryBuilder); + } + + private void setUpClusterService(Version version) { + ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); + NeuralSearchClusterUtil.instance().initialize(clusterService); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java b/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java index e513ab035..ba890959c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.search.BooleanClause; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import org.opensearch.test.OpenSearchTestCase; public class NeuralSearchQueryVisitorTests extends OpenSearchTestCase { @@ -38,11 +39,36 @@ public void testAccept_whenNeuralQueryBuilderWithoutFieldModelId_thenSetFieldMod assertEquals("bdcvjkcdjvkddcjxdjsc", neuralQueryBuilder.modelId()); } + public void testAccept_whenNeuralSparseQueryBuilderWithoutModelId_thenSetModelId() { + String modelId = "bdcvjkcdjvkddcjxdjsc"; + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("passage_text"); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(modelId, null); + neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder); + + assertEquals(modelId, neuralSparseQueryBuilder.modelId()); + } + + public void testAccept_whenNeuralSparseQueryBuilderWithoutFieldModelId_thenSetFieldModelId() { + Map neuralInfoMap = new HashMap<>(); + neuralInfoMap.put("passage_text", "bdcvjkcdjvkddcjxdjsc"); + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("passage_text"); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, neuralInfoMap); + neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder); + + assertEquals("bdcvjkcdjvkddcjxdjsc", neuralSparseQueryBuilder.modelId()); + } + public void testAccept_whenNullValuesInVisitor_thenFail() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, null); expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralQueryBuilder)); + expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder)); } public void testGetChildVisitor() {