• mybatis 拦截sql修改


    需求:需要进行不停机数据库迁移,进行数据库双写,先将数据同时写入新老库(也可以在数据库层面进行主从复制,但是运维和dba无法配合,否定);
    在原mapper接口中加入新注解,扫描该注解,获取完整sql,通过plusar发送消息,同步执行sql语句写入新数据库,进行同步;
    
    mybatis拦截器
    package com.zhaopin.zhiq.doublewrite;
    
    import com.alibaba.fastjson.JSONObject;
    import com.zhaopin.platzqaserver.pulsar.PulsarProducer;
    import com.zhaopin.platzqaserver.utils.LogUtil;
    import com.zhaopin.platzqaserver.utils.SpringUtils;
    import com.zhaopin.zhiq.doublewrite.constant.DoubleWriteConstant;
    import org.apache.ibatis.binding.MapperMethod;
    import org.apache.ibatis.executor.Executor;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.SqlCommandType;
    import org.apache.ibatis.plugin.*;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Properties;
    import java.util.Set;
    import java.util.stream.Collectors;
    
    //数据库双写拦截器
    @Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
    public class DoubleWriteInterceptor implements Interceptor {
    
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
    
            Object result = invocation.proceed();
            if ((int) result == 0) {
                return result;
            }
            String sql = "";
            try {
                Object[] args = invocation.getArgs();
                MappedStatement ms = (MappedStatement) args[0];
              	//获取请求参数和调用的mapper方法
                if (hasMapperName(ms)) {
                    Object parameter = args[1];
                    sql = DoubleWriteUtil.getCompleteSql(ms.getConfiguration(), ms.getBoundSql(parameter));
                    JSONObject message = new JSONObject();
                    if (SqlCommandType.INSERT.equals(ms.getSqlCommandType())) {
                        //批量insert时的sql处理
                        if (parameter instanceof MapperMethod.ParamMap) {
                            MapperMethod.ParamMap parameterMap = (MapperMethod.ParamMap) parameter;
                            Object params = parameterMap.get("param1");
                            if (params instanceof ArrayList) {
                                ArrayList paramsList = (ArrayList) params;
                                List<Long> ids = (List<Long>) paramsList.stream().map(param -> Long.parseLong(String.valueOf(DoubleWriteUtil.getFieldValue(param, "id")))).collect(Collectors.toList());
                                sql = DoubleWriteUtil.createDoubleWriteBatchSql(sql, ids);
                            }
                        } else {
                            Object id = DoubleWriteUtil.getFieldValue(parameter, "id");
                            sql = DoubleWriteUtil.createDoubleWriteSimpleSql(sql, id);
                        }
                    }
                    message.put("sql", sql);
                    message.put("sqlCommandType", ms.getSqlCommandType());
                    PulsarProducer.send(DoubleWriteConstant.DB_DOUBLE_WRITE_TOPIC, message.toJSONString());
                }
            } catch (Exception e) {
                LogUtil.error("failed to double write error , sql " + sql, e);
            }
            return result;
        }
    		//校验拦截的方法是否是需要被双写的sql
        private boolean hasMapperName(MappedStatement ms) {
            DoubleWritePrepareMapper doubleWritePrepareMapper = SpringUtils.getBean("doubleWritePrepareMapper", DoubleWritePrepareMapper.class);
            Set<String> mapperNames = doubleWritePrepareMapper.getMapperNames();
            String mapperName = ms.getId();
            mapperName = mapperName.substring(0, mapperName.lastIndexOf("."));
            return mapperNames.contains(mapperName);
        }
    
        @Override
        public Object plugin(Object target) {
            return Plugin.wrap(target, this);
        }
    
        @Override
        public void setProperties(Properties properties) {
    
        }
    
    }
    
    
    获取原始sql
    package com.zhaopin.zhiq.doublewrite;
    
    import com.zhaopin.platzqaserver.utils.LogUtil;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.mapping.ParameterMapping;
    import org.apache.ibatis.reflection.MetaObject;
    import org.apache.ibatis.session.Configuration;
    import org.apache.ibatis.type.TypeHandlerRegistry;
    
    import java.lang.reflect.Field;
    import java.lang.reflect.Method;
    import java.text.DateFormat;
    import java.util.ArrayList;
    import java.util.Date;
    import java.util.List;
    import java.util.Locale;
    
    /**
     * 获取原始sql修改
     */
    public class DoubleWriteUtil {
    
        private DoubleWriteUtil(){}
    
        public static String createDoubleWriteSimpleSql(String oldSql, Object id) {
            oldSql = oldSql.replaceFirst("\\(","(id, ");
    
            if (oldSql.contains("values")) {
                return oldSql.replace("values (", "values ("+ id + ", ");
            }
    
            if (oldSql.contains("VALUES")) {
                return oldSql.replace("VALUES (", "VALUES ("+ id + ", ");
            }
    
            return "";
        }
    
        public static String createDoubleWriteBatchSql(String oldSql, List<Long> ids) {
            oldSql = oldSql.replaceFirst("\\(","(id, ");
            StringBuffer sql = new StringBuffer();
    
            String sqlHead;
            if (oldSql.contains("values")) {
                sqlHead = oldSql.substring(0, oldSql.indexOf("values"));
                oldSql = oldSql.substring(oldSql.indexOf("values"));
            } else {
                sqlHead = oldSql.substring(0, oldSql.indexOf("VALUES"));
                oldSql = oldSql.substring(oldSql.indexOf("VALUES"));
            }
    
            sql.append(sqlHead);
            for (int i = 0; i< ids.size(); i++) {
                Long id = ids.get(i);
                String middleSql = oldSql.replaceFirst("\\(", "(" + id + ", ");
                if (i < ids.size() - 1) {
                    middleSql = middleSql.substring(0, middleSql.indexOf(")") + 1);
                    oldSql = oldSql.substring(oldSql.indexOf(")") + 1);
                }
                sql.append(middleSql);
            }
            return sql.toString();
        }
    
        /**
         * 获取完整sql
         * @param configuration
         * @param boundSql
         * @return
         */
        public static String getCompleteSql(Configuration configuration, BoundSql boundSql) {
            Object parameterObject = boundSql.getParameterObject();
            List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
            //替换空格、换行、tab缩进等
            String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
            if (parameterMappings.size() > 0 && parameterObject != null) {
                TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
                if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                    sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));
                } else {
                    MetaObject metaObject = configuration.newMetaObject(parameterObject);
                    for (ParameterMapping parameterMapping : parameterMappings) {
                        String propertyName = parameterMapping.getProperty();
                        if (metaObject.hasGetter(propertyName)) {
                            Object obj = metaObject.getValue(propertyName);
                            sql = sql.replaceFirst("\\?", getParameterValue(obj));
                        } else if (boundSql.hasAdditionalParameter(propertyName)) {
                            Object obj = boundSql.getAdditionalParameter(propertyName);
                            sql = sql.replaceFirst("\\?", getParameterValue(obj));
                        }
                    }
                }
            }
            return sql;
        }
    
        private static String getParameterValue(Object obj) {
            String value;
            if (obj instanceof String) {
                value = "'" + obj + "'";
            } else if (obj instanceof Date) {
                DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
                value = "'" + formatter.format(new Date()) + "'";
            } else {
                if (obj != null) {
                    value = obj.toString();
                } else {
                    value = "";
                }
            }
            return value.replace("$", "\\$");
        }
    
        /**
         * 反射获取字段值
         * @return
         */
        public static Object getFieldValue(Object parameter, String fieldName) {
            try {
                List<MetaClass> metaClasses = new ArrayList<>();
                getFields(parameter.getClass(), metaClasses, fieldName);
                if (metaClasses.size() > 0) {
                    String field = metaClasses.get(0).field;
                    Method method = metaClasses.get(0).superClass.getMethod("get" + captureName(field));
                    return method.invoke(parameter);
                }
            } catch (Exception e) {
                LogUtil.error("failed to get", e);
            }
            return null;
        }
    		//参数可能从父类继承
        private static void getFields(Class<?> clazz, List<MetaClass> metaClasses, String fieldName) {
            Field[] declaredFields = clazz.getDeclaredFields();
            for (Field field : declaredFields) {
                field.setAccessible(true);
                if (fieldName.equals(field.getName())) {
                    MetaClass metaClass = new MetaClass();
                    metaClass.superClass = clazz;
                    metaClass.field = fieldName;
                    metaClass.paramType = field.getType();
                    metaClasses.add(metaClass);
                    findField = true;
                    return;
                }
            }
            if (clazz.getSuperclass() != null) {
                getFields(clazz.getSuperclass(), metaClasses, fieldName);
            }
        }
    
        private static class MetaClass {
    				//父类
            private Class<?> superClass;
    				//参数
            private String field;
    				//参数类型
            private Class<?> paramType;
        }
      	//字符串首字母大写
      	private static String captureName(String str) {
            // 进行字母的ascii编码前移,效率要高于截取字符串进行转换的操作
            char[] cs = str.toCharArray();
            cs[0] -= (cs[0] > 96 && cs[0] < 123) ? 32 : 0;
            return String.valueOf(cs);
        }
    
    }
    
    
    获取被双写注解修饰的Repository
    package com.zhaopin.zhiq.doublewrite;
    
    import com.zhaopin.platzqaserver.utils.LogUtil;
    import com.zhaopin.zhiq.annotation.DoubleWrite;
    import org.mybatis.spring.mapper.MapperFactoryBean;
    import org.springframework.beans.BeansException;
    import org.springframework.beans.factory.config.BeanPostProcessor;
    import org.springframework.stereotype.Component;
    
    import java.lang.annotation.Annotation;
    import java.util.*;
    
    /**
     * 获取被双写注解修饰的Repository
     */
    @Component
    public class DoubleWritePrepareMapper implements BeanPostProcessor {
    
        /**
         * key: Repository 接口的Name
         * value: Repository 接口的Class对象
         */
        private Map<String, Class<?>> mappers;
    
        @Override
        public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
            return bean;
        }
    
        @Override
        public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
            if (mappers == null) {
                mappers = new HashMap<>(32);
            }
            try {
                MapperFactoryBean mapper;
                if (bean instanceof MapperFactoryBean) {
                    mapper = (MapperFactoryBean) bean;
                    Class mapperInterface = mapper.getMapperInterface();
                    Annotation annotation = mapperInterface.getDeclaredAnnotation(DoubleWrite.class);
                    if (annotation != null) {
                        mappers.put(mapperInterface.getName(), mapperInterface);
                    }
                }
            } catch (Exception e) {
                LogUtil.error("failed to initialize double read mappers : ", e);
            }
            return bean;
        }
    
        public Set<String> getMapperNames() {
            if (this.mappers == null) {
                return null;
            }
            return this.mappers.keySet();
        }
    }
    
    原有执行sql
    @ZhiqUserDBRepository //原有数据源
    @DoubleWrite //需要双写的数据源
    public interface IdentityFavorRepository {
        @Insert("insert into zhiq_identity_favor (uiid, uid, favored_uiid, favored_uid) " +
                "values (#{uiid}, #{uid}, #{favoredUiid}, #{favoredUid})" +
                "ON CONFLICT (uiid, favored_uiid) DO NOTHING")
        @Options(useGeneratedKeys = true, keyProperty="id", keyColumn = "id")
        boolean insert(IdentityFavor identityFavor);
    }
    
  • 相关阅读:
    移动端rem屏幕设置
    封装ajax库,post请求
    获取浏览器url参数
    身份证验证
    jq封装插件
    页面分享功能,分享好友、朋友圈判断,用share_type做标记 这里用的是jweixin-1.3.2.js
    @RequestParam和@RequestBody区别
    400报错
    IDEA中用Maven构建SSM项目环境入门
    Eclipse搭建Spring开发环境
  • 原文地址:https://www.cnblogs.com/SimonHu1993/p/15884044.html
Copyright © 2020-2023  润新知