package com.test.interceptor; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.util.List; import java.util.Map; import java.util.Properties; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.executor.parameter.ParameterHandler; import org.apache.ibatis.executor.statement.RoutingStatementHandler; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; import org.apache.ibatis.session.ResultHandler; import org.apache.ibatis.session.RowBounds; import com.mysql.jdbc.PreparedStatement; import com.test.util.Page; @Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class }), @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) }) public class StatementHandleInterceptor implements Interceptor { public static final String MYSQL = "mysql"; protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>(); public Object intercept(Invocation invocation) throws Throwable { if (invocation.getTarget() instanceof StatementHandler){ Page<?> page = pageThreadLocal.get(); if(page==null){ return invocation.proceed(); } RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation .getTarget(); StatementHandler delegate = ReflectUtil.getFieldValue( statementHandler, "delegate"); BoundSql boundSql = delegate.getBoundSql(); Connection connection = (Connection) invocation.getArgs()[0]; if(page.getTotalPage()>-1){ System.out.println("总页数:"+page.getTotalPage()); }else{ Object obj = boundSql.getParameterObject(); MetaObject metaStatementHandler = SystemMetaObject.forObject(statementHandler); MappedStatement mappedStatement=(MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement"); queryTotalRecord(page, obj, mappedStatement, connection); } String sql = boundSql.getSql(); String pageSql = buildPageSql(page,sql); System.out.println("分页时,生成pageSql:"+pageSql); ReflectUtil.setFieldValue((Object)boundSql, "sql",pageSql); return invocation.proceed(); }else{ Page<?> page = findPageObject(invocation.getArgs()[1]); if(page==null){ System.out.println("没有page参数对象,不是分页查询"); return invocation.proceed(); }else{ System.out.println("检测到page对象!使用分页查询"); } pageThreadLocal.set(page); try{ return invocation.proceed(); //可setpage Results /*Object resultObj = invocation.proceed(); if(resultObj instanceof List){ page.setResults((List)resultObj); } return resultObj;*/ }finally{ pageThreadLocal.remove(); } } } private String buildPageSql(Page page,String sql) { // 计算第一条记录的位置,Mysql中记录的位置是从0开始的。 int offset = (page.getPageNo() - 1) * page.getPageSize(); return new StringBuilder(sql).append(" limit ").append(offset) .append(",").append(page.getPageSize()).toString(); } /** * 判定是否需要分页拦截 * @param object * @return */ private Page<?> findPageObject(Object object) { if(object instanceof Page<?>){ return (Page<?>) object; }else if(object instanceof Map){ for(Object o:((Map<?,?>) object).values()){ if(o instanceof Page<?>){ return (Page<?>) o; } } } return null; } /** * 查询总记录数 * @param page * @param obj * @param mappedStatement * @param connection * @throws SQLException */ private void queryTotalRecord(Page<?> page, Object obj, MappedStatement mappedStatement, Connection connection) throws SQLException { BoundSql boundSql = mappedStatement.getBoundSql(page); String sql = boundSql.getSql(); String countSql = this.buildCountSql(sql); System.out.println("分页时,生成countSql:"+countSql); List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(),countSql,parameterMappings,obj); ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, obj, countBoundSql); PreparedStatement pstmt = null; ResultSet rs = null; try{ pstmt = (PreparedStatement) connection.prepareStatement(countSql); parameterHandler.setParameters(pstmt); rs = pstmt.executeQuery(); if(rs.next()){ long totalRecord = rs.getLong(1); page.setTotalRecord(totalRecord); } }finally{ if(rs!=null){ rs.close(); } if(pstmt!=null){ pstmt.close(); } } } /** * 构造查询总记录数sql * @param sql * @return */ private String buildCountSql(String sql) { int index = sql.toLowerCase().indexOf("from"); return "select count(*)"+sql.substring(index); } public Object plugin(Object target) { return Plugin.wrap(target, this); } public void setProperties(Properties properties) { } }
调用
结果: