Skip to content

Commit

Permalink
fix: SQL COUNT with GROUP BY to prevent incorrect row returns (#33380)
Browse files Browse the repository at this point in the history
* fix: SQL COUNT with GROUP BY to prevent incorrect row returns

* test: Add test cases for empty result with GROUP BY and ORDER BY

* fix: update db types and scenario type for e2e test case

* fix: update column names for e2e test

* fix: fix unit tests for empty result set

* test: add e2e tests for issue #4680

* fix: fix e2e tests for issue #4680

* update e2e tests for isssue #4680

* fix: fix failing checks

* fix: update conditions for group by and aggregate functions
  • Loading branch information
Malaydewangan09 authored Oct 31, 2024
1 parent 68fe634 commit f7a126a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.merge.dql.groupby;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationDistinctProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
Expand Down Expand Up @@ -140,10 +141,18 @@ private boolean getValueCaseSensitiveFromTables(final QueryResult queryResult,

private List<MemoryQueryResultRow> getMemoryResultSetRows(final SelectStatementContext selectStatementContext,
final Map<GroupByValue, MemoryQueryResultRow> dataMap, final List<Boolean> valueCaseSensitive) {
Object[] data = generateReturnData(selectStatementContext);

if (dataMap.isEmpty()) {
Object[] data = generateReturnData(selectStatementContext);
return selectStatementContext.getProjectionsContext().getAggregationProjections().isEmpty() ? Collections.emptyList() : Collections.singletonList(new MemoryQueryResultRow(data));
boolean hasGroupBy = !selectStatementContext.getGroupByContext().getItems().isEmpty();
boolean hasAggregations = !selectStatementContext.getProjectionsContext().getAggregationProjections().isEmpty();

if (hasGroupBy || !hasAggregations) {
return Collections.emptyList();
}
return Collections.singletonList(new MemoryQueryResultRow(data));
}

List<MemoryQueryResultRow> result = new ArrayList<>(dataMap.values());
result.sort(new GroupByRowComparator(selectStatementContext, valueCaseSensitive));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
Expand All @@ -33,7 +34,6 @@
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sharding.merge.dql.ShardingDQLResultMerger;
import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.sql.parser.statement.core.enums.OrderDirection;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment;
Expand Down Expand Up @@ -62,7 +62,6 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
Expand All @@ -80,9 +79,6 @@ void assertNextForResultSetsAllEmpty() throws SQLException {
when(database.getName()).thenReturn("db_schema");
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(createQueryResult(), createQueryResult(), createQueryResult()), createSelectStatementContext(), database, mock(ConnectionContext.class));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(0));
assertNull(actual.getValue(2, Object.class));
assertFalse(actual.next());
}

Expand Down Expand Up @@ -217,4 +213,57 @@ void assertNextForDistinctShorthandResultSetsEmpty() throws SQLException {
MergedResult actual = merger.merge(Arrays.asList(queryResult, queryResult, queryResult), createSelectStatementContext(database), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

@Test
void assertNextForEmptyResultWithCountAndGroupBy() throws SQLException {
when(database.getName()).thenReturn("db_schema");
QueryResult queryResult1 = createEmptyQueryResultWithCountGroupBy();
QueryResult queryResult2 = createEmptyQueryResultWithCountGroupBy();
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(queryResult1, queryResult2), createSelectStatementContextForCountGroupBy(), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

@Test
void assertNextForEmptyResultWithCountGroupByDifferentOrderBy() throws SQLException {
when(database.getName()).thenReturn("db_schema");
QueryResult queryResult = createEmptyQueryResultWithCountGroupBy();
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Collections.singletonList(queryResult), createSelectStatementContextForCountGroupByDifferentOrderBy(), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

private QueryResult createEmptyQueryResultWithCountGroupBy() throws SQLException {
QueryResult result = mock(QueryResult.class, RETURNS_DEEP_STUBS);
when(result.getMetaData().getColumnCount()).thenReturn(3);
when(result.getMetaData().getColumnLabel(1)).thenReturn("COUNT(*)");
when(result.getMetaData().getColumnLabel(2)).thenReturn("user_id");
when(result.getMetaData().getColumnLabel(3)).thenReturn("order_id");
when(result.next()).thenReturn(false);
return result;
}

private SelectStatementContext createSelectStatementContextForCountGroupBy() {
SelectStatement selectStatement = new MySQLSelectStatement();
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
projectionsSegment.getProjections().add(new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT(*)"));
selectStatement.setGroupBy(new GroupBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setOrderBy(new OrderBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setProjections(projectionsSegment);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new SelectStatementContext(createShardingSphereMetaData(database), Collections.emptyList(), selectStatement, DefaultDatabase.LOGIC_NAME, Collections.emptyList());
}

private SelectStatementContext createSelectStatementContextForCountGroupByDifferentOrderBy() {
SelectStatement selectStatement = new MySQLSelectStatement();
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
projectionsSegment.getProjections().add(new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT(*)"));
selectStatement.setGroupBy(new GroupBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setOrderBy(new OrderBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 3, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setProjections(projectionsSegment);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new SelectStatementContext(createShardingSphereMetaData(database), Collections.emptyList(), selectStatement, DefaultDatabase.LOGIC_NAME, Collections.emptyList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT COUNT(1) FROM t_order WHERE 1 = 2 GROUP BY order_id,user_id ORDER BY user_id" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tbl"
scenario-comments="Test GROUP BY with ORDER BY different fields returns no rows when WHERE condition matches no data">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT AVG(DISTINCT order_id) AS avg_id FROM t_order WHERE 1 = 2" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tbl"
scenario-comments="Test AVG DISTINCT returns NULL when no data matches">
<assertion expected-data-source-name="read_dataset" />
Expand All @@ -72,7 +77,7 @@
scenario-comments="Test MIN DISTINCT returns NULL when no data matches">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT MAX(DISTINCT order_id) AS max_id FROM t_order WHERE 1 = 2" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tbl"
scenario-comments="Test MAX DISTINCT returns NULL when no data matches">
<assertion expected-data-source-name="read_dataset" />
Expand Down

0 comments on commit f7a126a

Please sign in to comment.