Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Feb 27, 2024
1 parent 7176ebb commit ca5ea46
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
Expand All @@ -38,6 +41,8 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.test.OpenSearchTestCase;

import lombok.SneakyThrows;
Expand All @@ -51,6 +56,11 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {
private static final String QUERY_NAME = "queryName";
private static final Supplier<Map<String, Float>> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f);

@Before
public void setupClusterServiceToCurrentVersion() {
setUpClusterService(Version.CURRENT);
}

@SneakyThrows
public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() {
/*
Expand Down Expand Up @@ -162,7 +172,7 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() {
}

@SneakyThrows
public void testFromXContent_whenBuildWithMissingModelId_thenFail() {
public void testFromXContent_whenBuildWithMissingModelIdInCurrentVersion_thenSuccess() {
/*
{
"VECTOR_FIELD": {
Expand All @@ -177,6 +187,30 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() {
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertNull(sparseEncodingQueryBuilder.modelId());
}

@SneakyThrows
public void testFromXContent_whenBuildWithMissingModelIdInOldVersion_thenFail() {
/*
{
"VECTOR_FIELD": {
"query_text": "string"
}
}
*/
setUpClusterService(Version.V_2_12_0);
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
Expand Down Expand Up @@ -241,6 +275,11 @@ public void testToXContent() {
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
}

public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() {
setUpClusterService(Version.V_2_12_0);
testStreams();
}

@SneakyThrows
public void testStreams() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
Expand Down Expand Up @@ -436,10 +475,15 @@ public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null);
assertTrue(queryBuilder == sparseEncodingQueryBuilder);
assertSame(queryBuilder, sparseEncodingQueryBuilder);

sparseEncodingQueryBuilder.queryTokensSupplier(() -> null);
queryBuilder = sparseEncodingQueryBuilder.doRewrite(null);
assertTrue(queryBuilder == sparseEncodingQueryBuilder);
assertSame(queryBuilder, sparseEncodingQueryBuilder);
}

private void setUpClusterService(Version version) {
ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version);
NeuralSearchClusterUtil.instance().initialize(clusterService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.search.BooleanClause;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralSearchQueryVisitorTests extends OpenSearchTestCase {
Expand Down Expand Up @@ -38,11 +39,36 @@ public void testAccept_whenNeuralQueryBuilderWithoutFieldModelId_thenSetFieldMod
assertEquals("bdcvjkcdjvkddcjxdjsc", neuralQueryBuilder.modelId());
}

public void testAccept_whenNeuralSparseQueryBuilderWithoutModelId_thenSetModelId() {
String modelId = "bdcvjkcdjvkddcjxdjsc";
NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
neuralSparseQueryBuilder.fieldName("passage_text");

NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(modelId, null);
neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder);

assertEquals(modelId, neuralSparseQueryBuilder.modelId());
}

public void testAccept_whenNeuralSparseQueryBuilderWithoutFieldModelId_thenSetFieldModelId() {
Map<String, Object> neuralInfoMap = new HashMap<>();
neuralInfoMap.put("passage_text", "bdcvjkcdjvkddcjxdjsc");
NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
neuralSparseQueryBuilder.fieldName("passage_text");

NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, neuralInfoMap);
neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder);

assertEquals("bdcvjkcdjvkddcjxdjsc", neuralSparseQueryBuilder.modelId());
}

public void testAccept_whenNullValuesInVisitor_thenFail() {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, null);

expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralQueryBuilder));
expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder));
}

public void testGetChildVisitor() {
Expand Down

0 comments on commit ca5ea46

Please sign in to comment.