服务器之家:专注于服务器技术及软件下载分享
分类导航

PHP教程|ASP.NET教程|Java教程|ASP教程|编程技术|正则表达式|C/C++|IOS|C#|Swift|Android|VB|R语言|JavaScript|易语言|vb.net|

服务器之家 - 编程语言 - Java教程 - springboot配置多数据源后mybatis拦截器失效的解决

springboot配置多数据源后mybatis拦截器失效的解决

2022-01-13 00:43那些年搬过的砖 Java教程

这篇文章主要介绍了springboot配置多数据源后mybatis拦截器失效的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

配置文件是通过springcloudconfig远程分布式配置。采用阿里Druid数据源。并支持一主多从的读写分离。分页组件通过拦截器拦截带有page后缀的方法名,动态的设置total总数。

1. 解析配置文件初始化数据源

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@Configuration
public class DataSourceConfiguration {
    /**
     * 数据源类型
     */
    @Value("${spring.datasource.type}")
    private Class<? extends DataSource> dataSourceType;
    /**
     * 主数据源配置
     *
     * @return
     */
    @Bean(name = "masterDataSource", destroyMethod = "close")
    @Primary
    @ConfigurationProperties(prefix = "spring.datasource")
    public DataSource masterDataSource() {
        DataSource source = DataSourceBuilder.create().type(dataSourceType).build();
        return source;
    }
    /**
     * 从数据源配置
     *
     * @return
     */
    @Bean(name = "slaveDataSource0")
    @ConfigurationProperties(prefix = "spring.slave0")
    public DataSource slaveDataSource0() {
        DataSource source = DataSourceBuilder.create().type(dataSourceType).build();
        return source;
    }
    /**
     * 从数据源集合
     *
     * @return
     */
    @Bean(name = "slaveDataSources")
    public List<DataSource> slaveDataSources() {
        List<DataSource> slaveDataSources = new ArrayList();
        slaveDataSources.add(slaveDataSource0());
        return slaveDataSources;
    }
}

2. 定义数据源枚举类型

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public enum DataSourceType {
    master("master", "master"), slave("slave", "slave");
    private String type;
    private String name;
    DataSourceType(String type, String name) {
        this.type = type;
        this.name = name;
    }
    public String getType() {
        return type;
    }
    public void setType(String type) {
        this.type = type;
    }
    public String getName() {
        return name;
    }
    public void setName(String name) {
        this.name = name;
    }
}

3. TheadLocal保存数据源类型

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class DataSourceContextHolder {
    private static final ThreadLocal<String> local = new ThreadLocal<String>();
    public static ThreadLocal<String> getLocal() {
        return local;
    }
    public static void slave() {
        local.set(DataSourceType.slave.getType());
    }
    public static void master() {
        local.set(DataSourceType.master.getType());
    }
    public static String getJdbcType() {
        return local.get();
    }
    public static void clearDataSource(){
        local.remove();
    }
}

4. 自定义sqlSessionProxy

并将数据源填充到DataSourceRoute

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@Configuration
@ConditionalOnClass({EnableTransactionManagement.class})
@Import({DataSourceConfiguration.class})
public class DataSourceSqlSessionFactory {
    private Logger logger = Logger.getLogger(DataSourceSqlSessionFactory.class);
    @Value("${spring.datasource.type}")
    private Class<? extends DataSource> dataSourceType;
    @Value("${mybatis.mapper-locations}")
    private String mapperLocations;
    @Value("${mybatis.type-aliases-package}")
    private String aliasesPackage;
    @Value("${slave.datasource.number}")
    private int dataSourceNumber;
    @Resource(name = "masterDataSource")
    private DataSource masterDataSource;
    @Resource(name = "slaveDataSources")
    private List<DataSource> slaveDataSources;
    @Bean
    @ConditionalOnMissingBean
    public SqlSessionFactory sqlSessionFactory() throws Exception {
        logger.info("======================= init sqlSessionFactory");
        SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
        sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy());
        PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations));
        sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage);
        sqlSessionFactoryBean.getObject().getConfiguration().setMapUnderscoreToCamelCase(true);
        return sqlSessionFactoryBean.getObject();
    }
    @Bean(name = "roundRobinDataSourceProxy")
    public AbstractRoutingDataSource roundRobinDataSourceProxy() {
        logger.info("======================= init robinDataSourceProxy");
        DataSourceRoute proxy = new DataSourceRoute(dataSourceNumber);
        Map<Object, Object> targetDataSources = new HashMap();
        targetDataSources.put(DataSourceType.master.getType(), masterDataSource);
        if(null != slaveDataSources) {
            for(int i=0; i<slaveDataSources.size(); i++){
                targetDataSources.put(i, slaveDataSources.get(i));
            }
        }
        proxy.setDefaultTargetDataSource(masterDataSource);
        proxy.setTargetDataSources(targetDataSources);
        return proxy;
    }
}

5. 自定义路由

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public class DataSourceRoute extends AbstractRoutingDataSource {
    private Logger logger = Logger.getLogger(DataSourceRoute.class);
    private final int dataSourceNumber;
    
    public DataSourceRoute(int dataSourceNumber) {
        this.dataSourceNumber = dataSourceNumber;
    }
    @Override
    protected Object determineCurrentLookupKey() {
        String typeKey = DataSourceContextHolder.getJdbcType();
        logger.info("==================== swtich dataSource:" + typeKey);
        if (typeKey.equals(DataSourceType.master.getType())) {
            return DataSourceType.master.getType();
        }else{
            //从数据源随机分配
            Random random = new Random();
            int slaveDsIndex = random.nextInt(dataSourceNumber);
            return slaveDsIndex;
        }
    }
}

6. 定义切面,dao层定义切面

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@Aspect
@Component
public class DataSourceAop {
    private Logger logger = Logger.getLogger(DataSourceAop.class);
    @Before("execution(* com.dbq.iot.mapper..*.get*(..)) || execution(* com.dbq.iot.mapper..*.isExist*(..)) " +
            "|| execution(* com.dbq.iot.mapper..*.select*(..)) || execution(* com.dbq.iot.mapper..*.count*(..)) " +
            "|| execution(* com.dbq.iot.mapper..*.list*(..)) || execution(* com.dbq.iot.mapper..*.query*(..))" +
            "|| execution(* com.dbq.iot.mapper..*.find*(..))|| execution(* com.dbq.iot.mapper..*.search*(..))")
    public void setSlaveDataSourceType(JoinPoint joinPoint) {
        DataSourceContextHolder.slave();
        logger.info("=========slave, method:" + joinPoint.getSignature().getName());
    }
    @Before("execution(* com.dbq.iot.mapper..*.add*(..)) || execution(* com.dbq.iot.mapper..*.del*(..))" +
            "||execution(* com.dbq.iot.mapper..*.upDate*(..)) || execution(* com.dbq.iot.mapper..*.insert*(..))" +
            "||execution(* com.dbq.iot.mapper..*.create*(..)) || execution(* com.dbq.iot.mapper..*.update*(..))" +
            "||execution(* com.dbq.iot.mapper..*.delete*(..)) || execution(* com.dbq.iot.mapper..*.remove*(..))" +
            "||execution(* com.dbq.iot.mapper..*.save*(..)) || execution(* com.dbq.iot.mapper..*.relieve*(..))" +
            "|| execution(* com.dbq.iot.mapper..*.edit*(..))")
    public void setMasterDataSourceType(JoinPoint joinPoint) {
        DataSourceContextHolder.master();
        logger.info("=========master, method:" + joinPoint.getSignature().getName());
    }
}

7. 最后在写库增加事务管理

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Configuration
@Import({DataSourceConfiguration.class})
public class DataSouceTranscation extends DataSourceTransactionManagerAutoConfiguration {
    private Logger logger = Logger.getLogger(DataSouceTranscation.class);
    @Resource(name = "masterDataSource")
    private DataSource masterDataSource;
    /**
     * 配置事务管理器
     *
     * @return
     */
    @Bean(name = "transactionManager")
    public DataSourceTransactionManager transactionManagers() {
        logger.info("===================== init transactionManager");
        return new DataSourceTransactionManager(masterDataSource);
    }
}

8. 在配置文件中增加数据源配置

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
spring.datasource.name=writedb
spring.datasource.url=jdbc:mysql://192.168.0.1/master?useUnicode=true&amp;characterEncoding=utf8&amp;autoReconnect=true&amp;failOverReadOnly=false
spring.datasource.username=root
spring.datasource.password=1234
spring.datasource.type=com.alibaba.druid.pool.DruidDataSource
spring.datasource.driver-class-name=com.mysql.jdbc.Driver
spring.datasource.filters=stat
spring.datasource.initialSize=20
spring.datasource.minIdle=20
spring.datasource.maxActive=200
spring.datasource.maxWait=60000
#从库的数量
slave.datasource.number=1
spring.slave0.name=readdb
spring.slave0.url=jdbc:mysql://192.168.0.2/slave?useUnicode=true&amp;characterEncoding=utf8&amp;autoReconnect=true&amp;failOverReadOnly=false
spring.slave0.username=root
spring.slave0.password=1234
spring.slave0.type=com.alibaba.druid.pool.DruidDataSource
spring.slave0.driver-class-name=com.mysql.jdbc.Driver
spring.slave0.filters=stat
spring.slave0.initialSize=20
spring.slave0.minIdle=20
spring.slave0.maxActive=200
spring.slave0.maxWait=60000

这样就实现了在springcloud框架下的读写分离,并且支持多个从库的负载均衡(简单的通过随机分配,也有网友通过算法实现平均分配,具体做法是通过一个线程安全的自增长Integer类型,取余实现。个人觉得没大必要。如果有大神有更好的方法可以一起探讨。)

Mabatis分页配置可通过dao层的拦截器对特定方法进行拦截,拦截后添加自己的逻辑代码,比如计算total等,具体代码如下(参考了网友的代码,主要是通过@Intercepts注解):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class PageInterceptor implements Interceptor {
    private static final Log logger = LogFactory.getLog(PageInterceptor.class);
    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
    private static final ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();
    private static String defaultDialect = "mysql"; // 数据库类型(默认为mysql)
    private static String defaultPageSqlId = ".*Page$"; // 需要拦截的ID(正则匹配)
    private String dialect = ""; // 数据库类型(默认为mysql)
    private String pageSqlId = ""; // 需要拦截的ID(正则匹配)
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        MetaObject metaStatementHandler = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,
                DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);
        // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环可以分离出最原始的的目标类)
        while (metaStatementHandler.hasGetter("h")) {
            Object object = metaStatementHandler.getValue("h");
            metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);
        }
        // 分离最后一个代理对象的目标类
        while (metaStatementHandler.hasGetter("target")) {
            Object object = metaStatementHandler.getValue("target");
            metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);
        }
        Configuration configuration = (Configuration) metaStatementHandler.getValue("delegate.configuration");
        if (null == dialect || "".equals(dialect)) {
            logger.warn("Property dialect is not setted,use default 'mysql' ");
            dialect = defaultDialect;
        }
        if (null == pageSqlId || "".equals(pageSqlId)) {
            logger.warn("Property pageSqlId is not setted,use default '.*Page$' ");
            pageSqlId = defaultPageSqlId;
        }
        MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
        // 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的MappedStatement的sql
        if (mappedStatement.getId().matches(pageSqlId)) {
            BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
            Object parameterObject = boundSql.getParameterObject();
            if (parameterObject == null) {
                throw new NullPointerException("parameterObject is null!");
            } else {
                PageParameter page = (PageParameter) metaStatementHandler
                        .getValue("delegate.boundSql.parameterObject.page");
                String sql = boundSql.getSql();
                // 重写sql
                String pageSql = buildPageSql(sql, page);
                metaStatementHandler.setValue("delegate.boundSql.sql", pageSql);
                metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
                metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
                Connection connection = (Connection) invocation.getArgs()[0];
                // 重设分页参数里的总页数等
                setPageParameter(sql, connection, mappedStatement, boundSql, page);
            }
        }
        // 将执行权交给下一个拦截器
        return invocation.proceed();
    }
    /**
     * @param sql
     * @param connection
     * @param mappedStatement
     * @param boundSql
     * @param page
     */
    private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement,
                                  BoundSql boundSql, PageParameter page) {
        // 记录总记录数
        String countSql = "select count(0) from (" + sql + ") as total";
        PreparedStatement countStmt = null;
        ResultSet rs = null;
        try {
            countStmt = connection.prepareStatement(countSql);
            BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,
                    boundSql.getParameterMappings(), boundSql.getParameterObject());
            Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters");
            if (metaParamsField != null) {
                try {
                    MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters");
                    ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo);
                } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
                        | IllegalAccessException e) {
                    // TODO Auto-generated catch block
                     logger.error("Ignore this exception", e);
                }
            }
            Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters");
            if (additionalField != null) {
                try {
                    Map<String, Object> map = (Map<String, Object>) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters");
                    ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map);
                } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
                        | IllegalAccessException e) {
                    // TODO Auto-generated catch block
                    logger.error("Ignore this exception", e);
                }
            }
            setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject());
            rs = countStmt.executeQuery();
            int totalCount = 0;
            if (rs.next()) {
                totalCount = rs.getInt(1);
            }
            page.setTotalCount(totalCount);
            int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1);
            page.setTotalPage(totalPage);
        } catch (SQLException e) {
            logger.error("Ignore this exception", e);
        } finally {
            try {
                if (rs != null){
                    rs.close();
                }
            } catch (SQLException e) {
                logger.error("Ignore this exception", e);
            }
            try {
                if (countStmt != null){
                    countStmt.close();
                }
            } catch (SQLException e) {
                logger.error("Ignore this exception", e);
            }
        }
    }
    /**
     * 对SQL参数(?)设值
     *
     * @param ps
     * @param mappedStatement
     * @param boundSql
     * @param parameterObject
     * @throws SQLException
     */
    private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
                               Object parameterObject) throws SQLException {
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);
        parameterHandler.setParameters(ps);
    }
    /**
     * 根据数据库类型,生成特定的分页sql
     *
     * @param sql
     * @param page
     * @return
     */
    private String buildPageSql(String sql, PageParameter page) {
        if (page != null) {
            StringBuilder pageSql = new StringBuilder();
            pageSql = buildPageSqlForMysql(sql,page);
            return pageSql.toString();
        } else {
            return sql;
        }
    }
    /**
     * mysql的分页语句
     *
     * @param sql
     * @param page
     * @return String
     */
    public StringBuilder buildPageSqlForMysql(String sql, PageParameter page) {
        StringBuilder pageSql = new StringBuilder(100);
        String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());
        pageSql.append(sql);
        pageSql.append(" limit " + beginrow + "," + page.getPageSize());
        return pageSql;
    }
    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }
    @Override
    public void setProperties(Properties properties) {
    }
}

这里碰到一个比较有趣的问题,就是sql如果是foreach参数,在拦截后无法注入。需要加入以下代码才可以(有得资料上只提到重置metaParameters)。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters");
if (metaParamsField != null) {
    try {
        MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters");
        ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo);
    } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
            | IllegalAccessException e) {
        // TODO Auto-generated catch block
         logger.error("Ignore this exception", e);
    }
}
Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters");
if (additionalField != null) {
    try {
        Map<String, Object> map = (Map<String, Object>) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters");
        ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map);
    } catch (SecurityException | NoSuchFieldException | IllegalArgumentException
            | IllegalAccessException e) {
        // TODO Auto-generated catch block
        logger.error("Ignore this exception", e);
    }
}

读写分离倒是写好了,但是发现增加了mysql一主多从的读写分离后,此分页拦截器直接失效。

最后分析原因是因为,我们在做主从分离时,自定义了SqlSessionFactory,导致此拦截器没有注入。

在上面第4步中,DataSourceSqlSessionFactory中注入拦截器即可,具体代码如下

通过注解引入拦截器类:

?
1
@Import({DataSourceConfiguration.class,PageInterceptor.class})

注入拦截器

?
1
2
@Autowired
    private PageInterceptor pageInterceptor;

SqlSessionFactoryBean中设置拦截器

?
1
sqlSessionFactoryBean.setPlugins(newInterceptor[]{pageInterceptor});

这里碰到一个坑,就是设置plugins时必须在sqlSessionFactoryBean.getObject()之前。

SqlSessionFactory在生成的时候就会获取plugins,并设置到Configuration中,如果在之后设置则不会注入。

可跟踪源码看到:

?
1
sqlSessionFactoryBean.getObject()
?
1
2
3
4
5
6
public SqlSessionFactory getObject() throws Exception {
    if (this.sqlSessionFactory == null) {
      afterPropertiesSet();
    }
    return this.sqlSessionFactory;
}
?
1
2
3
4
5
6
7
public void afterPropertiesSet() throws Exception {
    notNull(dataSource, "Property 'dataSource' is required");
    notNull(sqlSessionFactoryBuilder, "Property 'sqlSessionFactoryBuilder' is required");
    state((configuration == null && configLocation == null) || !(configuration != null && configLocation != null),
              "Property 'configuration' and 'configLocation' can not specified with together");
    this.sqlSessionFactory = buildSqlSessionFactory();
  }

buildSqlSessionFactory()

?
1
2
3
4
5
6
7
8
if (!isEmpty(this.plugins)) {
      for (Interceptor plugin : this.plugins) {
        configuration.addInterceptor(plugin);
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("Registered plugin: '" + plugin + "'");
        }
      }
    }

最后贴上正确的配置代码(DataSourceSqlSessionFactory代码片段)

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
@Bean
@ConditionalOnMissingBean
public SqlSessionFactory sqlSessionFactory() throws Exception {
        logger.info("======================= init sqlSessionFactory");
        SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();
        sqlSessionFactoryBean.setPlugins(new Interceptor[]{pageInterceptor});
        sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy());
        PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations));
        sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage);
        SqlSessionFactory sqlSessionFactory = sqlSessionFactoryBean.getObject();
        sqlSessionFactory.getConfiguration().setMapUnderscoreToCamelCase(true);
        return sqlSessionFactory;
}

以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://www.jianshu.com/p/36fb90db79ce

延伸 · 阅读

精彩推荐