• 使用AOP统一验签和校参


    一、需求背景

      对外提供服务的接口需要统一做验签和参数合法性校验。每个接口的加签算法相同,不同的是参数的不为空的要求不同。

      要求,在controller层外做校验,校验不通过直接返回,不进入controller层。

    二、需求实现前代码

     在这之前已经对每个请求做了AOP拦截,对每个请求植入了线程号。以及统计每个接口的执行耗时,打印每个接口的返回结果,捕获接口的未检查异常并打印和封装返回结果。

     如:

    /**
     * 为每一个的HTTP请求添加线程号
     *
     * @author yangyongjie
     * @date 2019/9/2
     * @desc
     */
    @Order(1)
    @Aspect
    @Component
    public class LogAspect {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(LogAspect.class);
    
        @Pointcut(value = "@annotation(org.springframework.web.bind.annotation.RequestMapping)")
        private void webPointcut() {
            // doNothing
        }
    
        /**
         * 为所有的HTTP请求添加线程号
         *
         * @param joinPoint
         * @throws Throwable
         */
        @Around(value = "webPointcut()")
        public Object around(ProceedingJoinPoint joinPoint) {
            // 执行开始的时间
            Long beginTime = System.currentTimeMillis();
            // 方法执行前加上线程号,并将线程号放到线程本地变量中
            MDCUtil.init();
            // 获取切点的方法名
            String methodName = joinPoint.getSignature().getName();
            // 执行拦截的方法
            Object result = null;
            try {
                result = joinPoint.proceed();
            } catch (Throwable throwable) {
                LOGGER.error("{}方法执行异常:" + throwable.getMessage(), methodName, throwable);
                LogUtil.sendErrorLogMail("系统异常", throwable);
                result = new CommonResult(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg());
            } finally {
                LOGGER.info("{}方法返回结果:{}", methodName, JacksonJsonUtil.toString(result));
                Long endTime = System.currentTimeMillis();
                LOGGER.info("{}方法耗时{}毫秒", methodName, endTime - beginTime);
                // 方法执行结束移除线程号,并移除线程本地变量,防止内存泄漏
                MDCUtil.remove();
            }
            return result;
        }
    }

    @Order(1) :为多个AOP切面排序,数字越小,先执行谁。

    MDCUtil:

    /**
     * 日志相关工具类
     *
     * @author yangyongjie
     * @date 2019/9/17
     * @desc
     */
    public class MDCUtil {
        private MDCUtil() {
        }
    
        private static final String STR_THREAD_ID = "threadId";
    
        /**
         * 初始化日志参数并保存在线程副本中
         */
        public static void init() {
            String uuid = UUID.randomUUID().toString().replaceAll("-", "");
            MDC.put(STR_THREAD_ID, uuid);
            ThreadContext.currentThreadContext().setThreadId(uuid);
        }
    
        /**
         * 初始化日志参数
         */
        public static void initWithOutContext() {
            String uuid = UUID.randomUUID().toString().replaceAll("-", "");
            MDC.put(STR_THREAD_ID, uuid);
        }
    
        /**
         * 移除线程号和线程副本
         */
        public static void remove() {
            MDC.remove(STR_THREAD_ID);
            ThreadContext.remove();
        }
    
        /**
         * 移除线程号
         */
        public static void removeWithOutContext() {
            MDC.remove(STR_THREAD_ID);
        }
    }

    线程上下文ThreadContext:

    /**
     * 线程上下文,一个线程内所需的上下文变量参数,使用ThreadLocal保存副本
     *
     * @author yangyongjie
     * @date 2019/9/12
     * @desc
     */
    public class ThreadContext {
        /**
         * 每个线程的私有变量,每个线程都有独立的变量副本,所以使用private static final修饰,因为都需要复制进入本地线程
         */
        private static final ThreadLocal<ThreadContext> THREAD_LOCAL = new ThreadLocal<ThreadContext>() {
            @Override
            protected ThreadContext initialValue() {
                return new ThreadContext();
            }
        };
    
        public static ThreadContext currentThreadContext() {
            /*ThreadContext threadContext = THREAD_LOCAL.get();
            if (threadContext == null) {
                THREAD_LOCAL.set(new ThreadContext());
                threadContext = THREAD_LOCAL.get();
            }
            return threadContext;*/
            return THREAD_LOCAL.get();
        }
    
        public static void remove() {
            THREAD_LOCAL.remove();
        }
    
        /**
         * 线程号
         */
        private String threadId;
    
        /**
         * 请求参数
         */
        private Object requestParam;
    
        public String getThreadId() {
            return threadId;
        }
    
        public void setThreadId(String threadId) {
            this.threadId = threadId;
        }
    
        public Object getRequestParam() {
            return requestParam;
        }
    
        public void setRequestParam(Object requestParam) {
            this.requestParam = requestParam;
        }
    
        @Override
        public String toString() {
            return JacksonJsonUtil.toString(this);
        }
    }

     公共返回结果类:

    /**
     * 用于返回给调用方执行结果的公共结果类
     * 自定义返回结果继承此类即可
     *
     * @author yangyongjie
     * @date 2019/9/25
     * @desc
     */
    public class CommonResult {
        /**
         * 返回码,0000表示成功,其余都是失败,9998表示入参不符合要求,9999表示系统异常
         */
        private String code = "0000";
        /**
         * 返回信息
         */
        private String msg = "success";
    
        public CommonResult() {
        }
    
    
        public CommonResult(String code, String msg) {
            this.code = code;
            this.msg = msg;
        }
    
        /**
         * 失败情况
         */
        public void fail(String code, String msg) {
            this.code = code;
            this.msg = msg;
        }
    
        /**
         * 判断是否成功
         */
        @JsonIgnore
        public boolean isSuccess() {
            return StringUtils.equals("0000", code);
        }
    
        public String getCode() {
            return code;
        }
    
        public void setCode(String code) {
            this.code = code;
        }
    
        public String getMsg() {
            return msg;
        }
    
        public void setMsg(String msg) {
            this.msg = msg;
        }
    }

    三、需求具体实现

      1、现在需要再增加一个切面,对需要做验签和参数校验的接口拦截并校验

        1)自定义注解,作用在controller层的方法上,标识此接口需要验签和验参,其有两个属性,一个是方法返回类型,一个是接收参数的实体类。

        方法返回类型用来切面校验不通过封装返回数据,接收参数的实体类对需要验不为空的方法标志了注解,需在切面中进行校验。

    /**
     * 对外请求参数校验注解
     *
     * @author yangyongjie
     * @date 2019/11/5
     * @desc
     */
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface Check {
    
        /**
         * 方法的返回值类型,继承了CommonResult
         */
        Class<? extends CommonResult> value();
    
        /**
         * 校验的目标实体类
         */
        Class<?> paramBean();
    
    }

    如接收参数的实体类定义:

    public class AuthTokenRequest extends BaseRequest {
    
        /**
         * 值为authorization_code
         */
        @ParamVerify(nullable = CheckEnum.NOTNULL)
        private String grant_type;
    }
    
    public class BaseRequest {
        /**
         * 签名
         */
        @ParamVerify(nullable = CheckEnum.NOTNULL)
        private String sign;
    
        /**
         * 分配的接入id
         */
        @ParamVerify(nullable = CheckEnum.NOTNULL)
        private String partnerId;
    }

    属性校验注解:

    /**
     * 字段校验注解,目前只进行非空校验,可扩展
     */
    @Retention(RetentionPolicy.RUNTIME)
    @Target(ElementType.FIELD)
    public @interface ParamVerify {
        /**
         * 是否允许为空
         */
        CheckEnum nullable() default CheckEnum.NULL;
    }

    验签验参切面:

    /**
     * 对外同步接口参数校验切面
     *
     * @author yangyongjie
     * @date 2019/11/5
     * @desc
     */
    @Order(2)
    @Aspect
    @Component
    public class CheckAspect {
    
        private static final Logger LOGGER = LoggerFactory.getLogger(CheckAspect.class);
    
        /**
         * 验签公钥
         */
        @Value("${fx.publicKey}")
        private String fxPublicKey;
    
        @Autowired
        private OutgoingPartnerInfoDao outgoingPartnerInfoDao;
    
        @Pointcut("@annotation(com.xiaomi.mitv.outgoing.common.annotation.Check)")
        private void webPointcut() {
            // donothing
        }
    
        @Around(value = "webPointcut()")
        public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
            // 获取被增强的方法的相关信息
            MethodSignature ms = (MethodSignature) joinPoint.getSignature();
            // 获取被增强的方法
            Method pointcutMethod = ms.getMethod();
            String methodName = pointcutMethod.getName();
            // 对于对外接口,统一进行参数校验
            CommonResult commonResult = null;
            // 判断方法上有没有@Check注解
            if (pointcutMethod.isAnnotationPresent(Check.class)) {
                // 获取到拦截方法的HttpServletRequest
                // 获取当前方法执行的上下文的request
                HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
                // 获取body请求参数
                String bodyString = HttpUtil.getRequestBody(request);
    //            Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class);
                Map<String, Object> originMap = HttpUtil.fromJsonToObject(bodyString, Map.class);
                // 将请求参数放到线程本地拷贝中
                ThreadContext.currentThreadContext().setRequestParam(originMap);
    
                // 得到方法上的Check注解
                Check check = pointcutMethod.getAnnotation(Check.class);
                // 获取切点方法的返回类型
                Class<?> returnType = check.value();
                // 创建对象
                commonResult = (CommonResult) returnType.newInstance();
                // 获取参数签名
                String sign = request.getParameter("sign");
                LOGGER.info("{}-sign={}", methodName, sign);
    
                // 参数校验
                Class<?> beanType = check.paramBean();
                originMap.put("sign", sign);
                if(!HttpUtil.paramCheck(originMap, beanType)){
                    commonResult.fail(ResponseEnum.ERROR_PARAM.getCode(), ResponseEnum.ERROR_PARAM.getMsg());
                    return commonResult;
                }
    
                String partnerId = String.valueOf(originMap.get("partnerId"));
                if (!StringUtil.areNotEmpty(partnerId, sign)) {
                    commonResult.fail(ResponseEnum.ERROR_PARAM_NULL.getCode(), ResponseEnum.ERROR_PARAM_NULL.getMsg());
                    return commonResult;
                }
                // 校验partnerId的有效性
                if (!checkPartnerId(partnerId)) {
                    commonResult.fail(ResponseEnum.ERROR_APP_INVALID.getCode(), ResponseEnum.ERROR_APP_INVALID.getMsg());
                    return commonResult;
                }
                // 组装加签串
                String paramBody = HttpUtil.getAssembleParam(originMap);
                // 验签
                boolean pass;
                try {
                    pass = RSAUtil.rsa256CheckContent(paramBody, sign, fxPublicKey);
                } catch (BssException e) {
                    LogUtil.LogAndMail("验签异常", e);
                    commonResult.fail(ResponseEnum.ERROR_SYSTEM.getCode(), ResponseEnum.ERROR_SYSTEM.getMsg());
                    return commonResult;
                }
                if (!pass) {
                    commonResult.fail(ResponseEnum.ERROR_CHECK_SIGN_FAIL.getCode(), ResponseEnum.ERROR_CHECK_SIGN_FAIL.getMsg());
                    return commonResult;
                }
            }
            // 执行增强方法
            Object result = joinPoint.proceed();
            return result;
        }
    
        /**
         * 校验partnerId的有效性,先查缓存,缓存中没有的话再查询数据库,使用互斥锁
         *
         * @param partnerId
         * @return
         */
        private boolean checkPartnerId(String partnerId) {
            // 先查询缓存,值为1表示存在且有效,值为0表示存在但无效,值为null表示不存在
            String val = RedisUtil.get(CommonConstants.PARTNER_ID + partnerId);
            if (StringUtils.isEmpty(val)) {
                // 缓存中不存在,先拿到互斥锁,再查询数据库,并放进缓存中
                // 获取互斥锁
                String mutexKey = CommonConstants.NX_PARTNER_ID + partnerId;
                boolean flag = RedisUtil.setex(mutexKey, CommonConstants.STR_ONE, 60);
                // 拿到锁
                if (flag) {
                    // 查询数据库
                    OutgoingPartnerInfoDto partnerInfoDto = outgoingPartnerInfoDao.getByPartnerId(partnerId);
                    if (partnerInfoDto != null && StringUtils.equals(CommonConstants.STR_ONE, partnerInfoDto.getStatus())) {
                        // partnerId 存在且有效
                        RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ONE);
                        // 删除锁
                        RedisUtil.del(mutexKey);
                        return true;
                    } else {
                        // partnerId 不存在或无效
                        RedisUtil.set(CommonConstants.PARTNER_ID + partnerId, CommonConstants.STR_ZERO);
                        return false;
                    }
                } else {
                    //休息50毫秒后重试
                    try {
                        Thread.sleep(50);
                    } catch (InterruptedException e) {
                        LOGGER.error("获取partnerId互斥锁异常" + e.getMessage(), e);
                    }
                    return checkPartnerId(partnerId);
                }
                // val 不为空
            } else {
                return StringUtils.equals(CommonConstants.STR_ONE, val);
            }
        }
    
    }

    HttpUtil工具类:

    public class HttpUtil {
    
        private HttpUtil() {
        }
    
        private static final Logger LOGGER = LoggerFactory.getLogger(HttpUtil.class);
    
        /**
         * 获取request中的body信息 JSON格式
         *
         * @param request
         * @return
         */
        public static String getRequestBody(HttpServletRequest request) {
            BufferedReader br = null;
            StringBuilder bodyDataBuilder = new StringBuilder();
            try {
                br = request.getReader();
                String str;
                while ((str = br.readLine()) != null) {
                    bodyDataBuilder.append(str);
                }
                br.close();
            } catch (IOException e) {
                LOGGER.error(e.getMessage(), e);
            } finally {
                if (null != br) {
                    try {
                        br.close();
                    } catch (IOException e) {
                        LOGGER.error(e.getMessage(), e);
                    }
                }
            }
            String bodyString = bodyDataBuilder.toString();
            LOGGER.info("bodyString={}", bodyString);
            return bodyString;
        }
    
        /**
         * 获取request中的body信息,并组装好按“参数=参数值”的格式
         *
         * @param request
         * @return
         */
        public static String getAssembleRequestBody(HttpServletRequest request) {
            String bodyString = getRequestBody(request);
            Map<String, Object> originMap = JacksonJsonUtil.toObject(bodyString, Map.class);
            Map<String, Object> sortedParams = getSortedMap(originMap);
            String assembleBody = getSignContent(sortedParams);
            return assembleBody;
        }
    
        /**
         * 根据requestBody中的原始map获取解析后并组装的参数字符串,根据&符拼接
         *
         * @param originMap
         * @return
         */
        public static String getAssembleParam(Map<String, Object> originMap) {
            return getSignContent(getSortedMap(originMap));
        }
    
    
        /**
         * 将body转成按key首字母排好序
         *
         * @return
         */
        public static Map<String, Object> getSortedMap(Map<String, Object> originMap) {
            Map<String, Object> sortedParams = new TreeMap<String, Object>();
            if (originMap != null && originMap.size() > 0) {
                sortedParams.putAll(originMap);
            }
            return sortedParams;
        }
    
        /**
         * 将排序好的map的key和value拼接成字符串
         *
         * @param sortedParams
         * @return
         */
        public static String getSignContent(Map<String, Object> sortedParams) {
            StringBuffer content = new StringBuffer();
            List<String> keys = new ArrayList<String>(sortedParams.keySet());
            Collections.sort(keys);
            int index = 0;
            for (int i = 0; i < keys.size(); i++) {
                String key = keys.get(i);
                Object value = sortedParams.get(key);
                if (StringUtils.isNotEmpty(key) && value != null) {
                    content.append((index == 0 ? "" : "&") + key + "=" + value);
                    index++;
                }
            }
            return content.toString();
        }
    
        /**
         * Json转实体对象
         *
         * @param jsonStr
         * @param clazz 目标生成实体对象
         * @return
         */
        public static <T> T fromJsonToObject(String jsonStr, Class clazz) {
            T results = null;
            try {
                results = (T) JacksonJsonUtil.toObject(jsonStr, clazz);
            } catch (Exception e) {
            }
            return results;
        }
    
        /**
         * 对请求参数进行校验,目前只进行非空校验
         *
         * @param srcData body数据
         * @param tarClass 校验规则
         * @return 校验成功返回true
         */
        public static <T> boolean paramCheck(Map<String, Object> srcData, Class<T> tarClass){
            try {
                Field[] fields = tarClass.getDeclaredFields();
                for(Field field : fields){
                    ParamVerify verify = field.getAnnotation(ParamVerify.class);
                    if(verify != null){
                        //非空校验,后续若需增加校验类型,应抽离
                        if(verify.nullable() == CheckEnum.NOTNULL){
                            String fn = field.getName();
                            Object val = srcData.get(fn);
                            if(val == null || "".equals(val.toString())){
                                return false;
                            }
                        }
                    }
                }
            }catch (Exception ex){
                LOGGER.info("Param verify error");
                return false;
            }
            return true;
        }
    
    }

    日志工具类:

     /**
         * 打印日志并发送错误邮件
         *
         * @param msg
         * @param t
         */
        public static void LogAndMail(String msg, Throwable t) {
            // 获取调用此工具类的该方法 的调用方信息
            // 查询当前线程的堆栈信息
            StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
            // 按照规则,此方法的上一级调用类为
            StackTraceElement ste = stackTrace[2];
            String className = ste.getClassName();
            String methodName = ste.getMethodName();
            LOGGER.error("{}#{},{}," + t.getMessage(), className, methodName, msg, t);
            // 异步发送邮件
            String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg;
            executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3));
        }
    
    
        /**
         * 只发送错误邮件不打印日志
         *
         * @param msg
         */
        public static void sendErrorLogMail(String msg, Throwable t) {
            // 异步发送邮件
            String ms = "[" + ThreadContext.currentThreadContext().getThreadId() + "]" + msg + assembleStackTrace(t);
            executor.execute(() -> SendMailUtil.sendErrorMail(ms, t, 3));
        }
    
        /**
         * 组装异常堆栈
         *
         * @param t
         * @return
         */
        public static String assembleStackTrace(Throwable t) {
            StringWriter sw = new StringWriter();
            PrintWriter ps = new PrintWriter(sw);
            t.printStackTrace(ps);
            return sw.toString();
        }

    有关两个切面的执行顺序问题,请参考:https://www.cnblogs.com/yangyongjie/p/11800862.html

    END

  • 相关阅读:
    Struts2SpringHibernate整合示例,一个HelloWorld版的在线书店(项目源码+详尽注释+单元测试)
    Java实现蓝桥杯勇者斗恶龙
    Java实现 LeetCode 226 翻转二叉树
    Java实现 LeetCode 226 翻转二叉树
    Java实现 LeetCode 226 翻转二叉树
    Java实现 LeetCode 225 用队列实现栈
    Java实现 LeetCode 225 用队列实现栈
    Java实现 LeetCode 225 用队列实现栈
    Java实现 LeetCode 224 基本计算器
    Java实现 LeetCode 224 基本计算器
  • 原文地址:https://www.cnblogs.com/yangyongjie/p/12535938.html
Copyright © 2020-2023  润新知