Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Nov 3, 2023
1 parent 73d13a0 commit dbda2c4
Showing 1 changed file with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

import lombok.SneakyThrows;

import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.opensearch.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand All @@ -38,9 +41,11 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -497,4 +502,23 @@ public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
queryBuilder = sparseEncodingQueryBuilder.doRewrite(null);
assertTrue(queryBuilder == sparseEncodingQueryBuilder);
}

@SneakyThrows
public void testDoToQuery_successfulDoToQuery() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.maxTokenScore(MAX_TOKEN_SCORE)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
QueryShardContext mockedQueryShardContext = mock(QueryShardContext.class);
MappedFieldType mockedMappedFieldType = mock(MappedFieldType.class);
doAnswer(invocation -> "rank_features").when(mockedMappedFieldType).typeName();
doAnswer(invocation -> mockedMappedFieldType).when(mockedQueryShardContext).fieldMapper(any());

BooleanQuery.Builder targetQueryBuilder = new BooleanQuery.Builder();
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f), BooleanClause.Occur.SHOULD);
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f), BooleanClause.Occur.SHOULD);

assertEquals(sparseEncodingQueryBuilder.doToQuery(mockedQueryShardContext), targetQueryBuilder.build());
}
}

0 comments on commit dbda2c4

Please sign in to comment.