diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java index 9a416c9683..e4bbfde34a 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/CsvFormatResponseIT.java @@ -706,6 +706,103 @@ public void selectFunctionAsFieldTest() throws IOException { Assert.assertEquals(1, headers.size()); } + @Test + public void unionTest() throws IOException { + String query = + String.format( + Locale.ROOT, + "SELECT firstname, lastname FROM %s LIMIT 3 " + + "UNION ALL SELECT firstname, lastname FROM %s LIMIT 3", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT); + CSVResult result = executeCsvRequest(query, false, false, false); + List headers = result.getHeaders(); + Assert.assertEquals(2, headers.size()); + Assert.assertTrue(headers.contains("firstname")); + Assert.assertTrue(headers.contains("lastname")); + + List lines = result.getLines(); + Assert.assertEquals(6, lines.size()); + assertEquals(lines.get(0), "Amber,Duke"); + assertEquals(lines.get(1), "Hattie,Bond"); + assertEquals(lines.get(2), "Nanette,Bates"); + assertEquals(lines.get(0), "Amber,Duke"); + assertEquals(lines.get(1), "Hattie,Bond"); + assertEquals(lines.get(2), "Nanette,Bates"); + } + + @Test + public void unionWithAliasLeftTest() throws IOException { + String query = + String.format( + Locale.ROOT, + "SELECT lastname AS firstname FROM %s LIMIT 3 " + + "UNION ALL SELECT firstname FROM %s LIMIT 3", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT); + CSVResult result = executeCsvRequest(query, false, false, false); + List headers = result.getHeaders(); + Assert.assertEquals(1, headers.size()); + Assert.assertTrue(headers.contains("firstname")); + + List lines = result.getLines(); + Assert.assertEquals(6, lines.size()); + assertEquals(lines.get(0), "Duke"); + assertEquals(lines.get(1), "Bond"); + assertEquals(lines.get(2), "Bates"); + assertEquals(lines.get(3), "Amber"); + assertEquals(lines.get(4), "Hattie"); + assertEquals(lines.get(5), "Nanette"); + } + + @Test + public void unionWithAliasRightTest() throws IOException { + String query = + String.format( + Locale.ROOT, + "SELECT firstname FROM %s LIMIT 3 " + + "UNION ALL SELECT lastname AS firstname FROM %s LIMIT 3", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT); + CSVResult result = executeCsvRequest(query, false, false, false); + List headers = result.getHeaders(); + Assert.assertEquals(1, headers.size()); + Assert.assertTrue(headers.contains("firstname")); + + List lines = result.getLines(); + Assert.assertEquals(6, lines.size()); + assertEquals(lines.get(0), "Amber"); + assertEquals(lines.get(1), "Hattie"); + assertEquals(lines.get(2), "Nanette"); + assertEquals(lines.get(3), "Duke"); + assertEquals(lines.get(4), "Bond"); + assertEquals(lines.get(5), "Bates"); + } + + @Test + public void unionWithAliasBothSideTest() throws IOException { + String query = + String.format( + Locale.ROOT, + "SELECT firstname AS name FROM %s LIMIT 3 " + + "UNION ALL SELECT lastname AS name FROM %s LIMIT 3", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT); + CSVResult result = executeCsvRequest(query, false, false, false); + List headers = result.getHeaders(); + Assert.assertEquals(1, headers.size()); + Assert.assertTrue(headers.contains("name")); + + List lines = result.getLines(); + Assert.assertEquals(6, lines.size()); + assertEquals(lines.get(0), "Amber"); + assertEquals(lines.get(1), "Hattie"); + assertEquals(lines.get(2), "Nanette"); + assertEquals(lines.get(3), "Duke"); + assertEquals(lines.get(4), "Bond"); + assertEquals(lines.get(5), "Bates"); + } + private void verifyFieldOrder(final String[] expectedFields) throws IOException { final String fields = String.join(", ", expectedFields); diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/MultiQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/MultiQueryIT.java index 84750f8a27..5a6071ea67 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/MultiQueryIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/MultiQueryIT.java @@ -58,7 +58,31 @@ public void unionAllSameRequestOnlyOneRecordTwice() throws IOException { } @Test - public void unionAllOnlyOneRecordEachWithAlias() throws IOException { + public void unionAllOnlyOneRecordWithAliasLeft() throws IOException { + String query = + String.format( + "SELECT firstname as dog_name FROM %s WHERE firstname = 'Amber' " + + "UNION ALL " + + "SELECT dog_name FROM %s WHERE dog_name = 'rex'", + TestsConstants.TEST_INDEX_ACCOUNT, TestsConstants.TEST_INDEX_DOG); + + JSONObject response = executeQuery(query); + assertThat(getHits(response).length(), equalTo(2)); + + Set names = new HashSet<>(); + JSONArray hits = getHits(response); + for (int i = 0; i < hits.length(); i++) { + JSONObject hit = hits.getJSONObject(i); + JSONObject source = getSource(hit); + + names.add(source.getString("dog_name")); + } + + assertThat(names, hasItems("Amber", "rex")); + } + + @Test + public void unionAllOnlyOneRecordWithAliasRight() throws IOException { String query = String.format( "SELECT firstname FROM %s WHERE firstname = 'Amber' " @@ -81,6 +105,30 @@ public void unionAllOnlyOneRecordEachWithAlias() throws IOException { assertThat(names, hasItems("Amber", "rex")); } + @Test + public void unionAllOnlyOneRecordWithAliasBothSide() throws IOException { + String query = + String.format( + "SELECT firstname AS name FROM %s WHERE firstname = 'Amber' " + + "UNION ALL " + + "SELECT dog_name AS name FROM %s WHERE dog_name = 'rex'", + TestsConstants.TEST_INDEX_ACCOUNT, TestsConstants.TEST_INDEX_DOG); + + JSONObject response = executeQuery(query); + assertThat(getHits(response).length(), equalTo(2)); + + Set names = new HashSet<>(); + JSONArray hits = getHits(response); + for (int i = 0; i < hits.length(); i++) { + JSONObject hit = hits.getJSONObject(i); + JSONObject source = getSource(hit); + + names.add(source.getString("name")); + } + + assertThat(names, hasItems("Amber", "rex")); + } + @Test public void unionAllOnlyOneRecordEachWithComplexAlias() throws IOException { String query = diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java index 07883d92f4..c1be9f50bc 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/PrettyFormatResponseIT.java @@ -578,6 +578,80 @@ public void fieldOrderOther() throws IOException { testFieldOrder(expectedFields, expectedValues); } + @Test + public void unionQuery() throws IOException { + JSONObject response = + executeQuery( + String.format( + Locale.ROOT, + "SELECT firstname, lastname FROM %s " + + "UNION ALL SELECT firstname, lastname FROM %s", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT)); + + List fields = Arrays.asList("firstname", "lastname"); + JSONArray dataRows = getDataRows(response); + assertContainsColumns(getSchema(response), fields); + assertContainsData(dataRows, fields); + } + + @Test + public void unionQueryWithAliasLeft() throws IOException { + JSONObject response = + executeQuery( + String.format( + Locale.ROOT, + "SELECT lastname AS firstname FROM %s UNION ALL SELECT firstname FROM %s", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT)); + List fields = List.of("lastname"); + Map aliases = new HashMap<>(); + aliases.put("lastname", "firstname"); + JSONArray schema = getSchema(response); + JSONArray dataRows = getDataRows(response); + assertContainsColumns(schema, fields); + assertContainsAliases(schema, aliases); + assertContainsData(dataRows, fields); + } + + @Test + public void unionQueryWithAliasRight() throws IOException { + JSONObject response = + executeQuery( + String.format( + Locale.ROOT, + "SELECT firstname FROM %s UNION ALL SELECT lastname AS firstname FROM %s", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT)); + List fields = List.of("firstname"); + JSONArray schema = getSchema(response); + JSONArray dataRows = getDataRows(response); + assertContainsColumns(schema, fields); + // Query schema uses first subquery schema, so alias in second subquery doesn't count in. + assertNoAlias(schema); + assertContainsData(dataRows, fields); + } + + @Test + public void unionQueryWithAliasBothSide() throws IOException { + JSONObject response = + executeQuery( + String.format( + Locale.ROOT, + "SELECT firstname AS name FROM %s UNION ALL SELECT lastname AS name FROM %s", + TestsConstants.TEST_INDEX_ACCOUNT, + TestsConstants.TEST_INDEX_ACCOUNT)); + List fields = List.of("firstname"); + Map aliases = new HashMap<>(); + aliases.put("firstname", "name"); + aliases.put("lastname", "name"); + JSONArray schema = getSchema(response); + JSONArray dataRows = getDataRows(response); + assertContainsColumns(schema, fields); + assertContainsAliases(schema, aliases); + assertContainsData(dataRows, fields); + } + private void testFieldOrder(final String[] expectedFields, final Object[] expectedValues) throws IOException { @@ -644,6 +718,13 @@ private void assertContainsAliases(JSONArray schema, Map aliases } } + private void assertNoAlias(JSONArray schema) { + for (int i = 0; i < schema.length(); i++) { + JSONObject column = schema.getJSONObject(i); + assertFalse(column.has("alias")); + } + } + private void assertContainsData(JSONArray dataRows, Collection fields) { assertThat(dataRows.length(), greaterThan(0)); JSONArray row = dataRows.getJSONArray(0); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java index 2faa8cc6e5..d0a4dbb919 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Select.java @@ -47,6 +47,7 @@ public class Select extends Query { private List subQueries; private boolean selectAll = false; private JoinType nestedJoinType = JoinType.COMMA; + private boolean partOfUnion = false; public boolean isQuery = false; public boolean isAggregate = false; @@ -187,4 +188,13 @@ public boolean isOrderdSelect() { public boolean isSelectAll() { return selectAll; } + + public void setPartOfUnion(boolean partOfUnion) { + this.partOfUnion = partOfUnion; + } + + /** Return true is this SELECT is used in UNION */ + public boolean isPartOfUnion() { + return partOfUnion; + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/domain/Union.java b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Union.java new file mode 100644 index 0000000000..7309b93ba9 --- /dev/null +++ b/legacy/src/main/java/org/opensearch/sql/legacy/domain/Union.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.legacy.domain; + +import lombok.Getter; + +@Getter +public class Union extends Query { + private final Select firstTable; + private final Select secondTable; + + public Union(Select firstTable, Select secondTable) { + this.firstTable = firstTable; + this.secondTable = secondTable; + this.firstTable.setPartOfUnion(true); + this.secondTable.setPartOfUnion(true); + } +} diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java index b29369f713..b90e7f031f 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/Schema.java @@ -15,7 +15,6 @@ import org.opensearch.sql.legacy.domain.IndexStatement; public class Schema implements Iterable { - private String indexName; private List columns; @@ -44,7 +43,7 @@ public String getIndexName() { } public List getHeaders() { - return columns.stream().map(column -> column.getName()).collect(Collectors.toList()); + return columns.stream().map(Column::getIdentifier).collect(Collectors.toList()); } public List getColumns() { @@ -166,5 +165,21 @@ public String getIdentifier() { public Type getEnumType() { return type; } + + @Override + public String toString() { + return "Column{" + + "name='" + + name + + '\'' + + ", alias='" + + alias + + '\'' + + ", type=" + + type + + ", identifiedByAlias=" + + identifiedByAlias + + '}'; + } } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java index c60691cb7c..d5cab119ae 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/SelectResultSet.java @@ -49,6 +49,7 @@ import org.opensearch.sql.legacy.domain.Query; import org.opensearch.sql.legacy.domain.Select; import org.opensearch.sql.legacy.domain.TableOnJoinSelect; +import org.opensearch.sql.legacy.domain.Union; import org.opensearch.sql.legacy.esdomain.mapping.FieldMapping; import org.opensearch.sql.legacy.exception.SqlFeatureNotImplementedException; import org.opensearch.sql.legacy.executor.Format; @@ -101,6 +102,10 @@ public SelectResultSet( JoinSelect joinQuery = (JoinSelect) query; loadFromEsState(joinQuery.getFirstTable()); loadFromEsState(joinQuery.getSecondTable()); + } else if (isUnionQuery()) { + Union unionQuery = (Union) query; + loadFromEsState(unionQuery.getFirstTable()); + loadFromEsState(unionQuery.getSecondTable()); } else { loadFromEsState(query); } @@ -180,8 +185,12 @@ private void loadFromEsState(Query query) { Map typeMappings = mappings.get(indexName); this.indexName = this.indexName == null ? indexName : (this.indexName + "|" + indexName); - this.columns.addAll( - renameColumnWithTableAlias(query, populateColumns(query, fieldNames, typeMappings))); + if (isPartOfUnion(query) && !this.columns.isEmpty()) { + // Skip the second SELECT schema for Union query + } else { + this.columns.addAll( + renameColumnWithTableAlias(query, populateColumns(query, fieldNames, typeMappings))); + } } /** Rename column name with table alias as prefix for join query */ @@ -449,12 +458,15 @@ private List populateColumns( if (Schema.hasType(type)) { // If the current field is a group key, we should use alias as the identifier - boolean isGroupKey = false; + boolean identifiedByAlias = false; Select select = (Select) query; if (null != select.getGroupBys() && !select.getGroupBys().isEmpty() && select.getGroupBys().get(0).contains(fieldMap.get(fieldName))) { - isGroupKey = true; + identifiedByAlias = true; + } + if (select.isPartOfUnion()) { + identifiedByAlias = true; } columns.add( @@ -462,7 +474,7 @@ private List populateColumns( fieldName, fetchAlias(fieldName, fieldMap), Schema.Type.valueOf(type), - isGroupKey)); + identifiedByAlias)); } else if (!isSelectAll()) { throw new IllegalArgumentException( String.format("%s fieldName types are currently not supported.", type)); @@ -882,4 +894,12 @@ private Map addMap(String field, Object term) { private boolean isJoinQuery() { return query instanceof JoinSelect; } + + private boolean isUnionQuery() { + return query instanceof Union; + } + + private boolean isPartOfUnion(Query query) { + return query instanceof Select && ((Select) query).isPartOfUnion(); + } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/MultiQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/MultiQueryAction.java index a9eb6113f7..af1ef724ee 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/MultiQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/query/multi/MultiQueryAction.java @@ -12,6 +12,7 @@ import org.opensearch.client.Client; import org.opensearch.sql.legacy.domain.Field; import org.opensearch.sql.legacy.domain.Select; +import org.opensearch.sql.legacy.domain.Union; import org.opensearch.sql.legacy.exception.SqlParseException; import org.opensearch.sql.legacy.query.DefaultQueryAction; import org.opensearch.sql.legacy.query.QueryAction; @@ -22,7 +23,7 @@ public class MultiQueryAction extends QueryAction { private MultiQuerySelect multiQuerySelect; public MultiQueryAction(Client client, MultiQuerySelect multiSelect) { - super(client, null); + super(client, new Union(multiSelect.getFirstSelect(), multiSelect.getSecondSelect())); this.multiQuerySelect = multiSelect; } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/SemanticAnalyzerMultiQueryTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/SemanticAnalyzerMultiQueryTest.java index 319f6c5cfa..ed842c029c 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/SemanticAnalyzerMultiQueryTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/antlr/semantic/SemanticAnalyzerMultiQueryTest.java @@ -86,4 +86,13 @@ public void unionSelectFieldWithExtraStarOfTwoQueriesShouldFail() { expectValidationFailWithErrorMessages( "SELECT age FROM semantics UNION SELECT *, age FROM semantics"); } + + @Test + public void unionSelectWithAliasOfTwoQueriesShouldPass() { + validate( + "SELECT balance AS numeric FROM semantics UNION SELECT balance AS numeric FROM semantics"); + validate("SELECT balance AS numeric FROM semantics UNION SELECT age AS numeric FROM semantics"); + validate("SELECT balance AS age FROM semantics UNION SELECT age FROM semantics"); + validate("SELECT balance FROM semantics UNION SELECT age AS balance FROM semantics"); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java index 7f495935ca..7b96fe85e4 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerExplainTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.legacy.unittest.planner; +import org.junit.Assert; import org.junit.Test; import org.opensearch.sql.legacy.query.planner.core.QueryPlanner; @@ -40,4 +41,11 @@ public void explainInJsonWithDuplicateColumnsPushedDown() { + " WHERE d.region = 'US' AND e.age > 30"); planner.explain(); } + + @Test + public void explainInUnion() { + String explain = + explainUnion("SELECT lastname as name FROM employee UNION SELECT name FROM department"); + Assert.assertTrue(explain.contains("performing UNION on")); + } } diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java index 521b225893..ef5cac3712 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/planner/QueryPlannerTest.java @@ -15,6 +15,7 @@ import com.alibaba.druid.sql.ast.SQLExpr; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; +import com.alibaba.druid.sql.ast.statement.SQLUnionQuery; import com.alibaba.druid.sql.parser.ParserException; import com.alibaba.druid.sql.parser.SQLExprParser; import com.alibaba.druid.sql.parser.Token; @@ -52,6 +53,8 @@ import org.opensearch.sql.legacy.query.SqlElasticRequestBuilder; import org.opensearch.sql.legacy.query.join.BackOffRetryStrategy; import org.opensearch.sql.legacy.query.join.OpenSearchJoinQueryActionFactory; +import org.opensearch.sql.legacy.query.multi.MultiQuerySelect; +import org.opensearch.sql.legacy.query.multi.OpenSearchMultiQueryActionFactory; import org.opensearch.sql.legacy.query.planner.HashJoinQueryPlanRequestBuilder; import org.opensearch.sql.legacy.query.planner.core.QueryPlanner; import org.opensearch.sql.legacy.request.SqlRequest; @@ -195,6 +198,22 @@ protected SqlElasticRequestBuilder createRequestBuilder(String sql) { } } + protected String explainUnion(String sql) { + try { + SQLQueryExpr sqlExpr = (SQLQueryExpr) toSqlExpr(sql); + SQLUnionQuery select = (SQLUnionQuery) sqlExpr.getSubQuery().getQuery(); + MultiQuerySelect multiSelect = + new SqlParser().parseMultiSelect(select); // Ignore handleSubquery() + QueryAction queryAction = + OpenSearchMultiQueryActionFactory.createMultiQueryAction(client, multiSelect); + queryAction.setSqlRequest(new SqlRequest(sql, null)); + SqlElasticRequestBuilder request = queryAction.explain(); + return request.explain(); + } catch (SqlParseException e) { + throw new IllegalStateException("Invalid query: " + sql, e); + } + } + private SQLExpr toSqlExpr(String sql) { SQLExprParser parser = new ElasticSqlExprParser(sql); SQLExpr expr = parser.expr();