diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index b7a9145de..891ceaf0b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.doAnswer; @@ -49,6 +50,7 @@ import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.common.VectorUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -571,6 +573,42 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); } + @SneakyThrows + public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSetVectorSupplier() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) + .modelId(MODEL_ID) + .k(K); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(expectedVector); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + NeuralQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set vector supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.vectorSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); + } + public void testRewrite_whenVectorNull_thenReturnCopy() { Supplier nullSupplier = () -> null; NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) @@ -610,4 +648,19 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); } + + public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> neuralQueryBuilder.doToQuery(queryShardContext) + ); + assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage()); + } }