Skip to content

Commit

Permalink
Added unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 3, 2023
1 parent 9286598 commit 768a024
Showing 1 changed file with 53 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.query;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -49,6 +50,7 @@
import org.opensearch.index.query.MatchNoneQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
Expand Down Expand Up @@ -571,6 +573,42 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() {
assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f);
}

@SneakyThrows
public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSetVectorSupplier() {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.queryImage(IMAGE_TEXT)
.modelId(MODEL_ID)
.k(K);
List<Float> expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class);
doAnswer(invocation -> {
ActionListener<List<Float>> listener = invocation.getArgument(2);
listener.onResponse(expectedVector);
return null;
}).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any());
NeuralQueryBuilder.initialize(mlCommonsClientAccessor);

final CountDownLatch inProgressLatch = new CountDownLatch(1);
QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
doAnswer(invocation -> {
BiConsumer<Client, ActionListener<?>> biConsumer = invocation.getArgument(0);
biConsumer.accept(
null,
ActionListener.wrap(
response -> inProgressLatch.countDown(),
err -> fail("Failed to set vector supplier: " + err.getMessage())
)
);
return null;
}).when(queryRewriteContext).registerAsyncAction(any());

NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext);
assertNotNull(queryBuilder.vectorSupplier());
assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS));
assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f);
}

public void testRewrite_whenVectorNull_thenReturnCopy() {
Supplier<float[]> nullSupplier = () -> null;
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME)
Expand Down Expand Up @@ -610,4 +648,19 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() {
KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder;
assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter());
}

public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.k(K)
.vectorSupplier(TEST_VECTOR_SUPPLIER)
.filter(TEST_FILTER);
QueryShardContext queryShardContext = mock(QueryShardContext.class);
UnsupportedOperationException exception = expectThrows(
UnsupportedOperationException.class,
() -> neuralQueryBuilder.doToQuery(queryShardContext)
);
assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage());
}
}

0 comments on commit 768a024

Please sign in to comment.