• MyBatis多租户隔离插件开发


    在SASS的大潮流下,相信依然存在很多使用一个数据库为多个租户提供服务的场景,这个情况下一般是多个租户共用同一套表通过sql语句级别来隔离不同租户的资源,比如设置一个租户标识字段,每次查询的时候在后面附加一个筛选条件:TenantId=xxx。这样能低代价、简单地实现多租户服务,但是每次执行sql的时候需要附加字段隔离,否则会出现数据错乱。

    此隔离过程应该自动标识完成,所以我今天借助于Mybatis的插件机制来完成一个多租户sql隔离插件。

    一、设计需求

    1、首先,我们需要一种方案来识别哪些表需要使用多租户隔离,并且确定多租户隔离字段名称。

    2、然后拦截mybatis执行过程中的prepare方法,通过改写加入多租户隔离条件,然后替换为我们新的sql。

    3、寻找一种方法能多层次的智能的为识别到的数据表添加condition,毕竟CRUD过程都会存在子查询,并且不会丢失原有的where条件。

    二、设计思路

    对于需求1,我们可以定义一个条件字段决策器,用来决策某个表是否需要添加多租户过滤条件,比如定义一个接口:ITableFieldConditionDecision

    /**
     * 表字段条件决策器
     * 用于决策某个表是否需要添加某个字段过滤条件
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/23 15:49
     **/
    public interface ITableFieldConditionDecision {
    
        /**
         * 条件字段是否运行null值
         * @return
         */
        boolean isAllowNullValue();
        /**
         * 判决某个表是否需要添加某个字段过滤
         *
         * @param tableName   表名称
         * @param fieldName   字段名称
         * @return
         */
        boolean adjudge(String tableName, String fieldName);
    
    }

    然后在使用插件的地方填写必要的参数来初始化决策器

    <!--多租户隔离插件-->
                    <bean class="com.smartdata360.smartfx.dao.plugin.MultiTenantPlugin">
                        <property name="properties">
                            <value>
                                <!--当前数据库方言-->
                                dialect=postgresql
                                <!--多租户隔离字段名称-->
                                tenantIdField=domain
                                <!--需要隔离的表名称java正则表达式-->
                                tablePattern=uam_*
                                <!--需要隔离的表名称,逗号分隔-->
                                tableSet=uam_user,uam_role
                            </value>
                        </property>
                    </bean>

    对于需求2,我们开发一个Mybatis的拦截器:MultiTenantPlugin。抽取出将要预编译的sql语句,加工后再替换,然后Mybatis最终执行的是我们加工过的sql语句。

    /**
     * 多租户数据隔离插件
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/21 11:58
     **/
    @Intercepts({
            @Signature(type = StatementHandler.class,
                    method = "prepare",
                    args = {Connection.class})})
    public class MultiTenantPlugin extends BasePlugin

    对于需求3,我使用阿里Druid的sql parser模块来实现sql解析和condition附加。其大致过程如下:

    (1)把sql解析成一颗AST,基本每个部分都会有一个对象与之对应。

    (2)遍历AST,获取select、query和SQLExpr,抽取出表名称和别名,交给决策器判断是否需要添加多租户隔离条件。如果需要添加,则扩展原有condition加上多租户筛选条件;否则不做处理

    (3)把修改后的AST重新转成sql语句

    image

    执行结果:

    image

    三、代码参考

    import com.alibaba.druid.sql.SQLUtils;
    import com.alibaba.druid.sql.ast.SQLStatement;
    import com.smartdata360.smartfx.dao.extension.MultiTenantContent;
    import com.smartdata360.smartfx.dao.sqlparser.ITableFieldConditionDecision;
    import com.smartdata360.smartfx.dao.sqlparser.SqlConditionHelper;
    import org.apache.commons.lang3.StringUtils;
    import org.apache.ibatis.executor.statement.StatementHandler;
    import org.apache.ibatis.mapping.BoundSql;
    import org.apache.ibatis.plugin.Intercepts;
    import org.apache.ibatis.plugin.Invocation;
    import org.apache.ibatis.plugin.Signature;
    import org.apache.ibatis.reflection.MetaObject;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.sql.Connection;
    import java.util.*;
    import java.util.regex.Pattern;
    
    /**
     * 多租户数据隔离插件
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/21 11:58
     **/
    @Intercepts({
            @Signature(type = StatementHandler.class,
                    method = "prepare",
                    args = {Connection.class})})
    public class MultiTenantPlugin extends BasePlugin {
    
        private final Logger logger = LoggerFactory.getLogger(MultiTenantPlugin.class);
    
        /**
         * 当前数据库的方言
         */
        private String dialect;
        /**
         * 多租户字段名称
         */
        private String tenantIdField;
    
        /**
         * 需要识别多租户字段的表名称的正则表达式
         */
        private Pattern tablePattern;
    
        /**
         * 需要识别多租户字段的表名称列表
         */
        private Set<String> tableSet;
    
        private SqlConditionHelper conditionHelper;
    
    
        @Override
        public Object intercept(Invocation invocation) throws Throwable {
            String tenantId = MultiTenantContent.getCurrentTenantId();
            //租户id为空时不做处理
            if (StringUtils.isBlank(tenantId)) {
                return invocation.proceed();
            }
            StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
            BoundSql boundSql = statementHandler.getBoundSql();
            String newSql = addTenantCondition(boundSql.getSql(), tenantId);
            MetaObject boundSqlMeta = getMetaObject(boundSql);
            //把新sql设置到boundSql
            boundSqlMeta.setValue("sql", newSql);
    
            return invocation.proceed();
        }
    
        @Override
        public void setProperties(Properties properties) {
            dialect = properties.getProperty("dialect");
            if (StringUtils.isBlank(dialect))
                throw new IllegalArgumentException("MultiTenantPlugin need dialect property value");
            tenantIdField = properties.getProperty("tenantIdField");
            if (StringUtils.isBlank(tenantIdField))
                throw new IllegalArgumentException("MultiTenantPlugin need tenantIdField property value");
    
            String tableRegex = properties.getProperty("tableRegex");
            if (!StringUtils.isBlank(tableRegex))
                tablePattern = Pattern.compile(tableRegex);
    
            String tableNames = properties.getProperty("tableNames");
            if (!StringUtils.isBlank(tableNames)) {
                tableSet = new HashSet<String>(Arrays.asList(StringUtils.split(tableNames)));
            }
            if (tablePattern == null || tableSet == null)
                throw new IllegalArgumentException("MultiTenantPlugin tableRegex and tableNames must have one");
    
            /**
             * 多租户条件字段决策器
             */
            ITableFieldConditionDecision conditionDecision = new ITableFieldConditionDecision() {
                @Override
                public boolean isAllowNullValue() {
                    return false;
                }
                @Override
                public boolean adjudge(String tableName, String fieldName) {
                    if (tableRegex != null && tableRegex.matches(tableName)) return true;
                    if (tableSet != null && tableSet.contains(tableName)) return true;
                    return false;
                }
            };
            conditionHelper = new SqlConditionHelper(conditionDecision);
        }
    
    
        /**
         * 给sql语句where添加租户id过滤条件
         *
         * @param sql      要添加过滤条件的sql语句
         * @param tenantId 当前的租户id
         * @return 添加条件后的sql语句
         */
        private String addTenantCondition(String sql, String tenantId) {
            if (StringUtils.isBlank(sql) || StringUtils.isBlank(tenantIdField)) return sql;
            List<SQLStatement> statementList = SQLUtils.parseStatements(sql, dialect);
            if (statementList == null || statementList.size() == 0) return sql;
    
            SQLStatement sqlStatement = statementList.get(0);
            conditionHelper.addStatementCondition(sqlStatement, tenantIdField, tenantId);
            return SQLUtils.toSQLString(statementList, dialect);
        }
    
    }
    import com.alibaba.druid.sql.SQLUtils;
    import com.alibaba.druid.sql.ast.SQLExpr;
    import com.alibaba.druid.sql.ast.SQLStatement;
    import com.alibaba.druid.sql.ast.expr.*;
    import com.alibaba.druid.sql.ast.statement.*;
    import com.alibaba.druid.util.JdbcConstants;
    import org.apache.commons.lang3.NotImplementedException;
    import org.apache.commons.lang3.StringUtils;
    
    import java.util.List;
    
    /**
     * sql语句where条件处理辅助类
     *
     * @author liushuishang@gmail.com
     * @date 2017/12/21 15:05
     **/
    public class SqlConditionHelper {
    
        private ITableFieldConditionDecision conditionDecision;
    
        public SqlConditionHelper(ITableFieldConditionDecision conditionDecision) {
            this.conditionDecision = conditionDecision;
        }
    
        /**
         * 为sql'语句添加指定where条件
         *
         * @param sqlStatement
         * @param fieldName
         * @param fieldValue
         */
        public void addStatementCondition(SQLStatement sqlStatement, String fieldName, String fieldValue) {
            if (sqlStatement instanceof SQLSelectStatement) {
                SQLSelectQueryBlock queryObject = (SQLSelectQueryBlock) ((SQLSelectStatement) sqlStatement).getSelect().getQuery();
                addSelectStatementCondition(queryObject, queryObject.getFrom(), fieldName, fieldValue);
            } else if (sqlStatement instanceof SQLUpdateStatement) {
                SQLUpdateStatement updateStatement = (SQLUpdateStatement) sqlStatement;
                addUpdateStatementCondition(updateStatement, fieldName, fieldValue);
            } else if (sqlStatement instanceof SQLDeleteStatement) {
                SQLDeleteStatement deleteStatement = (SQLDeleteStatement) sqlStatement;
                addDeleteStatementCondition(deleteStatement, fieldName, fieldValue);
            } else if (sqlStatement instanceof SQLInsertStatement) {
                SQLInsertStatement insertStatement = (SQLInsertStatement) sqlStatement;
                addInsertStatementCondition(insertStatement, fieldName, fieldValue);
            }
        }
    
        /**
         * 为insert语句添加where条件
         *
         * @param insertStatement
         * @param fieldName
         * @param fieldValue
         */
        private void addInsertStatementCondition(SQLInsertStatement insertStatement, String fieldName, String fieldValue) {
            if (insertStatement != null) {
                SQLInsertInto sqlInsertInto = insertStatement;
                SQLSelect sqlSelect = sqlInsertInto.getQuery();
                if (sqlSelect != null) {
                    SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) sqlSelect.getQuery();
                    addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
                }
            }
        }
    
    
        /**
         * 为delete语句添加where条件
         *
         * @param deleteStatement
         * @param fieldName
         * @param fieldValue
         */
        private void addDeleteStatementCondition(SQLDeleteStatement deleteStatement, String fieldName, String fieldValue) {
            SQLExpr where = deleteStatement.getWhere();
            //添加子查询中的where条件
            addSQLExprCondition(where, fieldName, fieldValue);
    
            SQLExpr newCondition = newEqualityCondition(deleteStatement.getTableName().getSimpleName(),
                    deleteStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
            deleteStatement.setWhere(newCondition);
    
        }
    
        /**
         * where中添加指定筛选条件
         *
         * @param where      源where条件
         * @param fieldName
         * @param fieldValue
         */
        private void addSQLExprCondition(SQLExpr where, String fieldName, String fieldValue) {
            if (where instanceof SQLInSubQueryExpr) {
                SQLInSubQueryExpr inWhere = (SQLInSubQueryExpr) where;
                SQLSelect subSelectObject = inWhere.getSubQuery();
                SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
                addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
            } else if (where instanceof SQLBinaryOpExpr) {
                SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) where;
                SQLExpr left = opExpr.getLeft();
                SQLExpr right = opExpr.getRight();
                addSQLExprCondition(left, fieldName, fieldValue);
                addSQLExprCondition(right, fieldName, fieldValue);
            } else if (where instanceof SQLQueryExpr) {
                SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) (((SQLQueryExpr) where).getSubQuery()).getQuery();
                addSelectStatementCondition(selectQueryBlock, selectQueryBlock.getFrom(), fieldName, fieldValue);
            }
        }
    
        /**
         * 为update语句添加where条件
         *
         * @param updateStatement
         * @param fieldName
         * @param fieldValue
         */
        private void addUpdateStatementCondition(SQLUpdateStatement updateStatement, String fieldName, String fieldValue) {
            SQLExpr where = updateStatement.getWhere();
            //添加子查询中的where条件
            addSQLExprCondition(where, fieldName, fieldValue);
            SQLExpr newCondition = newEqualityCondition(updateStatement.getTableName().getSimpleName(),
                    updateStatement.getTableSource().getAlias(), fieldName, fieldValue, where);
            updateStatement.setWhere(newCondition);
        }
    
        /**
         * 给一个查询对象添加一个where条件
         *
         * @param queryObject
         * @param fieldName
         * @param fieldValue
         */
        private void addSelectStatementCondition(SQLSelectQueryBlock queryObject, SQLTableSource from, String fieldName, String fieldValue) {
            if (StringUtils.isBlank(fieldName) || from == null || queryObject == null) return;
    
            SQLExpr originCondition = queryObject.getWhere();
            if (from instanceof SQLExprTableSource) {
                String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();
                String alias = from.getAlias();
                SQLExpr newCondition = newEqualityCondition(tableName, alias, fieldName, fieldValue, originCondition);
                queryObject.setWhere(newCondition);
            } else if (from instanceof SQLJoinTableSource) {
                SQLJoinTableSource joinObject = (SQLJoinTableSource) from;
                SQLTableSource left = joinObject.getLeft();
                SQLTableSource right = joinObject.getRight();
    
                addSelectStatementCondition(queryObject, left, fieldName, fieldValue);
                addSelectStatementCondition(queryObject, right, fieldName, fieldValue);
    
            } else if (from instanceof SQLSubqueryTableSource) {
                SQLSelect subSelectObject = ((SQLSubqueryTableSource) from).getSelect();
                SQLSelectQueryBlock subQueryObject = (SQLSelectQueryBlock) subSelectObject.getQuery();
                addSelectStatementCondition(subQueryObject, subQueryObject.getFrom(), fieldName, fieldValue);
            } else {
                throw new NotImplementedException("未处理的异常");
            }
        }
    
        /**
         * 根据原来的condition创建一个新的condition
         *
         * @param tableName       表名称
         * @param tableAlias      表别名
         * @param fieldName
         * @param fieldValue
         * @param originCondition
         * @return
         */
        private SQLExpr newEqualityCondition(String tableName, String tableAlias, String fieldName, String fieldValue, SQLExpr originCondition) {
            //如果不需要设置条件
            if (!conditionDecision.adjudge(tableName, fieldName)) return originCondition;
            //如果条件字段不允许为空
            if (fieldValue == null && !conditionDecision.isAllowNullValue()) return originCondition;
    
            String filedName = StringUtils.isBlank(tableAlias) ? fieldName : tableAlias + "." + fieldName;
            SQLExpr condition = new SQLBinaryOpExpr(new SQLIdentifierExpr(filedName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
            return SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, condition, false, originCondition);
        }
    
    
        public static void main(String[] args) {
    //        String sql = "select * from user s  ";
    //        String sql = "select * from user s where s.name='333'";
    //        String sql = "select * from (select * from tab t where id = 2 and name = 'wenshao') s where s.name='333'";
    //        String sql="select u.*,g.name from user u join user_group g on u.groupId=g.groupId where u.name='123'";
    
    //        String sql = "update user set name=? where id =(select id from user s)";
    //        String sql = "delete from user where id = ( select id from user s )";
    
    //        String sql = "insert into user (id,name) select g.id,g.name from user_group g where id=1";
    
            String sql = "select u.*,g.name from user u join (select * from user_group g  join user_role r on g.role_code=r.code  ) g on u.groupId=g.groupId where u.name='123'";
            List<SQLStatement> statementList = SQLUtils.parseStatements(sql, JdbcConstants.POSTGRESQL);
            SQLStatement sqlStatement = statementList.get(0);
            //决策器定义
            SqlConditionHelper helper = new SqlConditionHelper(new ITableFieldConditionDecision() {
                @Override
                public boolean adjudge(String tableName, String fieldName) {
                    return true;
                }
    
                @Override
                public boolean isAllowNullValue() {
                    return false;
                }
            });
            //添加多租户条件,domain是字段ignc,yay是筛选值
            helper.addStatementCondition(sqlStatement, "domain", "yay");
            System.out.println("源sql:" + sql);
            System.out.println("修改后sql:" + SQLUtils.toSQLString(statementList, JdbcConstants.POSTGRESQL));
        }
    
    
    }

    因为时间和环境限制,仅仅提供一个基础版本,可能测试不够充分,欢迎提出修正意见。

  • 相关阅读:
    Siege 3.0 正式版发布,压力测试工具
    Pomm 1.1.2 发布,专为 PG 设计的 ORM 框架
    Whonix 0.5.6 发布,匿名通用操作系统
    国内开源 java cms,Jspxcms 2.0 发布
    EZNamespaceExtensions.Net v2013增加对上下文菜单、缩略图、图标、属性表的支持
    GNU Guile 2.0.9 发布,Scheme 实现
    jdao 1.0.4 发布 轻量级的orm工具包
    OpenSearchServer 1.4 RC4 发布
    Percona Server for MySQL 5.5.3030.2
    Samba 4.0.5 发布
  • 原文地址:https://www.cnblogs.com/yuananyun/p/8093853.html
Copyright © 2020-2023  润新知