这是一个仅支持MySQL的简单的Mybatis拦截器小运用。
原理
利用Mybatis的拦截器在sql执行之前把sql取出来,添加上分页语法,再把sql赋值回去。
- 利用ThreadLocal在线程内传送 页数 和 页面大小参数,减少对原有代码的改动
- 利用反射把修改后的sql 赋值回去
拦截器源码
@Component
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class})})
public class MyPageInterceptor implements Interceptor {
private static final Logger logger = LoggerFactory.getLogger(MyPageInterceptor.class);
@Override
public Object intercept(Invocation invocation) throws Throwable {
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
BoundSql boundSql = handler.getBoundSql();
Object param = boundSql.getParameterObject();
//取得原始sql
String sql =boundSql.getSql();
logger.info("拦截到 sql:{} 参数:{}",sql,param);
//判断需不需要分页
if(!PageHelper.isPage()){
//不需要分页
return invocation.proceed();
}
//获取原始sql查询总数
int count = getCount(invocation,sql);
PageHelper.setTotal(count);
int pageNo=PageHelper.getPageNo();
int pageSize=PageHelper.getPageSize();
int currentPage=(pageNo-1)*pageSize;
//组装新sql
StringBuilder preSql=new StringBuilder(sql);
preSql.append(" limit ").append(currentPage).append(",").append(pageSize);
//新sql赋值回原对象
ReflectUtil.setValueByFieldName(boundSql,"sql",preSql.toString());
logger.info("分页后 sql:{} 参数:{}",boundSql.getSql(),param);
//执行
return invocation.proceed();
}
@Override
public Object plugin(Object o) {
return Plugin.wrap(o, this);
}
@Override
public void setProperties(Properties properties) {
logger.warn(properties.toString());
}
private Integer getCount(Invocation invocation,String sql) throws SQLException {
RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
Connection connection = (Connection) invocation.getArgs()[0];
StringBuilder sb=new StringBuilder("select count(*) from (");
sb.append(sql).append(") count");
PreparedStatement ps = connection.prepareStatement(sb.toString());
handler.getParameterHandler().setParameters(ps);
ResultSet rs = ps.executeQuery();
rs.next();
Integer count=rs.getInt(1);
rs.close();
ps.close();
return count;
}
}
PageHelper源码
public class PageHelper {
static ThreadLocal<PageInfo> pageInfo = new ThreadLocal<>();
public static void startPage(Integer pageNo, Integer pageSize) {
PageInfo p = new PageInfo();
p.setPageNo(pageNo);
p.setPageSize(pageSize);
pageInfo.set(p);
}
public static PageInfo getPageInfo(Object data) throws ServiceException {
PageInfo p = pageInfo.get();
if (p == null) {
throw new ServiceException("此线程不存在分页查询。");
}
p.setData(data);
pageInfo.set(null);
return p;
}
public static void setTotal(Integer count) {
pageInfo.get().setTotal(count);
}
public static boolean isPage() {
return pageInfo.get() == null ? false : true;
}
public static Integer getPageNo(){
return pageInfo.get().getPageNo();
}
public static Integer getPageSize(){
return pageInfo.get().getPageSize();
}
}
反射工具源码
import java.lang.reflect.Field;
public class ReflectUtil {
public static Object getFieldValue(Object target, String field) throws IllegalAccessException {
Field f = getFieldByFieldName(target, field);
if (f == null) {
return null;
}
if (f.isAccessible()) {
return f.get(target);
}
f.setAccessible(true);
return f.get(target);
}
public static void setValueByFieldName(Object obj, String fieldName, Object value) throws SecurityException, NoSuchFieldException,
IllegalArgumentException, IllegalAccessException {
Field field = getFieldByFieldName(obj,fieldName);
if (field.isAccessible()) {
field.set(obj, value);
} else {
field.setAccessible(true);
field.set(obj, value);
field.setAccessible(false);
}
}
public static Field getFieldByFieldName(Object obj, String fieldName) {
Field f = null;
while (true) {
Class<?> clzz = obj.getClass();
if (clzz == Object.class) {
break;
}
try {
f = clzz.getDeclaredField(fieldName);
return f;
} catch (NoSuchFieldException e) {
e.printStackTrace();
}
}
return f;
}
}