Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: SQL COUNT with GROUP BY to prevent incorrect row returns #33380

Merged
merged 13 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.merge.dql.groupby;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;

import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationDistinctProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
Expand Down Expand Up @@ -140,10 +141,18 @@ private boolean getValueCaseSensitiveFromTables(final QueryResult queryResult,

private List<MemoryQueryResultRow> getMemoryResultSetRows(final SelectStatementContext selectStatementContext,
final Map<GroupByValue, MemoryQueryResultRow> dataMap, final List<Boolean> valueCaseSensitive) {
Object[] data = generateReturnData(selectStatementContext);

if (dataMap.isEmpty()) {
Object[] data = generateReturnData(selectStatementContext);
return selectStatementContext.getProjectionsContext().getAggregationProjections().isEmpty() ? Collections.emptyList() : Collections.singletonList(new MemoryQueryResultRow(data));
boolean hasGroupBy = !selectStatementContext.getGroupByContext().getItems().isEmpty();
boolean hasAggregations = !selectStatementContext.getProjectionsContext().getAggregationProjections().isEmpty();

if (hasGroupBy || !hasAggregations) {
return Collections.emptyList();
}
return Collections.singletonList(new MemoryQueryResultRow(data));
}

List<MemoryQueryResultRow> result = new ArrayList<>(dataMap.values());
result.sort(new GroupByRowComparator(selectStatementContext, valueCaseSensitive));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
Expand All @@ -33,7 +34,6 @@
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sharding.merge.dql.ShardingDQLResultMerger;
import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.sql.parser.statement.core.enums.OrderDirection;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment;
Expand Down Expand Up @@ -62,7 +62,6 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
Expand All @@ -80,9 +79,6 @@ void assertNextForResultSetsAllEmpty() throws SQLException {
when(database.getName()).thenReturn("db_schema");
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(createQueryResult(), createQueryResult(), createQueryResult()), createSelectStatementContext(), database, mock(ConnectionContext.class));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(0));
assertNull(actual.getValue(2, Object.class));
assertFalse(actual.next());
}

Expand Down Expand Up @@ -217,4 +213,57 @@ void assertNextForDistinctShorthandResultSetsEmpty() throws SQLException {
MergedResult actual = merger.merge(Arrays.asList(queryResult, queryResult, queryResult), createSelectStatementContext(database), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

@Test
void assertNextForEmptyResultWithCountAndGroupBy() throws SQLException {
when(database.getName()).thenReturn("db_schema");
QueryResult queryResult1 = createEmptyQueryResultWithCountGroupBy();
QueryResult queryResult2 = createEmptyQueryResultWithCountGroupBy();
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(queryResult1, queryResult2), createSelectStatementContextForCountGroupBy(), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

@Test
void assertNextForEmptyResultWithCountGroupByDifferentOrderBy() throws SQLException {
when(database.getName()).thenReturn("db_schema");
QueryResult queryResult = createEmptyQueryResultWithCountGroupBy();
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Collections.singletonList(queryResult), createSelectStatementContextForCountGroupByDifferentOrderBy(), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

private QueryResult createEmptyQueryResultWithCountGroupBy() throws SQLException {
QueryResult result = mock(QueryResult.class, RETURNS_DEEP_STUBS);
when(result.getMetaData().getColumnCount()).thenReturn(3);
when(result.getMetaData().getColumnLabel(1)).thenReturn("COUNT(*)");
when(result.getMetaData().getColumnLabel(2)).thenReturn("user_id");
when(result.getMetaData().getColumnLabel(3)).thenReturn("order_id");
when(result.next()).thenReturn(false);
return result;
}

private SelectStatementContext createSelectStatementContextForCountGroupBy() {
SelectStatement selectStatement = new MySQLSelectStatement();
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
projectionsSegment.getProjections().add(new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT(*)"));
selectStatement.setGroupBy(new GroupBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setOrderBy(new OrderBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setProjections(projectionsSegment);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new SelectStatementContext(createShardingSphereMetaData(database), Collections.emptyList(), selectStatement, DefaultDatabase.LOGIC_NAME, Collections.emptyList());
}

private SelectStatementContext createSelectStatementContextForCountGroupByDifferentOrderBy() {
SelectStatement selectStatement = new MySQLSelectStatement();
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
projectionsSegment.getProjections().add(new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT(*)"));
selectStatement.setGroupBy(new GroupBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setOrderBy(new OrderBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 3, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setProjections(projectionsSegment);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new SelectStatementContext(createShardingSphereMetaData(database), Collections.emptyList(), selectStatement, DefaultDatabase.LOGIC_NAME, Collections.emptyList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT COUNT(1) FROM t_order WHERE 1 = 2 GROUP BY order_id,user_id ORDER BY user_id" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tbl"
scenario-comments="Test GROUP BY with ORDER BY different fields returns no rows when WHERE condition matches no data">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT AVG(DISTINCT order_id) AS avg_id FROM t_order WHERE 1 = 2" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tbl"
scenario-comments="Test AVG DISTINCT returns NULL when no data matches">
<assertion expected-data-source-name="read_dataset" />
Expand All @@ -72,7 +77,7 @@
scenario-comments="Test MIN DISTINCT returns NULL when no data matches">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT MAX(DISTINCT order_id) AS max_id FROM t_order WHERE 1 = 2" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tbl"
scenario-comments="Test MAX DISTINCT returns NULL when no data matches">
<assertion expected-data-source-name="read_dataset" />
Expand Down