• ratelimit+redis+lua对接口限流


    背景:为防止接口QPS太大而造成系统运行卡顿的现象,在这儿以ratelimit+redis+lua对系统接口做了个限流。当时也考虑过使用其他的限流方法,比如微服务生态中使用的sentinel中间件,但是这个如果要实现持久化要进行特殊的配置,比如使用nacos进行持久化,需要修改sntinel源码,相比较而言单纯为了限流儿集成两个中间件会显得比较臃肿,所以到最后还是使用了retelimit+redis+lua这个方案,本身redis系统中就会使用,存储token、部门信息等一些读取次数多的数据。

    一、主要逻辑实现:

    1. 首先确定的是要采用切面的方式,后期如果相对某一个接口进行限流可以直接采用注解的方式。
    2. 其二在redis存储的key的名称要以方法名+ip的方式,这样可以更好的实现思路1指出的问题。
    3. 使用lua脚本直接传到redis中操作,这样可以减少网络开销以及复用,并且可以保证是原子操作。
    4. 第四点就是lua脚本的编写啦,redis 以有序队列进行存储,每一个key值都带有当前得分为当前时间戳的元素,每次新增的的时候都会将过时的元素进行清理,并进行判断是否达到限流条件。

    二、代码实现:

    代码结构:

    限流注解接口类

    package com.heyu.ratelimit.annotation;
    
    import org.aspectj.lang.annotation.Aspect;
    import org.springframework.core.annotation.AliasFor;
    import org.springframework.core.annotation.AnnotationUtils;
    
    import java.lang.annotation.*;
    import java.util.concurrent.TimeUnit;
    
    /**
     * <p>
     * 限流注解,添加了 {@link AliasFor} 必须通过 {@link AnnotationUtils} 获取,才会生效
     * </p>
     *
     * @author: 程鹏
     * @date: 2021-02-24 14:45
     * @Description: 限流切面
     */
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    public @interface RateLimiter {
        long DEFAULT_REQUEST = 5;
    
    
        /**
         * max 最大请求数
         */
        @AliasFor("max") long value() default DEFAULT_REQUEST;
    
        /**
         * max 最大请求数
         */
        @AliasFor("value") long max() default DEFAULT_REQUEST;
    
        /**
         * 限流key
         */
        String key() default "";
    
        /**
         * 超时时长,默认1分钟
         */
        long timeout() default 1;
    
        /**
         * 超时时间单位,默认 分钟
         */
        TimeUnit timeUnit() default TimeUnit.MINUTES;
    }
    

    切面操作类

    package com.heyu.ratelimit.aspect;
    
    import cn.hutool.core.util.StrUtil;
    import com.heyu.ratelimit.annotation.RateLimiter;
    import com.heyu.ratelimit.util.IpUtil;
    import lombok.RequiredArgsConstructor;
    import lombok.extern.slf4j.Slf4j;
    import org.aspectj.lang.ProceedingJoinPoint;
    import org.aspectj.lang.annotation.Around;
    import org.aspectj.lang.annotation.Aspect;
    import org.aspectj.lang.annotation.Pointcut;
    import org.aspectj.lang.reflect.MethodSignature;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.core.annotation.AnnotationUtils;
    import org.springframework.data.redis.core.StringRedisTemplate;
    import org.springframework.data.redis.core.script.RedisScript;
    import org.springframework.stereotype.Component;
    
    import java.lang.reflect.Method;
    import java.time.Instant;
    import java.util.Collections;
    import java.util.concurrent.TimeUnit;
    
    /**
     * @author: 程鹏
     * @date: 2021-02-24 14:13
     * @Description: 限流切面
     */
    @Slf4j
    @Aspect
    @Component
    @RequiredArgsConstructor(onConstructor_ = @Autowired)
    public class RateLimiterAspect {
        private final static String SEPARATOR = ":";
        private final static String REDIS_LIMIT_KEY_PREFIX = "limit:";
        private final StringRedisTemplate stringRedisTemplate;
        private final RedisScript<Long> limitRedisScript;
    
        @Pointcut("@annotation(com.heyu.ratelimit.annotation.RateLimiter)")
        public void rateLimit() {
    
        }
    
        @Around("rateLimit()")
        public Object pointcut(ProceedingJoinPoint point) throws Throwable {
            MethodSignature signature = (MethodSignature) point.getSignature();
    
            Method method = signature.getMethod();
            // 通过 AnnotationUtils.findAnnotation 获取 RateLimiter 注解
            RateLimiter rateLimiter = AnnotationUtils.findAnnotation(method, RateLimiter.class);
            if (rateLimiter != null) {
                String key = rateLimiter.key();
                // 默认用类名+方法名做限流的 key 前缀
                if (StrUtil.isBlank(key)) {
                    key = method.getDeclaringClass().getName() + StrUtil.DOT + method.getName();
                }
                // 最终限流的 key 为 前缀 + IP地址
                // TODO: 此时需要考虑局域网多用户访问的情况,因此 key 后续需要加上方法参数更加合理
                key = key + SEPARATOR + IpUtil.getIpAddr();
    
                long max = rateLimiter.max();
                long timeout = rateLimiter.timeout();
                TimeUnit timeUnit = rateLimiter.timeUnit();
                boolean limited = shouldLimited(key, max, timeout, timeUnit);
                if (limited) {
                    throw new RuntimeException("手速太快了,慢点儿吧~");
                }
            }
    
            return point.proceed();
        }
    
        private boolean shouldLimited(String key, long max, long timeout, TimeUnit timeUnit) {
            // 最终的 key 格式为:
            // limit:自定义key:IP
            // limit:类名.方法名:IP
            key = REDIS_LIMIT_KEY_PREFIX + key;
            // 统一使用单位毫秒
            long ttl = timeUnit.toMillis(timeout);
            // 当前时间毫秒数
            long now = Instant.now().toEpochMilli();
            long expired = now - ttl;
            // 注意这里必须转为 String,否则会报错 java.lang.Long cannot be cast to java.lang.String
            Long executeTimes = stringRedisTemplate.execute(limitRedisScript, Collections.singletonList(key), now + "", ttl + "", expired + "", max + "");
            if (executeTimes != null) {
                if (executeTimes == 0) {
                    log.error("【{}】在单位时间 {} 毫秒内已达到访问上限,当前接口上限 {}", key, ttl, max);
                    return true;
                } else {
                    log.info("【{}】在单位时间 {} 毫秒内访问 {} 次", key, ttl, executeTimes);
                    return false;
                }
            }
            return false;
        }
    }
    

    redis配置类

    package com.heyu.ratelimit.config;
    
    import org.springframework.context.annotation.Bean;
    import org.springframework.core.io.ClassPathResource;
    import org.springframework.data.redis.core.script.DefaultRedisScript;
    import org.springframework.data.redis.core.script.RedisScript;
    import org.springframework.scripting.support.ResourceScriptSource;
    import org.springframework.stereotype.Component;
    
    /**
     * @author: 程鹏
     * @date: 2021-02-26 14:35
     * @Description:
     */
    @Component
    public class RedisConfig {
        @Bean
        @SuppressWarnings("unchecked")
        public RedisScript<Long> limitRedisScript() {
            DefaultRedisScript redisScript = new DefaultRedisScript<>();
            redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("redis/limit.lua")));
            redisScript.setResultType(Long.class);
            return redisScript;
        }
    }
    

    Ip解析类

    package com.heyu.ratelimit.util;
    
    import cn.hutool.core.util.StrUtil;
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.web.context.request.RequestContextHolder;
    import org.springframework.web.context.request.ServletRequestAttributes;
    
    import javax.servlet.http.HttpServletRequest;
    
    /**
     * @author: 程鹏
     * @date: 2021-02-26 14:28
     * @Description:
     */
    @Slf4j
    public class IpUtil {
        private final static String UNKNOWN = "unknown";
        private final static int MAX_LENGTH = 15;
    
        /**
         * 获取IP地址
         * 使用Nginx等反向代理软件, 则不能通过request.getRemoteAddr()获取IP地址
         * 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址,X-Forwarded-For中第一个非unknown的有效IP字符串,则为真实IP地址
         */
        public static String getIpAddr() {
            HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
            String ip = null;
            try {
                ip = request.getHeader("x-forwarded-for");
                if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                    ip = request.getHeader("Proxy-Client-IP");
                }
                if (StrUtil.isEmpty(ip) || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) {
                    ip = request.getHeader("WL-Proxy-Client-IP");
                }
                if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                    ip = request.getHeader("HTTP_CLIENT_IP");
                }
                if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                    ip = request.getHeader("HTTP_X_FORWARDED_FOR");
                }
                if (StrUtil.isEmpty(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
                    ip = request.getRemoteAddr();
                }
            } catch (Exception e) {
                log.error("IPUtils ERROR ", e);
            }
            // 使用代理,则获取第一个IP地址
            if (!StrUtil.isEmpty(ip) && ip.length() > MAX_LENGTH) {
                if (ip.indexOf(StrUtil.COMMA) > 0) {
                    ip = ip.substring(0, ip.indexOf(StrUtil.COMMA));
                }
            }
            log.error("访客ip:"+ip);
            return ip;
        }
    }
    

    lua脚本

    -- 下标从 1 开始
    local key = KEYS[1]
    local now = tonumber(ARGV[1])
    local ttl = tonumber(ARGV[2])
    local expired = tonumber(ARGV[3])
    -- 最大访问量
    local max = tonumber(ARGV[4])
    
    -- 清除过期的数据
    -- 移除指定分数区间内的所有元素,expired 即已经过期的 score
    -- 根据当前时间毫秒数 - 超时毫秒数,得到过期时间 expired
    redis.call('zremrangebyscore', key, 0, expired)
    
    -- 获取 zset 中的当前元素个数
    local current = tonumber(redis.call('zcard', key))
    local next = current + 1
    
    if next > max then
      -- 达到限流大小 返回 0
      return 0;
    else
      -- 往 zset 中添加一个值、得分均为当前时间戳的元素,[value,score]
      redis.call("zadd", key, now, now)
      -- 每次访问均重新设置 zset 的过期时间,单位毫秒
      redis.call("pexpire", key, ttl)
      return next
    end
    

    controller层测试

  • 相关阅读:
    python 3 dict函数 神奇的参数规则
    python 3 黑色魔法元类初探
    私有变量为何传给了子类?
    [转]django-registration quickstart
    DoesNotExist at /account/
    DoesNotExist at /admin/
    setting.py
    Python excel 奇怪的通信规则
    Python 一个奇特的引用设定
    Chrome 内存和CPU消耗量双料冠军
  • 原文地址:https://www.cnblogs.com/pengcool/p/15611014.html
Copyright © 2020-2023  润新知