前言
以前研究过如何使用ShardingJdbc,使用ShardingJdbc进行分库分表,但是原理方面没有细致的深入了解。如果仅仅了解如何使用的话,对于改造和排查问题,其实都是不够的,所以跟踪源码了解其运行原理是很重要的。
Demo
依赖
<!-- 分库分表 -->
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>shardingsphere-jdbc-core-spring-boot-starter</artifactId>
<version>5.1.2</version>
</dependency>
配置文件
spring:
shardingsphere:
mode:
type: Memory # 内存模式,元数据保存在当前进程中
datasource:
names: test0,test1 # 数据源名称,这里有两个
test0: # 跟上面的数据源对应
type: com.alibaba.druid.pool.DruidDataSource # 连接池
url: jdbc:mysql://127.0.0.1:3306/test0 # 连接url
username: root
password: root
test1: # 跟上面的数据源对应
type: com.alibaba.druid.pool.DruidDataSource
url: jdbc:mysql://127.0.0.1:3306/test1
username: root
password: root
rules:
sharding:
tables:
user: # 这个可以随便取,问题不大
actual-data-nodes: test$->{0..1}.user$->{0..2} # 实际节点名称,格式为 库名$->{0..n1}.表名$->{0..n2}
# 其中n1、n2分别为库数量-1和表数量-1
# 也可以使用${0..n1}的形式,但是会与Spring属性文件占位符冲突
# 所以使用$->{0..n1}的形式
database-strategy: # 分库策略
standard: # 标准分库策略
sharding-column: age # 分库列名
sharding-algorithm-name: age-mod # 分库算法名字
table-strategy: # 分表策略
standard: # 标准分表策略
sharding-column: id # 分表列名
sharding-algorithm-name: id-mod # 分表算法名字
sharding-algorithms: # 配置分库和分表的算法
age-mod: # 分库算法名字
type: MOD # 算法类型为取模
props: # 算法配置的键名,所有算法配置都需要在props下
sharding-count: 2 # 分片数量
id-mod: # 分表算法名字
type: MOD # 算法类型为取模
props: # 算法配置的键名,所有算法配置都需要在props下
sharding-count: 3 # 分片数量
props:
sql-show: true # 打印SQL
创表语句
-- auto-generated definition
create table user
(
id bigint not null
primary key,
name varchar(200) null,
age int null
);
实体类
@Data
@Builder
public class User {
@TableId(type = IdType.ASSIGN_ID)
private Long id;
private String name;
private Integer age;
}
Mapper
public interface UserMapper extends BaseMapper<User> {
}
控制器
@RestController
@RequestMapping("/user")
public class UserController {
@Autowired
private UserMapper userMapper;
@GetMapping("/insert")
public boolean insert() {
userMapper.insert(User.builder().name("name").age(new Random().nextInt(100) + 1).build());
return true;
}
@GetMapping("/select")
public List<User> select() {
return userMapper.selectList(new QueryWrapper<>());
}
}
源码解析
解析
PreparedStatementHandler#instantiateStatement
,创建Statement
会用到sql解析。
protected Statement instantiateStatement(Connection connection) throws SQLException {
String sql = this.boundSql.getSql();
if (this.mappedStatement.getKeyGenerator() instanceof Jdbc3KeyGenerator) {
String[] keyColumnNames = this.mappedStatement.getKeyColumns();
return keyColumnNames == null ? connection.prepareStatement(sql, 1) : connection.prepareStatement(sql, keyColumnNames);
} else {
return this.mappedStatement.getResultSetType() == ResultSetType.DEFAULT ? connection.prepareStatement(sql) : connection.prepareStatement(sql, this.mappedStatement.getResultSetType().getValue(), 1007);
}
}
ShardingSpherePreparedStatement
public ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql) throws SQLException {
this(connection, sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT, false);
}
private ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql,
final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys) throws SQLException {
if (Strings.isNullOrEmpty(sql)) {
SQLExceptionErrorCode errorCode = SQLExceptionErrorCode.SQL_STRING_NULL_OR_EMPTY;
throw new SQLException(errorCode.getErrorMessage(), errorCode.getSqlState(), errorCode.getErrorCode());
}
this.connection = connection;
metaDataContexts = connection.getContextManager().getMetaDataContexts();
this.sql = sql;
statements = new ArrayList<>();
parameterSets = new ArrayList<>();
Optional<SQLParserRule> sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().findSingleRule(SQLParserRule.class);
Preconditions.checkState(sqlParserRule.isPresent());
ShardingSphereSQLParserEngine sqlParserEngine = sqlParserRule.get().getSQLParserEngine(
DatabaseTypeEngine.getTrunkDatabaseTypeName(metaDataContexts.getMetaData().getDatabases().get(connection.getDatabaseName()).getResource().getDatabaseType()));
sqlStatement = sqlParserEngine.parse(sql, true);
sqlStatementContext = SQLStatementContextFactory.newInstance(metaDataContexts.getMetaData().getDatabases(), sqlStatement, connection.getDatabaseName());
parameterMetaData = new ShardingSphereParameterMetaData(sqlStatement);
statementOption = returnGeneratedKeys ? new StatementOption(true) : new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
executor = new DriverExecutor(connection);
JDBCExecutor jdbcExecutor = new JDBCExecutor(connection.getContextManager().getExecutorEngine(), connection.isHoldTransaction());
batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaDataContexts, jdbcExecutor, connection.getDatabaseName());
kernelProcessor = new KernelProcessor();
statementsCacheable = isStatementsCacheable(metaDataContexts.getMetaData().getDatabases().get(connection.getDatabaseName()).getRuleMetaData().getConfigurations());
trafficRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().findSingleRule(TrafficRule.class).orElse(null);
statementManager = new StatementManager();
}
SQLStatementParserEngine#parse
,判断是否有缓存。
public SQLStatement parse(String sql, boolean useCache) {
return useCache ? (SQLStatement)this.sqlStatementCache.get(sql) : this.sqlStatementParserExecutor.parse(sql);
}
SQLParserExecutor#parse
,解析sql,处理成多个片段。
public ParseASTNode parse(String sql) {
ParseASTNode result = this.twoPhaseParse(sql);
if (result.getRootNode() instanceof ErrorNode) {
throw new SQLParsingException("Unsupported SQL of `%s`", new Object[]{sql});
} else {
return result;
}
}
路由
PreparedStatementHandler#query
,Mybatis执行查询获取到的PreparedStatement
是ShardingSpherePreparedStatement
。
public <E> List<E> query(Statement statement, ResultHandler resultHandler) throws SQLException {
PreparedStatement ps = (PreparedStatement)statement;
ps.execute();
return this.resultSetHandler.handleResultSets(ps);
}
ShardingSpherePreparedStatement#execute
,获取逻辑SQL,调试获取到的sql是INSERT INTO user ( id,name,age ) VALUES ( ?,?,? )
。
@Override
public boolean execute() throws SQLException {
try {
if (statementsCacheable && !statements.isEmpty()) {
resetParameters();
return statements.iterator().next().execute();
}
clearPrevious();
LogicSQL logicSQL = createLogicSQL();
trafficContext = getTrafficContext(logicSQL);
if (trafficContext.isMatchTraffic()) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficContext, logicSQL);
return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).execute());
}
executionContext = createExecutionContext(logicSQL);
if (hasRawExecutionRule()) {
// TODO process getStatement
Collection<ExecuteResult> executeResults = executor.getRawExecutor().execute(createRawExecutionGroupContext(), executionContext.getLogicSQL(), new RawSQLExecutorCallback());
return executeResults.iterator().next() instanceof QueryResult;
}
if (executionContext.getRouteContext().isFederated()) {
ResultSet resultSet = executeFederationQuery(logicSQL);
return null != resultSet;
}
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext();
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().execute(executionGroupContext,
executionContext.getLogicSQL(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback());
} catch (SQLException ex) {
handleExceptionInTransaction(connection, metaDataContexts);
throw ex;
} finally {
clearBatch();
}
}
KernelProcessor#generateExecutionContext
,处理sql主要是进行路由,改写
public ExecutionContext generateExecutionContext(LogicSQL logicSQL, ShardingSphereDatabase database, ConfigurationProperties props) {
RouteContext routeContext = this.route(logicSQL, database, props);
SQLRewriteResult rewriteResult = this.rewrite(logicSQL, database, props, routeContext);
ExecutionContext result = this.createExecutionContext(logicSQL, database, routeContext, rewriteResult);
this.logSQL(logicSQL, props, result);
return result;
}
SQLRouteEngine#route
,SQLRouteEngine
调用SQLRouteExecutor
进行路由,获取到的是PartialSQLRouteExecutor
。
public RouteContext route(final LogicSQL logicSQL, final ShardingSphereDatabase database) {
SQLRouteExecutor executor = isNeedAllSchemas(logicSQL.getSqlStatementContext().getSqlStatement()) ? new AllSQLRouteExecutor() : new PartialSQLRouteExecutor(rules, props);
return executor.route(logicSQL, database);
}
PartialSQLRouteExecutor
,根据SQLRouterFactory
获取routers
。
public PartialSQLRouteExecutor(final Collection<ShardingSphereRule> rules, final ConfigurationProperties props) {
this.props = props;
routers = SQLRouterFactory.getInstances(rules);
}
SQLRouterFactory#getInstances
,通过SPI的方式获取到SQLRouter
。
public static Map<ShardingSphereRule, SQLRouter> getInstances(final Collection<ShardingSphereRule> rules) {
return OrderedSPIRegistry.getRegisteredServices(SQLRouter.class, rules);
}
PartialSQLRouteExecutor#route
,根据ShardingSphereRule
和SQLRouter
获取结果。
public RouteContext route(final LogicSQL logicSQL, final ShardingSphereDatabase database) {
RouteContext result = new RouteContext();
Optional<String> dataSourceName = findDataSourceByHint(logicSQL.getSqlStatementContext(), database.getResource().getDataSources());
if (dataSourceName.isPresent()) {
result.getRouteUnits().add(new RouteUnit(new RouteMapper(dataSourceName.get(), dataSourceName.get()), Collections.emptyList()));
return result;
}
for (Entry<ShardingSphereRule, SQLRouter> entry : routers.entrySet()) {
if (result.getRouteUnits().isEmpty()) {
result = entry.getValue().createRouteContext(logicSQL, database, entry.getKey(), props);
} else {
entry.getValue().decorateRouteContext(result, logicSQL, database, entry.getKey(), props);
}
}
if (result.getRouteUnits().isEmpty() && 1 == database.getResource().getDataSources().size()) {
String singleDataSourceName = database.getResource().getDataSources().keySet().iterator().next();
result.getRouteUnits().add(new RouteUnit(new RouteMapper(singleDataSourceName, singleDataSourceName), Collections.emptyList()));
}
return result;
}
ShardingSQLRouter#createRouteContext
,根据对应的分库分表规则计算结果。
public RouteContext createRouteContext(LogicSQL logicSQL, ShardingSphereDatabase database, ShardingRule rule, ConfigurationProperties props) {
SQLStatement sqlStatement = logicSQL.getSqlStatementContext().getSqlStatement();
ShardingConditions shardingConditions = this.createShardingConditions(logicSQL, database, rule);
Optional<ShardingStatementValidator> validator = ShardingStatementValidatorFactory.newInstance(sqlStatement, shardingConditions);
validator.ifPresent((optional) -> {
optional.preValidate(rule, logicSQL.getSqlStatementContext(), logicSQL.getParameters(), database);
});
if (sqlStatement instanceof DMLStatement && shardingConditions.isNeedMerge()) {
shardingConditions.merge();
}
RouteContext result = ShardingRouteEngineFactory.newInstance(rule, database, logicSQL.getSqlStatementContext(), shardingConditions, props).route(rule);
validator.ifPresent((optional) -> {
optional.postValidate(rule, logicSQL.getSqlStatementContext(), logicSQL.getParameters(), database, props, result);
});
return result;
}
ShardingStandardRoutingEngine#route
,获取到DataNode
的集合。
public RouteContext route(ShardingRule shardingRule) {
RouteContext result = new RouteContext();
Collection<DataNode> dataNodes = this.getDataNodes(shardingRule, shardingRule.getTableRule(this.logicTableName));
result.getOriginalDataNodes().addAll(this.originalDataNodes);
Iterator var4 = dataNodes.iterator();
while(var4.hasNext()) {
DataNode each = (DataNode)var4.next();
result.getRouteUnits().add(new RouteUnit(new RouteMapper(each.getDataSourceName(), each.getDataSourceName()), Collections.singleton(new RouteMapper(this.logicTableName, each.getTableName()))));
}
return result;
}
ShardingStandardRoutingEngine#getDataNodes
,根据shardingCondition的数据进行路由,根据分库策略获取库名,再根据库名和分表策略获取表名。
private Collection<DataNode> getDataNodes(final ShardingRule shardingRule, final TableRule tableRule) {
ShardingStrategy databaseShardingStrategy = createShardingStrategy(shardingRule.getDatabaseShardingStrategyConfiguration(tableRule),
shardingRule.getShardingAlgorithms(), shardingRule.getDefaultShardingColumn());
ShardingStrategy tableShardingStrategy = createShardingStrategy(shardingRule.getTableShardingStrategyConfiguration(tableRule),
shardingRule.getShardingAlgorithms(), shardingRule.getDefaultShardingColumn());
if (isRoutingByHint(shardingRule, tableRule)) {
return routeByHint(tableRule, databaseShardingStrategy, tableShardingStrategy);
}
if (isRoutingByShardingConditions(shardingRule, tableRule)) {
return routeByShardingConditions(shardingRule, tableRule, databaseShardingStrategy, tableShardingStrategy);
}
return routeByMixedConditions(shardingRule, tableRule, databaseShardingStrategy, tableShardingStrategy);
}
private Collection<DataNode> routeByShardingConditions(final ShardingRule shardingRule, final TableRule tableRule,
final ShardingStrategy databaseShardingStrategy, final ShardingStrategy tableShardingStrategy) {
return shardingConditions.getConditions().isEmpty()
? route0(tableRule, databaseShardingStrategy, Collections.emptyList(), tableShardingStrategy, Collections.emptyList())
: routeByShardingConditionsWithCondition(shardingRule, tableRule, databaseShardingStrategy, tableShardingStrategy);
}
private Collection<DataNode> routeByShardingConditionsWithCondition(final ShardingRule shardingRule, final TableRule tableRule,
final ShardingStrategy databaseShardingStrategy, final ShardingStrategy tableShardingStrategy) {
Collection<DataNode> result = new LinkedList<>();
for (ShardingCondition each : shardingConditions.getConditions()) {
Collection<DataNode> dataNodes = route0(tableRule,
databaseShardingStrategy, getShardingValuesFromShardingConditions(shardingRule, databaseShardingStrategy.getShardingColumns(), each),
tableShardingStrategy, getShardingValuesFromShardingConditions(shardingRule, tableShardingStrategy.getShardingColumns(), each));
result.addAll(dataNodes);
originalDataNodes.add(dataNodes);
}
return result;
}
private Collection<DataNode> route0(final TableRule tableRule,
final ShardingStrategy databaseShardingStrategy, final List<ShardingConditionValue> databaseShardingValues,
final ShardingStrategy tableShardingStrategy, final List<ShardingConditionValue> tableShardingValues) {
Collection<String> routedDataSources = routeDataSources(tableRule, databaseShardingStrategy, databaseShardingValues);
Collection<DataNode> result = new LinkedList<>();
for (String each : routedDataSources) {
result.addAll(routeTables(tableRule, each, tableShardingStrategy, tableShardingValues));
}
return result;
}
ShardingStandardRoutingEngine#routeDataSources
private Collection<String> routeDataSources(final TableRule tableRule, final ShardingStrategy databaseShardingStrategy, final List<ShardingConditionValue> databaseShardingValues) {
if (databaseShardingValues.isEmpty()) {
return tableRule.getActualDatasourceNames();
}
Collection<String> result = databaseShardingStrategy.doSharding(tableRule.getActualDatasourceNames(), databaseShardingValues, tableRule.getDataSourceDataNode(), properties);
Preconditions.checkState(!result.isEmpty(), "No database route info");
Preconditions.checkState(tableRule.getActualDatasourceNames().containsAll(result),
"Some routed data sources do not belong to configured data sources. routed data sources: `%s`, configured data sources: `%s`", result, tableRule.getActualDatasourceNames());
return result;
}
StandardShardingStrategy#doSharding()
@Override
public Collection<String> doSharding(final Collection<String> availableTargetNames, final Collection<ShardingConditionValue> shardingConditionValues,
final DataNodeInfo dataNodeInfo, final ConfigurationProperties props) {
ShardingConditionValue shardingConditionValue = shardingConditionValues.iterator().next();
Collection<String> shardingResult = shardingConditionValue instanceof ListShardingConditionValue
? doSharding(availableTargetNames, (ListShardingConditionValue) shardingConditionValue, dataNodeInfo)
: doSharding(availableTargetNames, (RangeShardingConditionValue) shardingConditionValue, dataNodeInfo);
Collection<String> result = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
result.addAll(shardingResult);
return result;
}
private Collection<String> doSharding(final Collection<String> availableTargetNames, final ListShardingConditionValue<?> shardingValue, final DataNodeInfo dataNodeInfo) {
Collection<String> result = new LinkedList<>();
for (Comparable<?> each : shardingValue.getValues()) {
String target = shardingAlgorithm.doSharding(availableTargetNames,
new PreciseShardingValue(shardingValue.getTableName(), shardingValue.getColumnName(), dataNodeInfo, each));
if (null != target && availableTargetNames.contains(target)) {
result.add(target);
} else if (null != target && !availableTargetNames.contains(target)) {
throw new ShardingSphereException(String.format("Route table %s does not exist, available actual table: %s", target, availableTargetNames));
}
}
return result;
}
ModShardingAlgorithm#doSharding()
,如果是取余的算法就是获取后缀值。
@Override
public String doSharding(final Collection<String> availableTargetNames, final PreciseShardingValue<Comparable<?>> shardingValue) {
String shardingResultSuffix = getShardingResultSuffix(cutShardingValue(shardingValue.getValue()).mod(new BigInteger(String.valueOf(shardingCount))).toString());
return findMatchedTargetName(availableTargetNames, shardingResultSuffix, shardingValue.getDataNodeInfo()).orElse(null);
}
重写
KernelProcessor#rewrite
,调用SQLRewriteEntry
重写Sql。
private SQLRewriteResult rewrite(LogicSQL logicSQL, ShardingSphereDatabase database, ConfigurationProperties props, RouteContext routeContext) {
SQLRewriteEntry sqlRewriteEntry = new SQLRewriteEntry(database, props);
return sqlRewriteEntry.rewrite(logicSQL.getSql(), logicSQL.getParameters(), logicSQL.getSqlStatementContext(), routeContext);
}
SQLRewriteEntry#rewrite
,调用RouteSQLRewriteEngine
重写Sql。
public SQLRewriteResult rewrite(final String sql, final List<Object> parameters, final SQLStatementContext<?> sqlStatementContext, final RouteContext routeContext) {
SQLRewriteContext sqlRewriteContext = createSQLRewriteContext(sql, parameters, sqlStatementContext, routeContext);
SQLTranslatorRule rule = database.getRuleMetaData().findSingleRule(SQLTranslatorRule.class).orElseGet(() -> new SQLTranslatorRule(new SQLTranslatorRuleConfiguration()));
DatabaseType protocolType = database.getProtocolType();
DatabaseType storageType = database.getResource().getDatabaseType();
return routeContext.getRouteUnits().isEmpty()
? new GenericSQLRewriteEngine(rule, protocolType, storageType).rewrite(sqlRewriteContext)
: new RouteSQLRewriteEngine(rule, protocolType, storageType).rewrite(sqlRewriteContext, routeContext);
}
RouteSQLRewriteEngine#rewrite
,重写sql。
public RouteSQLRewriteResult rewrite(final SQLRewriteContext sqlRewriteContext, final RouteContext routeContext) {
Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<>(routeContext.getRouteUnits().size(), 1);
for (Entry<String, Collection<RouteUnit>> entry : aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
Collection<RouteUnit> routeUnits = entry.getValue();
if (isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits)) {
sqlRewriteUnits.put(routeUnits.iterator().next(), createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
} else {
addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
}
}
return new RouteSQLRewriteResult(translate(sqlRewriteContext.getSqlStatementContext().getSqlStatement(), sqlRewriteUnits));
}
private void addSQLRewriteUnits(final Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits, final SQLRewriteContext sqlRewriteContext,
final RouteContext routeContext, final Collection<RouteUnit> routeUnits) {
for (RouteUnit each : routeUnits) {
sqlRewriteUnits.put(each, new SQLRewriteUnit(new RouteSQLBuilder(sqlRewriteContext, each).toSQL(), getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each)));
}
}
AbstractSQLBuilder#toSQL
,重新生成SQL。
@Override
public final String toSQL() {
if (context.getSqlTokens().isEmpty()) {
return context.getSql();
}
Collections.sort(context.getSqlTokens());
StringBuilder result = new StringBuilder();
result.append(context.getSql(), 0, context.getSqlTokens().get(0).getStartIndex());
for (SQLToken each : context.getSqlTokens()) {
result.append(each instanceof ComposableSQLToken ? getComposableSQLTokenText((ComposableSQLToken) each) : getSQLTokenText(each));
result.append(getConjunctionText(each));
}
return result.toString();
}
归并
ShardingSpherePreparedStatement#getResultSet()
,归并查询到的结果集。
@Override
public ResultSet getResultSet() throws SQLException {
if (null != currentResultSet) {
return currentResultSet;
}
if (trafficContext.isMatchTraffic()) {
return executor.getTrafficExecutor().getResultSet();
}
if (executionContext.getRouteContext().isFederated()) {
return executor.getFederationExecutor().getResultSet();
}
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
List<ResultSet> resultSets = getResultSets();
MergedResult mergedResult = mergeQuery(getQueryResults(resultSets));
currentResultSet = new ShardingSphereResultSet(resultSets, mergedResult, this, executionContext);
}
return currentResultSet;
}
MergeEngine#merge
,执行merge。
public MergedResult merge(final List<QueryResult> queryResults, final SQLStatementContext<?> sqlStatementContext) throws SQLException {
Optional<MergedResult> mergedResult = executeMerge(queryResults, sqlStatementContext);
Optional<MergedResult> result = mergedResult.isPresent() ? Optional.of(decorate(mergedResult.get(), sqlStatementContext)) : decorate(queryResults.get(0), sqlStatementContext);
return result.orElseGet(() -> new TransparentMergedResult(queryResults.get(0)));
}
@SuppressWarnings({"unchecked", "rawtypes"})
private Optional<MergedResult> executeMerge(final List<QueryResult> queryResults, final SQLStatementContext<?> sqlStatementContext) throws SQLException {
for (Entry<ShardingSphereRule, ResultProcessEngine> entry : engines.entrySet()) {
if (entry.getValue() instanceof ResultMergerEngine) {
ResultMerger resultMerger = ((ResultMergerEngine) entry.getValue()).newInstance(
database.getName(), database.getResource().getDatabaseType(), entry.getKey(), props, sqlStatementContext);
return Optional.of(resultMerger.merge(queryResults, sqlStatementContext, database));
}
}
return Optional.empty();
}
ShardingDQLResultMerger#merge
,根据group by、distinct、order by等关键字做不同的归并处理
public MergedResult merge(final List<QueryResult> queryResults, final SQLStatementContext<?> sqlStatementContext, final ShardingSphereDatabase database) throws SQLException {
if (1 == queryResults.size() && !isNeedAggregateRewrite(sqlStatementContext)) {
return new IteratorStreamMergedResult(queryResults);
}
Map<String, Integer> columnLabelIndexMap = getColumnLabelIndexMap(queryResults.get(0));
SelectStatementContext selectStatementContext = (SelectStatementContext) sqlStatementContext;
selectStatementContext.setIndexes(columnLabelIndexMap);
MergedResult mergedResult = build(queryResults, selectStatementContext, columnLabelIndexMap, database);
return decorate(queryResults, selectStatementContext, mergedResult);
}
private MergedResult build(final List<QueryResult> queryResults, final SelectStatementContext selectStatementContext,
final Map<String, Integer> columnLabelIndexMap, final ShardingSphereDatabase database) throws SQLException {
String defaultSchemaName = DatabaseTypeEngine.getDefaultSchemaName(selectStatementContext.getDatabaseType(), database.getName());
ShardingSphereSchema schema = selectStatementContext.getTablesContext().getSchemaName()
.map(optional -> database.getSchemas().get(optional)).orElseGet(() -> database.getSchemas().get(defaultSchemaName));
if (isNeedProcessGroupBy(selectStatementContext)) {
return getGroupByMergedResult(queryResults, selectStatementContext, columnLabelIndexMap, schema);
}
if (isNeedProcessDistinctRow(selectStatementContext)) {
setGroupByForDistinctRow(selectStatementContext);
return getGroupByMergedResult(queryResults, selectStatementContext, columnLabelIndexMap, schema);
}
if (isNeedProcessOrderBy(selectStatementContext)) {
return new OrderByStreamMergedResult(queryResults, selectStatementContext, schema);
}
return new IteratorStreamMergedResult(queryResults);
}