温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

怎么用MyBatis进行数据权限验证

发布时间:2021-11-10 11:42:59 来源:亿速云 阅读:143 作者:iii 栏目:MySQL数据库

本篇内容主要讲解“怎么用MyBatis进行数据权限验证”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“怎么用MyBatis进行数据权限验证”吧!

首先先创建表

CREATE TABLE `dataprivilegeconfig` (
  `id` bigint(20) NOT NULL AUTO_INCREMENT,
  `project` varchar(32) DEFAULT NULL comment '项目名称',
  `module` varchar(32) NOT NULL comment '模块名称',
  `tableName` varchar(32) NOT NULL comment '表名',
  `statement` varchar(512) NOT NULL comment '配置的SQL片段',
  PRIMARY KEY (`id`)
) ;

怎么用MyBatis进行数据权限验证

使用一个自定义annotation来实现不同模块,拼接不同的SQL文本

package com.bj58.mis.datapriv.plugin.mybatis;
import java.lang.annotation.*;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.TYPE;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
@Documented
@Inherited
@Retention(RUNTIME)
@Target({ TYPE, METHOD })
public @interface DataPrivilege{
    String module() default "all";
}

上文的SQL解析代码

SQLDataPrivilege类

package com.bj58.mis.datapriv.core;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOperator;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.ast.statement.SQLTableSource;
import com.alibaba.druid.sql.ast.statement.SQLUnionQuery;
import com.alibaba.druid.sql.parser.SQLExprParser;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.util.JdbcUtils;
/**
 * Hello world!
 *
 */
public class SQLDataPrivilege {
   public static void main(String[] args) {
      
   }
   
   
   //单例.该对象用于给已经存在的SQL增加数据权限
   private static SQLDataPrivilege INSTANCE = new SQLDataPrivilege();
   public static SQLDataPrivilege getInstance() {
      return INSTANCE;
   }
   //从数据库中获取配置信息
   private SQLDataPrivilege() {
      try {
         Class.forName("com.mysql.jdbc.Driver");
         Connection con = DriverManager.getConnection("jdbc:mysql://127.0.0.1:3306/test", "root", "laohuali@58");
         String sql="select project,module,tableName,group_concat(statement separator ' and ') statement ";
         sql=sql+" from DataPrivilegeConfig where Project='测试' ";
         sql=sql+" group by project,module,tableName";
         PreparedStatement ps = con.prepareStatement(sql);
         ResultSet rs = ps.executeQuery();
         while (rs.next()) {
            Privilege p = new Privilege();
            p.setProject(rs.getString("project"));
            p.setModule(rs.getString("module"));
            p.setTableName(rs.getString("tableName"));
            p.setStatement(rs.getString("statement"));
            privList.add(p);
         }
         rs.close();
         ps.close();
         con.close();
      } catch (ClassNotFoundException e) {
         e.printStackTrace();
      } catch (SQLException e) {
         e.printStackTrace();
      }
   }
   //保存本项目的数据权限配置信息
   private List<Privilege> privList = new ArrayList<Privilege>();
   //在SQL上拼接数据权限
   public String addPrivilege(final String module,final String sql, Map<String, String> varMap) {
      // SQLParserUtils.createSQLStatementParser可以将sql装载到Parser里面
      SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, JdbcUtils.MYSQL);
      // parseStatementList的返回值SQLStatement本身就是druid里面的语法树对象
      List<SQLStatement> stmtList = parser.parseStatementList();
      SQLStatement stmt = stmtList.get(0);
      //如果不是查询,则返回
      if (!(stmt instanceof SQLSelectStatement)) {
         return sql;
      }
      SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
      // 拿到SQLSelect 通过在这里打断点看对象我们可以看出这是一个树的结构
      SQLSelect sqlselect = selectStmt.getSelect();
      SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
       
      parseSubQuery(module,query.getSelectList(), varMap);
      parseTable(module,query, varMap);
      System.out.println(sqlselect.toString());
      return sqlselect.toString();
   }
   //给子查询增加数据权限
   private void parseSubQuery(final String module,final List<SQLSelectItem> fieldList, final Map<String, String> varMap) {
      for (SQLSelectItem item : fieldList) {
         if (item.getExpr() instanceof SQLQueryExpr) {
            SQLQueryExpr expr = (SQLQueryExpr) item.getExpr();
            parseTable(module,expr.getSubQuery().getQueryBlock(), varMap);
         }
      }
   }
   //递归处理嵌套表
   private void parseTable(final String module,final SQLSelectQueryBlock query, final Map<String, String> varMap) {
      if (query == null) {
         return;
      }
      SQLTableSource tableSource = query.getFrom();
      if (tableSource instanceof SQLExprTableSource) {
         //如果是普通的表,则在where中增加数据权限
         SQLExprTableSource table = ((SQLExprTableSource) tableSource);
         String tableName = table.getName().getSimpleName();
         String aliasName = table.getAlias();
         SQLExpr sqlExpr = createSQLExpr(module,tableName, aliasName, varMap);
         createWhereSQLExpr(query, varMap, sqlExpr);
      } else if (tableSource instanceof SQLSubqueryTableSource) {
         //如果是嵌套表,则递归到内层
         SQLSubqueryTableSource table = ((SQLSubqueryTableSource) tableSource);
         parseTable(module,table.getSelect().getQueryBlock(), varMap);
      } else if (tableSource instanceof SQLJoinTableSource) {
         //如果是两个表关联.则在on条件中增加数据权限。并且在左右表中分别判断是否是union all的情况
         SQLJoinTableSource table = ((SQLJoinTableSource) tableSource);
         SQLTableSource left = table.getLeft();
         SQLTableSource right = table.getRight();
         SQLExpr onExpr = table.getCondition();
         if (left instanceof SQLSubqueryTableSource) {
            SQLSubqueryTableSource leftTable = ((SQLSubqueryTableSource) left);
            parseUnion(module,leftTable.getSelect().getQuery(), varMap);
            parseTable(module,leftTable.getSelect().getQueryBlock(), varMap);
         } else if (left instanceof SQLExprTableSource) {
            SQLExprTableSource joinTable = ((SQLExprTableSource) left);
            onExpr = createOnExpr(module,joinTable, onExpr, varMap);
         }
         if (right instanceof SQLSubqueryTableSource) {
            SQLSubqueryTableSource rightTable = ((SQLSubqueryTableSource) right);
            parseUnion(module,rightTable.getSelect().getQuery(), varMap);
            parseTable(module,rightTable.getSelect().getQueryBlock(), varMap);
         } else if (right instanceof SQLExprTableSource) {
            SQLExprTableSource joinTable = ((SQLExprTableSource) right);
            onExpr = createOnExpr(module,joinTable, onExpr, varMap);
         }
         table.setCondition(onExpr);
      }
   }
   //如果是union all的情况,则通过递归进入内层
   private boolean parseUnion(final String module,final SQLSelectQuery query, final Map<String, String> varMap) {
      if (query instanceof SQLUnionQuery) {
         SQLUnionQuery unionQuery = (SQLUnionQuery) query;
         if (unionQuery.getLeft() instanceof SQLUnionQuery) {
            parseUnion(module,unionQuery.getLeft(), varMap);
         } else if (unionQuery.getLeft() instanceof SQLSelectQueryBlock) {
            SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) unionQuery.getLeft();
            parseTable(module,queryBlock, varMap);
         }
         if (unionQuery.getRight() instanceof SQLUnionQuery) {
            parseUnion(module,unionQuery.getRight(), varMap);
         } else if (unionQuery.getRight() instanceof SQLSelectQueryBlock) {
            SQLSelectQueryBlock queryBlock = (SQLSelectQueryBlock) unionQuery.getRight();
            parseTable(module,queryBlock, varMap);
         }
         return true;
      }
      return false;
   }
   //在连接的on条件中拼接权限
   private SQLExpr createOnExpr(final String module,SQLExprTableSource joinTable, SQLExpr onExpr, final Map<String, String> varMap) {
      String tableName = joinTable.getName().getSimpleName();
      String aliasName = joinTable.getAlias();
      SQLExpr sqlExpr = createSQLExpr(module,tableName, aliasName, varMap);
      if (sqlExpr != null) {
         SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(onExpr, SQLBinaryOperator.BooleanAnd, sqlExpr);
         onExpr = newWhereExpr;
      }
      return onExpr;
   }
   //根据配置获取拼接好的权限SQL
   private SQLExpr createSQLExpr(String module,String tableName, String aliasName, final Map<String, String> varMap) {
      StringBuffer constraintsBuffer = new StringBuffer("");
      for (Privilege p : privList) {
         if (tableName.equals(p.getTableName()) && module.equals(p.getModule())) {
            constraintsBuffer.append(p.toString(aliasName, varMap));
         }
      }
      if ("".equals(constraintsBuffer.toString())) {
         return null;
      }
      SQLExprParser constraintsParser = SQLParserUtils
            .createExprParser(constraintsBuffer.toString(), JdbcUtils.MYSQL);
      SQLExpr constraintsExpr = constraintsParser.expr();
      return constraintsExpr;
   }
   //拼接where中的权限信息
   private void createWhereSQLExpr(final SQLSelectQueryBlock query, final Map<String, String> varMap,
         SQLExpr sqlExpr) {
      if (sqlExpr == null) {
         return;
      }
      SQLExpr whereExpr = query.getWhere();
      // 修改where表达式
      if (whereExpr == null) {
         query.setWhere(sqlExpr);
      } else {
         SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, sqlExpr);
         query.setWhere(newWhereExpr);
      }
   }
}

Privilege类

package com.bj58.mis.datapriv.core;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
public class Privilege {
   private String project = null;
   private String module = null;
   private String tableName = null;
   private String statement = null;
   private Map<String,String> varDef = new HashMap<String,String>();
   private Pattern pattern = Pattern.compile("\\{.*?\\}");
   public String getProject() {
      return project;
   }
   public void setProject(String project) {
      if (this.project == null) {
         this.project = project;
      }
   }
   public String getModule() {
      return module;
   }
   public void setModule(String module) {
      if (this.module == null) {
         this.module = module;
      }
   }
   public String getTableName() {
      return tableName;
   }
   public void setTableName(String tableName) {
      if (this.tableName == null) {
         this.tableName = tableName;
      }
   }
   public String getStatement() {
      return statement;
   }
   public void setStatement(String statement) {
      if (this.statement == null) {
         this.statement = statement;
         Matcher m = pattern.matcher(this.statement);
         while (m.find()) {
            String var = m.group().replaceAll("(\\{|\\})", "").trim();
            this.varDef.put(var, "\\{" + var + "\\}");
         }
      }
   }
   public String toString(String aliasName, Map<String, String> varMap) {
      if (aliasName == null || "".equals(aliasName)) {
         aliasName = tableName;
      }
      String sqlString = this.statement.replaceAll("#tab#", aliasName);
 
      for (Entry<String,String> entry: varDef.entrySet()) {
         if (varMap.containsKey(entry.getKey())) {
            sqlString = sqlString.replaceAll(entry.getValue(), varMap.get(entry.getKey()));
         } else {
            throw new RuntimeException("缺少必要信息");
         }
      }
      return sqlString;
   }
}

增加一个MyBatis拦截器实现拼接SQL的功能

package com.bj58.mis.datapriv.plugin.mybatis;
import com.bj58.mis.datapriv.core.SQLDataPrivilege;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Plugin;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class})
})
public class MapperInterceptor implements Interceptor {
    private Properties properties;
    private Map<String,String> moduleMapping=new ConcurrentHashMap<String,String>();
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement mappedStatement = (MappedStatement) args[0];
        Object parameter = args[1];
        final BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        MappedStatement newMs = copyFromMappedStatement(mappedStatement, boundSql,parameter);
        System.out.println(newMs.getBoundSql(parameter).getSql());
        long start = System.currentTimeMillis();
        List returnValue = (List) invocation.proceed();
        long end = System.currentTimeMillis();
        return returnValue;
    }
    private String concatSQL(String mapperId, String sql, Object parameter) {
        String module=moduleMapping.get(mapperId);
        if(module==null){
            initModule(mapperId);
            module=moduleMapping.get(mapperId);
        }
        if("".equals(module)){
            return sql;
        }
        Map<String,String> newParameterMap=new HashMap<String,String>();
        for(Map.Entry<String,Object> entry :  ((Map<String,Object>)parameter).entrySet()){
            if(entry.getValue() instanceof ArrayList){
                StringBuilder sb=new StringBuilder(128);
                sb.append(" ( ");
                for(Object obj:(ArrayList)entry.getValue()) {
                    if(obj instanceof  String) {
                        sb.append("'");
                        sb.append(obj);
                        sb.append("',");
                    }else {
                        sb.append(obj);
                        sb.append(",");
                    }
                }
                sb.deleteCharAt(sb.length()-1);
                sb.append(" ) ");
                newParameterMap.put(entry.getKey(),sb.toString());
            }else{
                newParameterMap.put(entry.getKey(), String.valueOf( entry.getValue()));
            }
        }
        SQLDataPrivilege s =SQLDataPrivilege.getInstance();
        return s.addPrivilege(module,sql,newParameterMap);
    }
    private void initModule(String mapperId){
        String clazzName = mapperId.substring(0, mapperId.lastIndexOf("."));
        try {
            Class clazz = Class.forName(clazzName);
            DataPrivilege clazzDataPrivilege= (DataPrivilege) clazz.getAnnotation(DataPrivilege.class);
            for(Method  method:clazz.getMethods()){
                String key=clazzName+"."+method.getName();
                DataPrivilege methodDataPrivilege=method.getAnnotation(DataPrivilege.class);
                if(methodDataPrivilege!=null){
                    moduleMapping.put(key,methodDataPrivilege.module());
                }else if(clazzDataPrivilege!=null){
                    moduleMapping.put(key,clazzDataPrivilege.module());
                }else{
                    moduleMapping.put(key,"");
                }
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
    }
    public static class BoundSqlSqlSource implements SqlSource {
        private BoundSql boundSql;
        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }
    private MappedStatement copyFromMappedStatement(MappedStatement ms, BoundSql boundSql, Object parameter) {
         String sql = concatSQL(ms.getId(), boundSql.getSql(),parameter);
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), sql,boundSql.getParameterMappings(), boundSql.getParameterObject());
        MetaObject boundSqlObject = SystemMetaObject.forObject(boundSql);
        MetaObject newBoundSqlObject = SystemMetaObject.forObject(newBoundSql);
        newBoundSqlObject.setValue("metaParameters",boundSqlObject.getValue("metaParameters"));
        try {
            Field additionalParametersField=BoundSql.class.getDeclaredField("additionalParameters");
            additionalParametersField.setAccessible(true);
            Map<String, Object> boundSqlAdditionalParametersField= (Map<String, Object>) additionalParametersField.get(boundSql);
            Map<String, Object> newBoundSqlObjectSqlAdditionalParametersField= (Map<String, Object>) additionalParametersField.get(newBoundSql);
            for(Map.Entry<String,Object> entry:boundSqlAdditionalParametersField.entrySet()){
                newBoundSqlObjectSqlAdditionalParametersField.put(entry.getKey(),entry.getValue());
            }
            Field sqlSource=MappedStatement.class.getDeclaredField("sqlSource");
            sqlSource.setAccessible(true);
            sqlSource.set(ms,new BoundSqlSqlSource(newBoundSql));
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return ms;
    }
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }
    @Override
    public void setProperties(Properties properties0) {
        this.properties = properties0;
    }
}

使用的时候,先配置数据库的SQL片段

然后配置MyBatis拦截器插件

@SpringBootApplication
@EnableSwagger2
public class StatisticsApplication {
   public static void main(String[] args) {
      SpringApplication.run(StatisticsApplication.class, args);
   }
   @Bean(name = "sqlSessionFactory")
   public SqlSessionFactory sqlSessionFactory( DataSource dataSource) throws Exception {
      SqlSessionFactoryBean factory = new SqlSessionFactoryBean();
      factory.setDataSource(dataSource);
      factory.setPlugins(new Interceptor[]{mapperInterceptor()});
      ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
      factory.setMapperLocations(resolver.getResources("classpath*:/mapper/*.mapper.xml"));
      return factory.getObject();
   }
   @Bean
   public MapperInterceptor mapperInterceptor() {
      MapperInterceptor mapperInterceptor=new MapperInterceptor();
      return mapperInterceptor;
   }
}

最后在Mapper接口上增加Annotation

怎么用MyBatis进行数据权限验证

到此,相信大家对“怎么用MyBatis进行数据权限验证”有了更深的了解,不妨来实际操作一番吧!这里是亿速云网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI