Skip to content

Commit

Permalink
Refactor BroadcastRule (#33475)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu authored Oct 30, 2024
1 parent d0aabdc commit 68fe634
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,58 +64,58 @@
public final class BroadcastSQLRouter implements EntranceSQLRouter<BroadcastRule>, DecorateSQLRouter<BroadcastRule> {

@Override
public RouteContext createRouteContext(final QueryContext queryContext, final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database,
final BroadcastRule rule, final ConfigurationProperties props) {
public RouteContext createRouteContext(final QueryContext queryContext, final RuleMetaData globalRuleMetaData,
final ShardingSphereDatabase database, final BroadcastRule rule, final ConfigurationProperties props) {
RouteContext result = new RouteContext();
BroadcastRouteEngineFactory.newInstance(rule, database, queryContext).route(result, rule);
return result;
}

@Override
public void decorateRouteContext(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database, final BroadcastRule broadcastRule,
final ConfigurationProperties props) {
public void decorateRouteContext(final RouteContext routeContext, final QueryContext queryContext,
final ShardingSphereDatabase database, final BroadcastRule rule, final ConfigurationProperties props) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (sqlStatement instanceof TCLStatement) {
routeToAllDatabase(routeContext, broadcastRule);
routeToAllDatabase(routeContext, rule);
}
if (sqlStatement instanceof DDLStatement) {
decorateRouteContextWhenDDLStatement(routeContext, queryContext, database, broadcastRule);
decorateRouteContextWhenDDLStatement(routeContext, queryContext, database, rule);
}
if (sqlStatement instanceof DALStatement && isResourceGroupStatement(sqlStatement)) {
routeToAllDatabaseInstance(routeContext, database, broadcastRule);
routeToAllDatabaseInstance(routeContext, database, rule);
}
if (sqlStatement instanceof DCLStatement && !isDCLForSingleTable(queryContext.getSqlStatementContext())) {
routeToAllDatabaseInstance(routeContext, database, broadcastRule);
routeToAllDatabaseInstance(routeContext, database, rule);
}
}

private void decorateRouteContextWhenDDLStatement(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database, final BroadcastRule broadcastRule) {
private void decorateRouteContextWhenDDLStatement(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database, final BroadcastRule rule) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
if (sqlStatementContext instanceof CursorAvailable) {
if (sqlStatementContext instanceof CloseStatementContext && ((CloseStatementContext) sqlStatementContext).getSqlStatement().isCloseAll()) {
routeToAllDatabase(routeContext, broadcastRule);
routeToAllDatabase(routeContext, rule);
}
return;
}
if (sqlStatementContext instanceof IndexAvailable && !routeContext.getRouteUnits().isEmpty()) {
putAllBroadcastTables(routeContext, broadcastRule, sqlStatementContext);
putAllBroadcastTables(routeContext, rule, sqlStatementContext);
}
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
boolean functionStatement = sqlStatement instanceof CreateFunctionStatement || sqlStatement instanceof AlterFunctionStatement || sqlStatement instanceof DropFunctionStatement;
boolean procedureStatement = sqlStatement instanceof CreateProcedureStatement || sqlStatement instanceof AlterProcedureStatement || sqlStatement instanceof DropProcedureStatement;
if (functionStatement || procedureStatement) {
routeToAllDatabase(routeContext, broadcastRule);
routeToAllDatabase(routeContext, rule);
return;
}
// TODO BEGIN extract db route logic to common database router, eg: DCL in instance route @duanzhengqiang
if (sqlStatement instanceof CreateTablespaceStatement || sqlStatement instanceof AlterTablespaceStatement || sqlStatement instanceof DropTablespaceStatement) {
routeToAllDatabaseInstance(routeContext, database, broadcastRule);
routeToAllDatabaseInstance(routeContext, database, rule);
}
// TODO END extract db route logic to common database router, eg: DCL in instance route
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable ? getTableNames((TableAvailable) sqlStatementContext) : Collections.emptyList();
if (broadcastRule.isAllBroadcastTables(tableNames)) {
routeToAllDatabaseInstance(routeContext, database, broadcastRule);
if (rule.isAllBroadcastTables(tableNames)) {
routeToAllDatabaseInstance(routeContext, database, rule);
}
}

Expand All @@ -128,9 +128,9 @@ private Collection<String> getTableNames(final TableAvailable sqlStatementContex
return result;
}

private void putAllBroadcastTables(final RouteContext routeContext, final BroadcastRule broadcastRule, final SQLStatementContext sqlStatementContext) {
private void putAllBroadcastTables(final RouteContext routeContext, final BroadcastRule rule, final SQLStatementContext sqlStatementContext) {
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames() : Collections.emptyList();
for (String each : broadcastRule.getBroadcastRuleTableNames(tableNames)) {
for (String each : rule.filterBroadcastTableNames(tableNames)) {
for (RouteUnit routeUnit : routeContext.getRouteUnits()) {
routeUnit.getTableMappers().add(new RouteMapper(each, each));
}
Expand All @@ -151,18 +151,18 @@ private boolean isDCLForSingleTable(final SQLStatementContext sqlStatementContex
return false;
}

private void routeToAllDatabaseInstance(final RouteContext routeContext, final ShardingSphereDatabase database, final BroadcastRule broadcastRule) {
private void routeToAllDatabaseInstance(final RouteContext routeContext, final ShardingSphereDatabase database, final BroadcastRule rule) {
routeContext.getRouteUnits().clear();
for (String each : broadcastRule.getDataSourceNames()) {
for (String each : rule.getDataSourceNames()) {
if (database.getResourceMetaData().getAllInstanceDataSourceNames().contains(each)) {
routeContext.getRouteUnits().add(new RouteUnit(new RouteMapper(each, each), Collections.emptyList()));
}
}
}

private void routeToAllDatabase(final RouteContext routeContext, final BroadcastRule broadcastRule) {
private void routeToAllDatabase(final RouteContext routeContext, final BroadcastRule rule) {
routeContext.getRouteUnits().clear();
for (String each : broadcastRule.getDataSourceNames()) {
for (String each : rule.getDataSourceNames()) {
routeContext.getRouteUnits().add(new RouteUnit(new RouteMapper(each, each), Collections.emptyList()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,45 +59,45 @@ public final class BroadcastRouteEngineFactory {
/**
* Create new instance of broadcast routing engine.
*
* @param broadcastRule broadcast rule
* @param rule broadcast rule
* @param database database
* @param queryContext query context
* @return broadcast route engine
*/
public static BroadcastRouteEngine newInstance(final BroadcastRule broadcastRule, final ShardingSphereDatabase database, final QueryContext queryContext) {
public static BroadcastRouteEngine newInstance(final BroadcastRule rule, final ShardingSphereDatabase database, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (sqlStatement instanceof TCLStatement) {
return new BroadcastDatabaseBroadcastRoutingEngine();
}
if (sqlStatement instanceof DDLStatement) {
return sqlStatementContext instanceof CursorAvailable
? getCursorRouteEngine(broadcastRule, sqlStatementContext, queryContext.getConnectionContext())
: getDDLRoutingEngine(broadcastRule, database, queryContext);
? getCursorRouteEngine(rule, sqlStatementContext, queryContext.getConnectionContext())
: getDDLRoutingEngine(rule, database, queryContext);
}
if (sqlStatement instanceof DALStatement) {
return getDALRoutingEngine(broadcastRule, queryContext);
return getDALRoutingEngine(rule, queryContext);
}
if (sqlStatement instanceof DCLStatement) {
return getDCLRoutingEngine(broadcastRule, queryContext);
return getDCLRoutingEngine(rule, queryContext);
}
return getDQLRoutingEngine(broadcastRule, queryContext);
return getDQLRoutingEngine(rule, queryContext);
}

private static BroadcastRouteEngine getCursorRouteEngine(final BroadcastRule broadcastRule, final SQLStatementContext sqlStatementContext, final ConnectionContext connectionContext) {
private static BroadcastRouteEngine getCursorRouteEngine(final BroadcastRule rule, final SQLStatementContext sqlStatementContext, final ConnectionContext connectionContext) {
if (sqlStatementContext instanceof CloseStatementContext && ((CloseStatementContext) sqlStatementContext).getSqlStatement().isCloseAll()) {
return new BroadcastDatabaseBroadcastRoutingEngine();
}
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable
? ((TableAvailable) sqlStatementContext).getTablesContext().getSimpleTables().stream().map(each -> each.getTableName().getIdentifier().getValue()).collect(Collectors.toSet())
: Collections.emptyList();
return broadcastRule.isAllBroadcastTables(tableNames) ? new BroadcastUnicastRoutingEngine(sqlStatementContext, tableNames, connectionContext) : new BroadcastIgnoreRoutingEngine();
return rule.isAllBroadcastTables(tableNames) ? new BroadcastUnicastRoutingEngine(sqlStatementContext, tableNames, connectionContext) : new BroadcastIgnoreRoutingEngine();
}

private static BroadcastRouteEngine getDDLRoutingEngine(final BroadcastRule broadcastRule, final ShardingSphereDatabase database, final QueryContext queryContext) {
private static BroadcastRouteEngine getDDLRoutingEngine(final BroadcastRule rule, final ShardingSphereDatabase database, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
Collection<String> tableNames = getTableNames(database, sqlStatementContext);
if (broadcastRule.isAllBroadcastTables(tableNames)) {
if (rule.isAllBroadcastTables(tableNames)) {
return new BroadcastTableBroadcastRoutingEngine(tableNames);
}
return new BroadcastIgnoreRoutingEngine();
Expand All @@ -123,26 +123,26 @@ private static Collection<String> getTableNames(final ShardingSphereDatabase dat
return result;
}

private static BroadcastRouteEngine getDALRoutingEngine(final BroadcastRule broadcastRule, final QueryContext queryContext) {
private static BroadcastRouteEngine getDALRoutingEngine(final BroadcastRule rule, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (sqlStatement instanceof MySQLUseStatement) {
return new BroadcastIgnoreRoutingEngine();
}
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames() : Collections.emptyList();
Collection<String> broadcastRuleTableNames = broadcastRule.getBroadcastRuleTableNames(tableNames);
if (broadcastRule.isAllBroadcastTables(broadcastRuleTableNames)) {
return new BroadcastTableBroadcastRoutingEngine(broadcastRuleTableNames);
Collection<String> broadcastTableNames = rule.filterBroadcastTableNames(tableNames);
if (rule.isAllBroadcastTables(broadcastTableNames)) {
return new BroadcastTableBroadcastRoutingEngine(broadcastTableNames);
}
return new BroadcastIgnoreRoutingEngine();
}

private static BroadcastRouteEngine getDCLRoutingEngine(final BroadcastRule broadcastRule, final QueryContext queryContext) {
private static BroadcastRouteEngine getDCLRoutingEngine(final BroadcastRule rule, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames() : Collections.emptyList();
Collection<String> broadcastRuleTableNames = broadcastRule.getBroadcastRuleTableNames(tableNames);
if (isDCLForSingleTable(sqlStatementContext) && !broadcastRuleTableNames.isEmpty() || broadcastRule.isAllBroadcastTables(broadcastRuleTableNames)) {
return new BroadcastTableBroadcastRoutingEngine(broadcastRuleTableNames);
Collection<String> broadcastTableNames = rule.filterBroadcastTableNames(tableNames);
if (isDCLForSingleTable(sqlStatementContext) && !broadcastTableNames.isEmpty() || rule.isAllBroadcastTables(broadcastTableNames)) {
return new BroadcastTableBroadcastRoutingEngine(broadcastTableNames);
}
return new BroadcastIgnoreRoutingEngine();
}
Expand All @@ -156,10 +156,10 @@ private static boolean isDCLForSingleTable(final SQLStatementContext sqlStatemen
return false;
}

private static BroadcastRouteEngine getDQLRoutingEngine(final BroadcastRule broadcastRule, final QueryContext queryContext) {
private static BroadcastRouteEngine getDQLRoutingEngine(final BroadcastRule rule, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames() : Collections.emptyList();
if (broadcastRule.isAllBroadcastTables(tableNames)) {
if (rule.isAllBroadcastTables(tableNames)) {
return sqlStatementContext.getSqlStatement() instanceof SelectStatement
? new BroadcastUnicastRoutingEngine(sqlStatementContext, tableNames, queryContext.getConnectionContext())
: new BroadcastDatabaseBroadcastRoutingEngine();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,55 +20,54 @@
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.broadcast.route.engine.type.BroadcastRouteEngine;
import org.apache.shardingsphere.broadcast.rule.BroadcastRule;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;

/**
* Broadcast routing engine for table.
*/
@HighFrequencyInvocation
@RequiredArgsConstructor
public final class BroadcastTableBroadcastRoutingEngine implements BroadcastRouteEngine {

private final Collection<String> broadcastRuleTableNames;
private final Collection<String> broadcastTableNames;

@Override
public RouteContext route(final RouteContext routeContext, final BroadcastRule broadcastRule) {
Collection<String> logicTableNames = broadcastRule.getBroadcastRuleTableNames(broadcastRuleTableNames);
if (logicTableNames.isEmpty()) {
routeContext.getRouteUnits().addAll(getRouteContext(broadcastRule).getRouteUnits());
} else {
routeContext.getRouteUnits().addAll(getRouteContext(broadcastRule, logicTableNames).getRouteUnits());
}
public RouteContext route(final RouteContext routeContext, final BroadcastRule rule) {
Collection<String> logicTableNames = rule.filterBroadcastTableNames(broadcastTableNames);
RouteContext toBeAddedRouteContext = logicTableNames.isEmpty() ? getRouteContext(rule) : getRouteContext(rule, logicTableNames);
routeContext.getRouteUnits().addAll(toBeAddedRouteContext.getRouteUnits());
return routeContext;
}

private RouteContext getRouteContext(final BroadcastRule broadcastRule) {
private RouteContext getRouteContext(final BroadcastRule rule) {
RouteContext result = new RouteContext();
for (String each : broadcastRule.getDataSourceNames()) {
for (String each : rule.getDataSourceNames()) {
result.getRouteUnits().add(new RouteUnit(new RouteMapper(each, each), Collections.singletonList(new RouteMapper("", ""))));
}
return result;
}

private RouteContext getRouteContext(final BroadcastRule broadcastRule, final Collection<String> logicTableNames) {
private RouteContext getRouteContext(final BroadcastRule rule, final Collection<String> logicTableNames) {
RouteContext result = new RouteContext();
Collection<RouteMapper> tableRouteMappers = getTableRouteMappers(logicTableNames);
for (String each : broadcastRule.getDataSourceNames()) {
for (String each : rule.getDataSourceNames()) {
RouteMapper dataSourceMapper = new RouteMapper(each, each);
result.getRouteUnits().add(new RouteUnit(dataSourceMapper, tableRouteMappers));
}
return result;
}

private Collection<RouteMapper> getTableRouteMappers(final Collection<String> logicTableNames) {
Collection<RouteMapper> result = new ArrayList<>(logicTableNames.size());
for (String logicTableName : logicTableNames) {
result.add(new RouteMapper(logicTableName, logicTableName));
Collection<RouteMapper> result = new LinkedList<>();
for (String each : logicTableNames) {
result.add(new RouteMapper(each, each));
}
return result;
}
Expand Down
Loading

0 comments on commit 68fe634

Please sign in to comment.