From d5d19430671c373e3cf919769255ac5d344b6cab Mon Sep 17 00:00:00 2001 From: Lukasz Soszynski Date: Fri, 19 Jul 2024 13:34:49 +0200 Subject: [PATCH] Corrections related to LookupCommandIT and tests related to OpenSearchIndex Signed-off-by: Lukasz Soszynski --- .../sql/planner/physical/LookupOperator.java | 3 +- .../sql/planner/DefaultImplementorTest.java | 115 +++----- .../opensearch/sql/ppl/LookupCommandIT.java | 148 +++++++---- .../opensearch/storage/OpenSearchIndex.java | 113 ++++---- .../OpenSearchExecutionProtectorTest.java | 103 ++++---- .../storage/OpenSearchIndexTest.java | 101 +++++-- .../storage/SingleRowQueryTest.java | 246 ++++++++++++++++++ 7 files changed, 590 insertions(+), 239 deletions(-) create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/SingleRowQueryTest.java diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/LookupOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/LookupOperator.java index 7117d87f5d..c4b1ccd824 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/LookupOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/LookupOperator.java @@ -9,6 +9,7 @@ import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.BiFunction; @@ -108,7 +109,7 @@ public ExprValue next() { } Map tupleInputValue = ExprValueUtils.getTupleValue(inputValue); - Map resultTupleBuilder = new HashMap<>(); + Map resultTupleBuilder = new LinkedHashMap<>(); resultTupleBuilder.putAll(tupleInputValue); for (Map.Entry sourceOfAdditionalField : lookupResult.entrySet()) { String lookedUpFieldName = sourceOfAdditionalField.getKey(); diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 930eb63a03..45d8f6c03c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -21,7 +21,6 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.eval; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.lookup; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.nested; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.rareTopN; @@ -48,7 +47,6 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.model.ExprBooleanValue; -import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.executor.pagination.PlanSerializer; import org.opensearch.sql.expression.DSL; @@ -60,7 +58,6 @@ import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.expression.window.ranking.RowNumberFunction; import org.opensearch.sql.planner.logical.LogicalCloseCursor; -import org.opensearch.sql.planner.logical.LogicalLookup; import org.opensearch.sql.planner.logical.LogicalPaginate; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; @@ -125,27 +122,22 @@ public void visit_should_return_default_physical_operator() { nested( limit( LogicalPlanDSL.dedupe( - lookup( - rareTopN( - sort( - eval( - remove( - rename( - aggregation( - filter(values(emptyList()), filterExpr), - aggregators, - groupByExprs), - mappings), - exclude), - newEvalField), - sortField), - CommandType.TOP, - topByExprs, - rareTopNField), - "lookup_index_name", - Map.of(), - false, - Map.of()), + rareTopN( + sort( + eval( + remove( + rename( + aggregation( + filter(values(emptyList()), filterExpr), + aggregators, + groupByExprs), + mappings), + exclude), + newEvalField), + sortField), + CommandType.TOP, + topByExprs, + rareTopNField), dedupeField), limit, offset), @@ -160,30 +152,24 @@ public void visit_should_return_default_physical_operator() { PhysicalPlanDSL.nested( PhysicalPlanDSL.limit( PhysicalPlanDSL.dedupe( - PhysicalPlanDSL.lookup( - PhysicalPlanDSL.rareTopN( - PhysicalPlanDSL.sort( - PhysicalPlanDSL.eval( - PhysicalPlanDSL.remove( - PhysicalPlanDSL.rename( - PhysicalPlanDSL.agg( - PhysicalPlanDSL.filter( - PhysicalPlanDSL.values(emptyList()), - filterExpr), - aggregators, - groupByExprs), - mappings), - exclude), - newEvalField), - sortField), - CommandType.TOP, - topByExprs, - rareTopNField), - "lookup_index_name", - Map.of(), - false, - Map.of(), - null), + PhysicalPlanDSL.rareTopN( + PhysicalPlanDSL.sort( + PhysicalPlanDSL.eval( + PhysicalPlanDSL.remove( + PhysicalPlanDSL.rename( + PhysicalPlanDSL.agg( + PhysicalPlanDSL.filter( + PhysicalPlanDSL.values(emptyList()), + filterExpr), + aggregators, + groupByExprs), + mappings), + exclude), + newEvalField), + sortField), + CommandType.TOP, + topByExprs, + rareTopNField), dedupeField), limit, offset), @@ -292,37 +278,4 @@ public void visitPaginate_should_remove_it_from_tree() { new ProjectOperator(new ValuesOperator(List.of(List.of())), List.of(), List.of()); assertEquals(physicalPlanTree, logicalPlanTree.accept(implementor, null)); } - - @Test - public void visitLookup_should_build_LookupOperator() { - LogicalPlan values = values(List.of(DSL.literal("to be or not to be"))); - var logicalPlan = lookup(values, "lookup_index_name", Map.of(), false, Map.of()); - var expectedPhysicalPlan = - PhysicalPlanDSL.lookup( - new ValuesOperator(List.of(List.of(DSL.literal("to be or not to be")))), - "lookup_index_name", - Map.of(), - false, - Map.of(), - null); - - PhysicalPlan lookupOperator = logicalPlan.accept(implementor, null); - - assertEquals(expectedPhysicalPlan, lookupOperator); - } - - @Test - public void visitLookup_should_throw_unsupportedOperationException() { - LogicalLookup input = mock(LogicalLookup.class); - LogicalPlan dataSource = mock(LogicalPlan.class); - PhysicalPlan physicalSource = mock(PhysicalPlan.class); - when(dataSource.accept(implementor, null)).thenReturn(physicalSource); - when(input.getChild()).thenReturn(List.of(dataSource)); - PhysicalPlan lookupOperator = implementor.visitLookup(input, null); - when(physicalSource.next()).thenReturn(ExprValueUtils.tupleValue(Map.of("field", "value"))); - - var ex = assertThrows(UnsupportedOperationException.class, () -> lookupOperator.next()); - - assertEquals("Lookup not implemented by DefaultImplementor", ex.getMessage()); - } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/LookupCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/LookupCommandIT.java index ade37b1241..d6a85f4687 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/LookupCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/LookupCommandIT.java @@ -11,6 +11,7 @@ import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import java.io.IOException; +import java.math.BigDecimal; import org.json.JSONObject; import org.junit.jupiter.api.Test; @@ -32,49 +33,49 @@ public void testLookup() throws IOException { verifyDataRows( result, rows( - 28.1, - "2015-01-20 15:31:32.406431", + new BigDecimal("28.1"), 255, + "2015-01-20 15:31:32.406431", "temperature-basement", "meter", 255, "VendorOne"), rows( - 27.8, - "2016-01-20 15:31:33.509334", + new BigDecimal("27.8"), 256, + "2016-01-20 15:31:33.509334", "temperature-living-room", "temperature meter", 256, "VendorTwo"), rows( - 27.4, - "2017-01-20 15:31:35.732436", + new BigDecimal("27.4"), 257, + "2017-01-20 15:31:35.732436", "temperature-bedroom", "camcorder", 257, "VendorThree"), rows( - 28.5, - "2018-01-20 15:32:32.406431", + new BigDecimal("28.5"), 255, + "2018-01-20 15:32:32.406431", "temperature-basement", "meter", 255, "VendorOne"), rows( - 27.9, - "2019-01-20 15:32:33.509334", + new BigDecimal("27.9"), 256, + "2019-01-20 15:32:33.509334", "temperature-living-room", "temperature meter", 256, "VendorTwo"), rows( - 27.4, - "2020-01-20 15:32:35.732436", + new BigDecimal("27.4"), 257, + "2020-01-20 15:32:35.732436", "temperature-bedroom", "camcorder", 257, @@ -90,12 +91,23 @@ public void testLookupSelectedAttribute() throws IOException { TEST_INDEX_IOT_READINGS, TEST_INDEX_IOT_SENSORS)); verifyDataRows( result, - rows(28.1, "2015-01-20 15:31:32.406431", 255, "meter", "VendorOne"), - rows(27.8, "2016-01-20 15:31:33.509334", 256, "temperature meter", "VendorTwo"), - rows(27.4, "2017-01-20 15:31:35.732436", 257, "camcorder", "VendorThree"), - rows(28.5, "2018-01-20 15:32:32.406431", 255, "meter", "VendorOne"), - rows(27.9, "2019-01-20 15:32:33.509334", 256, "temperature meter", "VendorTwo"), - rows(27.4, "2020-01-20 15:32:35.732436", 257, "camcorder", "VendorThree")); + rows(new BigDecimal("28.1"), 255, "2015-01-20 15:31:32.406431", "meter", "VendorOne"), + rows( + new BigDecimal("27.8"), + 256, + "2016-01-20 15:31:33.509334", + "temperature meter", + "VendorTwo"), + rows(new BigDecimal("27.4"), 257, "2017-01-20 15:31:35.732436", "camcorder", "VendorThree"), + rows(new BigDecimal("28.5"), 255, "2018-01-20 15:32:32.406431", "meter", "VendorOne"), + rows( + new BigDecimal("27.9"), + 256, + "2019-01-20 15:32:33.509334", + "temperature meter", + "VendorTwo"), + rows( + new BigDecimal("27.4"), 257, "2020-01-20 15:32:35.732436", "camcorder", "VendorThree")); } @Test @@ -108,12 +120,36 @@ public void testLookupRenameSelectedAttributes() throws IOException { TEST_INDEX_IOT_READINGS, TEST_INDEX_IOT_SENSORS)); verifyDataRows( result, - rows(28.1, "2015-01-20 15:31:32.406431", 255, 255, "meter", "VendorOne"), - rows(27.8, "2016-01-20 15:31:33.509334", 256, 256, "temperature meter", "VendorTwo"), - rows(27.4, "2017-01-20 15:31:35.732436", 257, 257, "camcorder", "VendorThree"), - rows(28.5, "2018-01-20 15:32:32.406431", 255, 255, "meter", "VendorOne"), - rows(27.9, "2019-01-20 15:32:33.509334", 256, 256, "temperature meter", "VendorTwo"), - rows(27.4, "2020-01-20 15:32:35.732436", 257, 257, "camcorder", "VendorThree")); + rows(new BigDecimal("28.1"), 255, "2015-01-20 15:31:32.406431", 255, "meter", "VendorOne"), + rows( + new BigDecimal("27.8"), + 256, + "2016-01-20 15:31:33.509334", + 256, + "temperature meter", + "VendorTwo"), + rows( + new BigDecimal("27.4"), + 257, + "2017-01-20 15:31:35.732436", + 257, + "camcorder", + "VendorThree"), + rows(new BigDecimal("28.5"), 255, "2018-01-20 15:32:32.406431", 255, "meter", "VendorOne"), + rows( + new BigDecimal("27.9"), + 256, + "2019-01-20 15:32:33.509334", + 256, + "temperature meter", + "VendorTwo"), + rows( + new BigDecimal("27.4"), + 257, + "2020-01-20 15:32:35.732436", + 257, + "camcorder", + "VendorThree")); } @Test @@ -125,12 +161,12 @@ public void testLookupSelectedMultipleAttributes() throws IOException { TEST_INDEX_IOT_READINGS, TEST_INDEX_IOT_SENSORS)); verifyDataRows( result, - rows(28.1, "2015-01-20 15:31:32.406431", 255, "meter"), - rows(27.8, "2016-01-20 15:31:33.509334", 256, "temperature meter"), - rows(27.4, "2017-01-20 15:31:35.732436", 257, "camcorder"), - rows(28.5, "2018-01-20 15:32:32.406431", 255, "meter"), - rows(27.9, "2019-01-20 15:32:33.509334", 256, "temperature meter"), - rows(27.4, "2020-01-20 15:32:35.732436", 257, "camcorder")); + rows(new BigDecimal("28.1"), 255, "2015-01-20 15:31:32.406431", "meter"), + rows(new BigDecimal("27.8"), 256, "2016-01-20 15:31:33.509334", "temperature meter"), + rows(new BigDecimal("27.4"), 257, "2017-01-20 15:31:35.732436", "camcorder"), + rows(new BigDecimal("28.5"), 255, "2018-01-20 15:32:32.406431", "meter"), + rows(new BigDecimal("27.9"), 256, "2019-01-20 15:32:33.509334", "temperature meter"), + rows(new BigDecimal("27.4"), 257, "2020-01-20 15:32:35.732436", "camcorder")); } @Test @@ -143,32 +179,32 @@ public void testLookupShouldAppendOnlyShouldBeFalseByDefault() throws IOExceptio TEST_INDEX_IOT_READINGS, TEST_INDEX_IOT_SENSORS)); verifyDataRows( result, - rows("2015-01-20 15:31:32.406431", 255, "VendorOne", "temperature-basement", "meter", 255), + rows(255, "2015-01-20 15:31:32.406431", "VendorOne", "temperature-basement", "meter", 255), rows( - "2016-01-20 15:31:33.509334", 256, + "2016-01-20 15:31:33.509334", "VendorTwo", "temperature-living-room", "temperature meter", 256), rows( - "2017-01-20 15:31:35.732436", 257, + "2017-01-20 15:31:35.732436", "VendorThree", "temperature-bedroom", "camcorder", 257), - rows("2018-01-20 15:32:32.406431", 255, "VendorOne", "temperature-basement", "meter", 255), + rows(255, "2018-01-20 15:32:32.406431", "VendorOne", "temperature-basement", "meter", 255), rows( - "2019-01-20 15:32:33.509334", 256, + "2019-01-20 15:32:33.509334", "VendorTwo", "temperature-living-room", "temperature meter", 256), rows( - "2020-01-20 15:32:35.732436", 257, + "2020-01-20 15:32:35.732436", "VendorThree", "temperature-bedroom", "camcorder", @@ -185,23 +221,47 @@ public void testLookupWithAppendOnlyFalse() throws IOException { TEST_INDEX_IOT_READINGS, TEST_INDEX_IOT_SENSORS)); verifyDataRows( result, - rows("2015-01-20 15:31:32.406431", 255, 28.1, "temperature-basement", "meter", 255), rows( - "2016-01-20 15:31:33.509334", + 255, + "2015-01-20 15:31:32.406431", + new BigDecimal("28.1"), + "temperature-basement", + "meter", + 255), + rows( 256, - 27.8, + "2016-01-20 15:31:33.509334", + new BigDecimal("27.8"), "temperature-living-room", "temperature meter", 256), - rows("2017-01-20 15:31:35.732436", 257, 27.4, "temperature-bedroom", "camcorder", 257), - rows("2018-01-20 15:32:32.406431", 255, 28.5, "temperature-basement", "meter", 255), rows( - "2019-01-20 15:32:33.509334", + 257, + "2017-01-20 15:31:35.732436", + new BigDecimal("27.4"), + "temperature-bedroom", + "camcorder", + 257), + rows( + 255, + "2018-01-20 15:32:32.406431", + new BigDecimal("28.5"), + "temperature-basement", + "meter", + 255), + rows( 256, - 27.9, + "2019-01-20 15:32:33.509334", + new BigDecimal("27.9"), "temperature-living-room", "temperature meter", 256), - rows("2020-01-20 15:32:35.732436", 257, 27.4, "temperature-bedroom", "camcorder", 257)); + rows( + 257, + "2020-01-20 15:32:35.732436", + new BigDecimal("27.4"), + "temperature-bedroom", + "camcorder", + 257)); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index f250bbea15..640d49355f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -9,16 +9,20 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; import lombok.RequiredArgsConstructor; +import org.jetbrains.annotations.Nullable; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.unit.TimeValue; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.type.ExprCoreType; @@ -217,66 +221,79 @@ public PhysicalPlan visitML(LogicalML node, OpenSearchIndexScan context) { @Override public PhysicalPlan visitLookup(LogicalLookup node, OpenSearchIndexScan context) { + SingleRowQuery singleRowQuery = new SingleRowQuery(client); return new LookupOperator( visitChild(node, context), node.getIndexName(), node.getMatchFieldMap(), node.getAppendOnly(), node.getCopyFieldMap(), - lookup()); + lookup(singleRowQuery)); } - BiFunction, Map> lookup() { - - if (client.getNodeClient() == null) { - throw new RuntimeException( - "Can not perform lookup because openSearchClient was null. This is likely a bug."); - } - + BiFunction, Map> lookup( + SingleRowQuery singleRowQuery) { + Objects.requireNonNull(singleRowQuery, "SingleRowQuery is required to perform lookup"); return (indexName, inputMap) -> { Map matchMap = (Map) inputMap.get("_match"); Set copySet = (Set) inputMap.get("_copy"); - - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - - for (Map.Entry f : matchMap.entrySet()) { - BoolQueryBuilder orQueryBuilder = new BoolQueryBuilder(); - - // Todo: Search with term and a match query? Or terms only? - orQueryBuilder.should(new TermQueryBuilder(f.getKey(), f.getValue().toString())); - orQueryBuilder.should(new MatchQueryBuilder(f.getKey(), f.getValue().toString())); - orQueryBuilder.minimumShouldMatch(1); - - // filter is the same as "must" but ignores scoring - boolQueryBuilder.filter(orQueryBuilder); - } - - SearchResponse result = - client - .getNodeClient() - .search( - new SearchRequest(indexName) - .source( - SearchSourceBuilder.searchSource() - .fetchSource( - copySet == null ? null : copySet.toArray(new String[0]), null) - .query(boolQueryBuilder) - .size(2))) - .actionGet(); - - int hits = result.getHits().getHits().length; - - if (hits == 0) { - // null indicates no hits for the lookup found - return null; - } - - if (hits != 1) { - throw new RuntimeException("too many hits for " + indexName + " (" + hits + ")"); - } - - return result.getHits().getHits()[0].getSourceAsMap(); + return singleRowQuery.executeQuery(indexName, matchMap, copySet); }; } } + + static class SingleRowQuery { + + private final NodeClient nodeClient; + + public SingleRowQuery(OpenSearchClient openSearchClient) { + Objects.requireNonNull(openSearchClient, "Opensearch client is required to perform lookup"); + this.nodeClient = + Objects.requireNonNull( + openSearchClient.getNodeClient(), + "Can not perform lookup because openSearchClient was null. This is likely a bug."); + } + + public @Nullable Map executeQuery( + String indexName, Map matchMap, Set copySet) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + for (Map.Entry f : matchMap.entrySet()) { + BoolQueryBuilder orQueryBuilder = new BoolQueryBuilder(); + + // Todo: Search with term and a match query? Or terms only? + orQueryBuilder.should(new TermQueryBuilder(f.getKey(), f.getValue().toString())); + orQueryBuilder.should(new MatchQueryBuilder(f.getKey(), f.getValue().toString())); + orQueryBuilder.minimumShouldMatch(1); + + // filter is the same as "must" but ignores scoring + boolQueryBuilder.filter(orQueryBuilder); + } + + SearchResponse result = + nodeClient + .search( + new SearchRequest(indexName) + .source( + SearchSourceBuilder.searchSource() + .fetchSource( + copySet == null ? null : copySet.toArray(new String[0]), null) + .query(boolQueryBuilder) + .size(2))) + .actionGet(); + + SearchHit[] searchHits = result.getHits().getHits(); + int hits = searchHits.length; + if (hits == 0) { + // null indicates no hits for the lookup found + return null; + } + + if (hits != 1) { + throw new RuntimeException("too many hits for " + indexName + " (" + hits + ")"); + } + + return searchHits[0].getSourceAsMap(); + } + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index b2dc042110..7d5ec81127 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.BeforeEach; @@ -81,6 +82,8 @@ class OpenSearchExecutionProtectorTest { @Mock private OpenSearchSettings settings; + @Mock private BiFunction, Map> lookupFunction; + private OpenSearchExecutionProtector executionProtector; @BeforeEach @@ -120,59 +123,71 @@ void test_protect_indexScan() { settings.getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE)); assertEquals( PhysicalPlanDSL.project( - PhysicalPlanDSL.limit( - PhysicalPlanDSL.dedupe( - PhysicalPlanDSL.rareTopN( - resourceMonitor( - PhysicalPlanDSL.sort( - PhysicalPlanDSL.eval( - PhysicalPlanDSL.remove( - PhysicalPlanDSL.rename( - PhysicalPlanDSL.agg( - filter( - resourceMonitor( - new OpenSearchIndexScan( - client, maxResultWindow, request)), - filterExpr), - aggregators, - groupByExprs), - mappings), - exclude), - newEvalField), - sortField)), - CommandType.TOP, - topExprs, - topField), - dedupeField), - limit, - offset), - include), - executionProtector.protect( - PhysicalPlanDSL.project( + PhysicalPlanDSL.lookup( PhysicalPlanDSL.limit( PhysicalPlanDSL.dedupe( PhysicalPlanDSL.rareTopN( - PhysicalPlanDSL.sort( - PhysicalPlanDSL.eval( - PhysicalPlanDSL.remove( - PhysicalPlanDSL.rename( - PhysicalPlanDSL.agg( - filter( - new OpenSearchIndexScan( - client, maxResultWindow, request), - filterExpr), - aggregators, - groupByExprs), - mappings), - exclude), - newEvalField), - sortField), + resourceMonitor( + PhysicalPlanDSL.sort( + PhysicalPlanDSL.eval( + PhysicalPlanDSL.remove( + PhysicalPlanDSL.rename( + PhysicalPlanDSL.agg( + filter( + resourceMonitor( + new OpenSearchIndexScan( + client, maxResultWindow, request)), + filterExpr), + aggregators, + groupByExprs), + mappings), + exclude), + newEvalField), + sortField)), CommandType.TOP, topExprs, topField), dedupeField), limit, offset), + "lookup_index_name", + Map.of(), + false, + Map.of(), + lookupFunction), + include), + executionProtector.protect( + PhysicalPlanDSL.project( + PhysicalPlanDSL.lookup( + PhysicalPlanDSL.limit( + PhysicalPlanDSL.dedupe( + PhysicalPlanDSL.rareTopN( + PhysicalPlanDSL.sort( + PhysicalPlanDSL.eval( + PhysicalPlanDSL.remove( + PhysicalPlanDSL.rename( + PhysicalPlanDSL.agg( + filter( + new OpenSearchIndexScan( + client, maxResultWindow, request), + filterExpr), + aggregators, + groupByExprs), + mappings), + exclude), + newEvalField), + sortField), + CommandType.TOP, + topExprs, + topField), + dedupeField), + limit, + offset), + "lookup_index_name", + Map.of(), + false, + Map.of(), + lookupFunction), include))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 3ddb07d86a..d7ce2dc7b5 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -28,8 +28,11 @@ import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; @@ -37,7 +40,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.node.NodeClient; import org.opensearch.common.unit.TimeValue; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; @@ -53,6 +58,8 @@ import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex.OpenSearchDefaultImplementor; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex.SingleRowQuery; import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; @@ -65,6 +72,9 @@ class OpenSearchIndexTest { public static final TimeValue SCROLL_TIMEOUT = new TimeValue(1); public static final OpenSearchRequest.IndexName INDEX_NAME = new OpenSearchRequest.IndexName("test"); + public static final String LOOKUP_INDEX_NAME = "lookup-index-name"; + public static final String LOOKUP_TABLE_FIELD = "lookup_table_field"; + public static final String QUERY_FIELD = "query_field"; @Mock private OpenSearchClient client; @@ -74,6 +84,8 @@ class OpenSearchIndexTest { @Mock private IndexMapping mapping; + @Mock private NodeClient nodeClient; + private OpenSearchIndex index; @BeforeEach @@ -222,6 +234,7 @@ void implementRelationOperatorWithOptimization() { void implementOtherLogicalOperators() { when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + when(client.getNodeClient()).thenReturn(nodeClient); NamedExpression include = named("age", ref("age", INTEGER)); ReferenceExpression exclude = ref("name", STRING); ReferenceExpression dedupeField = ref("name", STRING); @@ -234,34 +247,80 @@ void implementOtherLogicalOperators() { LogicalPlan plan = project( - LogicalPlanDSL.dedupe( - sort( - eval( - remove(rename(index.createScanBuilder(), mappings), exclude), newEvalField), - sortField), - dedupeField), + LogicalPlanDSL.lookup( + LogicalPlanDSL.dedupe( + sort( + eval( + remove(rename(index.createScanBuilder(), mappings), exclude), + newEvalField), + sortField), + dedupeField), + LOOKUP_INDEX_NAME, + Map.of( + new ReferenceExpression(LOOKUP_TABLE_FIELD, STRING), + new ReferenceExpression(QUERY_FIELD, STRING)), + true, + Collections.emptyMap()), include); Integer maxResultWindow = index.getMaxResultWindow(); final var requestBuilder = new OpenSearchRequestBuilder(QUERY_SIZE_LIMIT, exprValueFactory); + + BiFunction, Map> anyBifunction = + new BiFunction<>() { + @Override + public Map apply(String s, Map stringObjectMap) { + return Map.of(); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof BiFunction; + } + }; assertEquals( PhysicalPlanDSL.project( - PhysicalPlanDSL.dedupe( - PhysicalPlanDSL.sort( - PhysicalPlanDSL.eval( - PhysicalPlanDSL.remove( - PhysicalPlanDSL.rename( - new OpenSearchIndexScan( - client, - QUERY_SIZE_LIMIT, - requestBuilder.build( - INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), - mappings), - exclude), - newEvalField), - sortField), - dedupeField), + PhysicalPlanDSL.lookup( + PhysicalPlanDSL.dedupe( + PhysicalPlanDSL.sort( + PhysicalPlanDSL.eval( + PhysicalPlanDSL.remove( + PhysicalPlanDSL.rename( + new OpenSearchIndexScan( + client, + QUERY_SIZE_LIMIT, + requestBuilder.build( + INDEX_NAME, maxResultWindow, SCROLL_TIMEOUT)), + mappings), + exclude), + newEvalField), + sortField), + dedupeField), + LOOKUP_INDEX_NAME, + Map.of( + new ReferenceExpression(LOOKUP_TABLE_FIELD, STRING), + new ReferenceExpression(QUERY_FIELD, STRING)), + true, + Collections.emptyMap(), + anyBifunction), include), index.implement(plan)); } + + @Test + public void lookupShouldExecuteQuery() { + OpenSearchDefaultImplementor implementor = new OpenSearchDefaultImplementor(client); + Map matchMap = Map.of("column name", "required value"); + Set copySet = Set.of("column_1", "column_2"); + Map parameters = Map.of("_match", matchMap, "_copy", copySet); + SingleRowQuery singleRowQuery = Mockito.mock(SingleRowQuery.class); + Map resultRow = Map.of("column_1", 1, "column_2", 2); + when(singleRowQuery.executeQuery("lookup_index_name", matchMap, copySet)).thenReturn(resultRow); + BiFunction, Map> lookup = + implementor.lookup(singleRowQuery); + + Map givenResult = lookup.apply("lookup_index_name", parameters); + + assertEquals(resultRow, givenResult); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/SingleRowQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/SingleRowQueryTest.java new file mode 100644 index 0000000000..dbbb09f705 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/SingleRowQueryTest.java @@ -0,0 +1,246 @@ +package org.opensearch.sql.opensearch.storage; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.emptyArray; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.storage.OpenSearchIndex.SingleRowQuery; + +@ExtendWith(MockitoExtension.class) +class SingleRowQueryTest { + + @Mock private OpenSearchClient openSearchClient; + @Mock private NodeClient nodeClient; + @Mock private ActionFuture searchFuture; + @Mock private SearchResponse response; + @Mock private SearchHits searchHits; + @Mock private SearchHit hit; + @Captor private ArgumentCaptor searchRequestCaptor; + + private SingleRowQuery singleRowQuery; + + @BeforeEach + public void beforeEach() { + when(openSearchClient.getNodeClient()).thenReturn(nodeClient); + singleRowQuery = new SingleRowQuery(openSearchClient); + } + + @AfterEach + public void shouldUseExactlyOneSearch() { + verify(nodeClient, times(1)).search(any()); + } + + @Test + void shouldReturnNullWhenRowDoesNotExist() { + Map predicates = Map.of("column_name", "value 1"); + Set projection = Set.of("returned column name"); + mockSearch(new SearchHit[0]); + + Map row = singleRowQuery.executeQuery("index_name", predicates, projection); + + assertThat(row, nullValue()); + } + + @Test + void shouldThrowExceptionWhenMoreThanOneRowIsFound() { + Map predicates = Map.of("column_name", "value 1"); + Set projection = Set.of("returned column name"); + mockSearch(new SearchHit[2]); + + RuntimeException ex = + assertThrows( + RuntimeException.class, + () -> singleRowQuery.executeQuery("index_name", predicates, projection)); + + assertThat(ex.getMessage(), Matchers.containsString("too many hits")); + } + + @ParameterizedTest + @ValueSource(strings = {"table_1", "other_table", "yet_another_table"}) + void shouldQueryCorrectTable(String tableName) { + Map predicates = Map.of("column_name", "value row criteria"); + Set projection = Set.of("returned column name"); + + mockSearch(new SearchHit[0]); + + Map row = singleRowQuery.executeQuery(tableName, predicates, projection); + + SearchRequest searchRequest = searchRequestCaptor.getValue(); + assertThat(searchRequest.indices(), equalTo(new String[] {tableName})); + assertThat(row, nullValue()); + } + + @ParameterizedTest + @CsvSource({ + "field,value row criteria,fetched_column_name_one", + "column,another value row criteria,another_fetched_column_name", + "attribute,yet another value row criteria,third_fetched_column_name", + "regular_name,one_2_three,this_is_the_fetched_column_name", + "extra_ordinary_column_name,I am an expected value,my_name_is_the_fetched_column", + "nice_column_name,last value row criteria,my_nice_fetched_column", + }) + public void shouldBuildOpenSearchQueryForOnePredicate( + String columnName, String valuePredicate, String projection) { + Map predicates = Map.of(columnName, valuePredicate); + mockSearch(new SearchHit[0]); + + Map row = + singleRowQuery.executeQuery("index_name", predicates, Set.of(projection)); + + SearchRequest searchRequest = searchRequestCaptor.getValue(); + BoolQueryBuilder filterQueryForSinglePredicate = new BoolQueryBuilder(); + filterQueryForSinglePredicate.should(new TermQueryBuilder(columnName, valuePredicate)); + filterQueryForSinglePredicate.should(new MatchQueryBuilder(columnName, valuePredicate)); + filterQueryForSinglePredicate.minimumShouldMatch(1); + BoolQueryBuilder expectedQuery = new BoolQueryBuilder(); + expectedQuery.filter(filterQueryForSinglePredicate); + + assertThat(searchRequest.source().query(), equalTo(expectedQuery)); + assertThat(searchRequest.source().size(), equalTo(2)); + assertThat(searchRequest.source().fetchSource().includes(), equalTo(new String[] {projection})); + assertThat(searchRequest.source().fetchSource().excludes(), emptyArray()); + assertThat(row, nullValue()); + } + + @ParameterizedTest + @CsvSource({ + "columnName1,columnName2,columnName3", + "columnName4,columnName5,columnName6", + "columnName,columnName8,columnName9", + "extraOrdinaryOne,eXtraOrdinaryTwo,extraOrdinaryThree" + }) + void shouldFetchVariousColumns(String columnOne, String columnTwo, String columnThree) { + Map predicates = Map.of("find_only_row", "with_value_abc"); + mockSearch(new SearchHit[0]); + + Map row = + singleRowQuery.executeQuery( + "index_name", predicates, Set.of(columnOne, columnTwo, columnThree)); + + assertThat(row, nullValue()); + } + + @ParameterizedTest + @MethodSource("variousPredicates") + void shouldUseComplexPredicate(Map predicates) { + Set projection = Set.of("returned column name"); + mockSearch(new SearchHit[0]); + + Map row = singleRowQuery.executeQuery("index_name", predicates, projection); + + SearchRequest searchRequest = searchRequestCaptor.getValue(); + BoolQueryBuilder expectedQuery = new BoolQueryBuilder(); + predicates.entrySet().stream() + .map( + entry -> { + String columnName = entry.getKey(); + Object value = entry.getValue(); + BoolQueryBuilder filterQueryForSinglePredicate = new BoolQueryBuilder(); + filterQueryForSinglePredicate.should(new TermQueryBuilder(columnName, value)); + filterQueryForSinglePredicate.should(new MatchQueryBuilder(columnName, value)); + filterQueryForSinglePredicate.minimumShouldMatch(1); + return filterQueryForSinglePredicate; + }) + .forEach(expectedQuery::filter); + + assertThat(searchRequest.source().query(), equalTo(expectedQuery)); + assertThat(row, nullValue()); + } + + static Stream variousPredicates() { + return Stream.of( + Arguments.of(Map.of("column_name_1", "value row criteria_12")), + Arguments.of( + Map.of( + "column_name_2", "value row criteria_23", "another_column_5", "another value_8")), + Arguments.of( + Map.of( + "column_name_3", + "value row criteria_34", + "another_column_6", + "another value_8", + "yet_another_column_11", + "yet another value_13")), + Arguments.of( + Map.of( + "column_name_4", + "value row criteria_45", + "another_column_7", + "another value_10", + "yet_another_column_12", + "yet another value_14", + "extra_column_15", + "extra value_16"))); + } + + @Test + public void shouldReturnRow() { + Map predicates = Map.of("column_name", "value row criteria"); + Set projection = Set.of("returned column name"); + Map searchResult = + Map.of("column_name", "value row criteria", "returned column name", "value 2"); + mockSearch(searchResult); + + Map row = singleRowQuery.executeQuery("index_name", predicates, projection); + + assertThat(row, equalTo(searchResult)); + } + + @Test + public void shouldTreatProjectionAsOptionalParameter() { + Map predicates = Map.of("column_name", "value row criteria"); + Set projection = null; + Map searchResult = + Map.of("column_name", "value row criteria", "returned column name", "value 2"); + mockSearch(searchResult); + + Map row = singleRowQuery.executeQuery("index_name", predicates, projection); + + assertThat(row, equalTo(searchResult)); + } + + private void mockSearch(Map searchResult) { + when(hit.getSourceAsMap()).thenReturn(searchResult); + mockSearch(new SearchHit[] {hit}); + } + + private void mockSearch(SearchHit[] searchResult) { + when(nodeClient.search(searchRequestCaptor.capture())).thenReturn(searchFuture); + when(searchFuture.actionGet()).thenReturn(response); + when(response.getHits()).thenReturn(searchHits); + when(searchHits.getHits()).thenReturn(searchResult); + } +}