• 基于SpringDataRedis实现一个分布式锁工具类


    基础依赖

     <dependency>
           <groupId>org.springframework.boot</groupId>
           <artifactId>spring-boot-starter-data-redis</artifactId>
           <version>2.x.x.RELEASE</version>
     </dependency>

    核心类 DRedisLock

    package com.idanchuang.component.redis.util.task;
    
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.data.redis.core.StringRedisTemplate;
    import org.springframework.data.redis.core.script.DefaultRedisScript;
    
    import java.net.InetAddress;
    import java.net.UnknownHostException;
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.List;
    import java.util.UUID;
    import java.util.concurrent.TimeUnit;
    import java.util.concurrent.atomic.AtomicLong;
    import java.util.concurrent.locks.Condition;
    import java.util.concurrent.locks.Lock;
    
    /**
     * 基于Redis的分布式锁(线程内可重入)
     * @author yjy
     * @date 2019/11/27 11:07
     **/
    public class DRedisLock implements Lock {
    
        private static final Logger log = LoggerFactory.getLogger(DRedisLock.class);
    
        /** 默认的锁超时时间 */
        public final static long DEFAULT_TIMEOUT = 30000L;
        /** 锁key前缀 */
        public final static String LOCK_PREFIX = "D_LOCK_";
        /** 默认的获取锁超时时间 */
        public final static long DEFAULT_TRY_LOCK_TIMEOUT = 10000L;
        /** 等待锁时, 自旋尝试的周期, 默认10毫秒 */
        public final static long DEFAULT_LOOP_INTERVAL = 10L;
    
        /** 序列值, 用于确保锁value的唯一性 */
        private static AtomicLong SERIAL_NUM;
        /** 最大序列值 */
        private static long MAX_SERIAL;
        /** 本机host */
        private static String CURRENT_HOST;
    
        /** StringRedisTemplate */
        private final StringRedisTemplate redisTemplate;
        /** 锁Key */
        private final String lockKey;
        /** 锁超时时间(单位毫秒) */
        private final long timeout;
        /** 等待锁时, 自旋尝试的周期(单位毫秒) */
        private final long loopInterval;
        /** 主机+线程id */
        private final String hostThreadId;
        /** 锁定值 */
        private final String lockValue;
        /** 是否重入 */
        private boolean reentrant = false;
        /** 是否持有锁 */
        private boolean locked = false;
    
        static {
            try {
                SERIAL_NUM = new AtomicLong(0);
                MAX_SERIAL = 99999999L;
                CURRENT_HOST = InetAddress.getLocalHost().getHostAddress();
            } catch (UnknownHostException e) {
                CURRENT_HOST = UUID.randomUUID().toString();
                log.warn("DRedisLock > can not get local host, use uuid: {}", CURRENT_HOST);
            }
        }
    
        public DRedisLock(String lockName) {
            this(lockName, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL);
        }
    
        public DRedisLock(String lockName, long timeout) {
            this(lockName, timeout, DEFAULT_LOOP_INTERVAL);
        }
    
        public DRedisLock(String lockName, long timeout, long loopInterval) {
            if (lockName == null) {
                throw new IllegalArgumentException("lockName must assigned");
            }
            this.redisTemplate = SpringUtil.getBean(StringRedisTemplate.class);
            this.lockKey = LOCK_PREFIX + lockName;
            this.timeout = timeout;
            this.loopInterval = loopInterval;
            this.hostThreadId = CURRENT_HOST + ":" + Thread.currentThread().getId();
            this.lockValue = this.hostThreadId + ":" + getNextSerial();
        }
    
        /**
         * 获取锁, 如果锁被持有, 将一直等待, 直到超出默认的的DEFAULT_TRY_LOCK_TIMEOUT
         */
        @Override
        public void lock() {
            try {
                if (!tryLock(DEFAULT_TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS)) {
                    throw new RuntimeException("try lock timeout, lockKey: " + this.lockKey);
                }
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
    
        /**
         * 尝试获取锁, 如果锁被持有, 则等待相应的时间(等待锁时可被中断)
         * @throws InterruptedException 被中断等待
         */
        @Override
        public void lockInterruptibly() throws InterruptedException {
            if (!tryLock(DEFAULT_TRY_LOCK_TIMEOUT, TimeUnit.MILLISECONDS, true)) {
                throw new RuntimeException("try lock timeout, lockKey: " + this.lockKey);
            }
        }
    
        /**
         * 尝试获取锁, 只会立即获取一次, 如果锁被占用, 则返回false, 获取成功则返回true
         * @return 是否成功获取锁
         */
        @Override
        public boolean tryLock() {
            try {
                Boolean success = setIfAbsent(this.lockKey, this.lockValue, this.timeout / 1000);
                if (success != null && success) {
                    locked = true;
                    log.debug("Lock success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                    return true;
                } else {
                    // 如果持有锁的是当前线程, 则重入
                    String script = "local val,ttl=ARGV[1],ARGV[2] ";
                    script += "if redis.call('EXISTS', KEYS[1])==1 then local curValue = redis.call('GET', KEYS[1]) if string.find(curValue, val)==1 then local curTtl = redis.call('TTL', KEYS[1]) redis.call('EXPIRE', KEYS[1], curTtl + ttl) return true else return false end else return false end";
                    DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
                    redisScript.setResultType(Boolean.class);
                    redisScript.setScriptText(script);
                    List<String> keys = new ArrayList<>();
                    keys.add(this.lockKey);
                    success = redisTemplate.execute(redisScript, keys, this.hostThreadId, String.valueOf(Math.max(this.timeout / 1000L, 1)));
                    if (success != null && success) {
                        this.reentrant = true;
                        locked = true;
                        log.debug("Lock reentrant success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                        return true;
                    }
                }
            } catch (Exception e) {
                log.error("tryLock error, do unlock, lockKey: {}, lockValue: {}", this.lockKey, lockValue, e);
                unlock();
            }
            return false;
        }
    
        /**
         * 使用lua脚本的方式实现setIfAbsent, 因为当业务应用使用了redisson时, 直接使用template的setIfAbsent返回值为null
         * @param key key
         * @param value 值
         * @param timeoutSecond 超时时间
         * @return 是否成功设值
         */
        private Boolean setIfAbsent(String key, String value, long timeoutSecond) {
            String script = "local val,ttl=ARGV[1],ARGV[2] ";
            script += "if redis.call('EXISTS', KEYS[1])==1 then return false else redis.call('SET', KEYS[1], ARGV[1]) redis.call('EXPIRE', KEYS[1], ARGV[2]) return true end";
            DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
            redisScript.setResultType(Boolean.class);
            redisScript.setScriptText(script);
            List<String> keys = new ArrayList<>();
            keys.add(key);
            return redisTemplate.execute(redisScript, keys, value, String.valueOf(timeoutSecond));
        }
    
        /**
         * 尝试获取锁, 如果锁被占用, 则持续尝试获取, 直到超过指定的time时间
         * @param time 等待锁的时间
         * @param unit time的单位
         * @return 是否成功获取锁
         * @throws InterruptedException 被中断
         */
        @Override
        public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
            return tryLock(time, unit, false);
        }
    
        /**
         * 尝试获取锁, 如果锁被占用, 则持续尝试获取, 直到超过指定的time时间
         * @param time 等待锁的时间
         * @param unit time的单位
         * @param interruptibly 等待是否可被中断
         * @return 是否成功获取锁
         * @throws InterruptedException 被中断
         */
        private boolean tryLock(long time, TimeUnit unit, boolean interruptibly) throws InterruptedException {
            long millis = unit.convert(time, TimeUnit.MILLISECONDS);
            long current = System.currentTimeMillis();
            do {
                if (interruptibly && Thread.interrupted()) {
                    throw new RuntimeException("tryLock interrupted");
                }
                if (tryLock()) {
                    return true;
                }
                Thread.sleep(loopInterval);
            } while (System.currentTimeMillis() - current < millis);
            return false;
        }
    
        /**
         * 释放锁
         */
        @Override
        public void unlock() {
            try {
                if (!locked) {
                    return;
                }
                if (this.reentrant) {
                    log.debug("Unlock reentrant success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                    return;
                }
                // 使用lua脚本处理锁判断和释放
                String script = "if redis.call('get', KEYS[1]) == ARGV[1] then redis.call('del', KEYS[1]) return true else return false end";
                DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
                redisScript.setResultType(Boolean.class);
                redisScript.setScriptText(script);
                Boolean res = this.redisTemplate.execute(redisScript, Collections.singletonList(this.lockKey), this.lockValue);
                if (res != null && res) {
                    locked = false;
                    log.debug("Unlock success, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
                    return;
                }
            } catch (Exception e) {
                log.error("Unlock error", e);
            }
            log.warn("Unlock failed, lockKey: {}, lockValue: {}", this.lockKey, this.lockValue);
        }
    
        @Override
        public Condition newCondition() {
            throw new UnsupportedOperationException();
        }
    
        /**
         * @return 下一个序列值
         */
        private static synchronized long getNextSerial() {
            long serial = SERIAL_NUM.incrementAndGet();
            if (serial > MAX_SERIAL) {
                serial = serial - MAX_SERIAL;
                SERIAL_NUM.set(serial);
            }
            return serial;
        }
    
        public static AtomicLong getSerialNum() {
            return SERIAL_NUM;
        }
    
        public static long getMaxSerial() {
            return MAX_SERIAL;
        }
    
        public static String getCurrentHost() {
            return CURRENT_HOST;
        }
    
        public String getLockKey() {
            return lockKey;
        }
    
        public long getTimeout() {
            return timeout;
        }
    
        public long getLoopInterval() {
            return loopInterval;
        }
    
        public String getHostThreadId() {
            return hostThreadId;
        }
    
        public String getLockValue() {
            return lockValue;
        }
    
        public boolean isReentrant() {
            return reentrant;
        }
    
        @Override
        public String toString() {
            return "DRedisLock{" +
                    "lockKey='" + lockKey + '\'' +
                    ", timeout=" + timeout +
                    ", loopInterval=" + loopInterval +
                    ", hostThreadId='" + hostThreadId + '\'' +
                    ", lockValue='" + lockValue + '\'' +
                    ", reentrant=" + reentrant +
                    '}';
        }
    }

    工具类 SpringUtil

    package com.idanchuang.component.core.util;
    
    import org.springframework.beans.BeansException;
    import org.springframework.context.ApplicationContext;
    import org.springframework.context.ApplicationContextAware;
    import org.springframework.stereotype.Component;
    
    @Component
    public class SpringUtil implements ApplicationContextAware {
    
        private static ApplicationContext applicationContext;
    
        public static <T> T getBean(Class<T> tClass) {
            checkState();
            return applicationContext.getBean(tClass);
        }
    
        public static <T> T getBean(String beanName) {
            checkState();
            return (T)applicationContext.getBean(beanName);
        }
    
        public static <T> T getBean(String beanName, Class<T> requiredType) {
            checkState();
            return applicationContext.getBean(beanName, requiredType);
        }
    
        private static void checkState() {
            if (SpringUtil.applicationContext == null) {
                throw new IllegalStateException("SpringUtil applicationContext is unready");
            }
        }
    
        @Override
        public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
            SpringUtil.applicationContext = applicationContext;
        }
    }
    View Code

    至此, 我们已经可以开始使用分布式锁的功能啦, 如下

    DRedisLock lock = new DRedisLock("testa");
    try {
        lock.lock();
        int b = a;
        a = b + 1;
        System.out.println(a);
    } finally {
        lock.unlock();
    }

    我觉得这样使用起来也太麻烦了, 还要自己实例化lock对象来加锁和释放锁, 如果忘记释放的话问题就很大, 所以我又封装了一个 DRedisLocks 类

    package com.idanchuang.component.redis.util;
    
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.util.concurrent.*;
    
    import static com.idanchuang.component.redis.util.DRedisLock.*;
    
    /**
     * 基于Redis的分布式锁
     * @author yjy
     * @date 2019/11/27 11:07
     **/
    public class DRedisLocks {
    
        private static final Logger log = LoggerFactory.getLogger(DRedisLocks.class);
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param runnable 要执行的代码块
         */
        public static void runWithLock(String lockName, Runnable runnable) {
            runWithLock(lockName, DEFAULT_TRY_LOCK_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, runnable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param callable 要执行的代码块
         * @param <V> 返回类型
         * @return 执行结果
         */
        public static <V> V runWithLock(String lockName, Callable<V> callable) {
            return runWithLock(lockName, DEFAULT_TRY_LOCK_TIMEOUT, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, callable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param tryTimeout 获取锁的超时时间
         * @param runnable 要执行的代码块
         */
        public static void runWithLock(String lockName, long tryTimeout, Runnable runnable) {
            runWithLock(lockName, tryTimeout, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, runnable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param tryTimeout 获取锁的超时时间
         * @param callable 要执行的代码块
         * @param <V> 返回类型
         * @return 执行结果
         */
        public static <V> V runWithLock(String lockName, long tryTimeout, Callable<V> callable) {
            return runWithLock(lockName, tryTimeout, DEFAULT_TIMEOUT, DEFAULT_LOOP_INTERVAL, callable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param tryTimeout 获取锁的超时时间
         * @param lockTimeout 持有锁的超时时间
         * @param runnable 要执行的代码块
         */
        public static void runWithLock(String lockName, long tryTimeout, long lockTimeout, Runnable runnable) {
            runWithLock(lockName, tryTimeout, lockTimeout, DEFAULT_LOOP_INTERVAL, runnable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param tryTimeout 获取锁的超时时间
         * @param lockTimeout 持有锁的超时时间
         * @param callable 要执行的代码块
         * @param <V> 返回类型
         * @return 执行结果
         */
        public static <V> V runWithLock(String lockName, long tryTimeout, long lockTimeout, Callable<V> callable) {
            return runWithLock(lockName, tryTimeout, lockTimeout, DEFAULT_LOOP_INTERVAL, callable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param tryTimeout 获取锁的超时时间
         * @param lockTimeout 持有锁的超时时间
         * @param loopInterval 自旋获取锁间隔
         * @param runnable 要执行的代码块
         */
        public static void runWithLock(String lockName, long tryTimeout, long lockTimeout, long loopInterval, Runnable runnable) {
            Callable<Void> callable = () -> {
                runnable.run();
                return null;
            };
           runWithLock(lockName, tryTimeout, lockTimeout, loopInterval, callable);
        }
    
        /**
         * 执行分布式同步代码块
         * @param lockName 锁名称
         * @param tryTimeout 获取锁的超时时间
         * @param lockTimeout 持有锁的超时时间
         * @param loopInterval 自旋获取锁间隔
         * @param callable 要执行的代码块
         * @param <V> 返回类型
         * @return 执行结果
         */
        public static <V> V runWithLock(String lockName, long tryTimeout, long lockTimeout, long loopInterval, Callable<V> callable) {
            DRedisLock lock = new DRedisLock(lockName, lockTimeout, loopInterval);
            log.debug("Init DRedisLock > {}", lock);
            try {
                if (lock.tryLock(tryTimeout, TimeUnit.MILLISECONDS)) {
                    log.debug("Lock successful, lockName: {}", lockName);
                    return callable.call();
                }
                throw new RuntimeException("Get redisLock failed, lockName: " + lockName);
            } catch (RuntimeException e) {
                throw e;
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                lock.unlock();
                log.debug("Unlock successful, lockName: {}", lockName);
            }
        }
    
    }

    现在我们就可以通过下面这种方式来使用分布式锁了, 而且不用自己手动加锁释放锁, 轻松了不少

    DRedisLocks.runWithLock("testa", () -> {
        int b = a;
        a = b + 1;
        System.out.println(a);
    });

    那么针对整个方法的同步锁, 这样使用还是不够优雅, 能不能做到一个注解就实现分布式锁的能力, 答案当然是可以的, 我又新建了几个类

    RedisLock 注解类

    package com.idanchuang.component.redis.annotation;
    
    import java.lang.annotation.*;
    
    /**
     * @author yjy
     * @date 2020/5/8 9:53
     **/
    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    @Inherited
    public @interface RedisLock {
    
        /** 锁名称 如果不指定,则为类名:方法名 */
        String value() default "";
    
        /** 获取锁的超时时间 ms */
        long tryTimeout() default 10000L;
    
        /** 持有锁的超时时间 ms */
        long lockTimeout() default 30000L;
    
        /** 自旋获取锁间隔 ms */
        long loopInterval() default 10L;
    
        /** 自定义业务key (解析后追加在锁名称中) */
        String[] keys() default {};
    
        /** 错误提示信息 */
        String errMessage() default "";
    
    }

    RedisLockAspect AOP配置类

    package com.idanchuang.component.redis.aspect;
    
    import com.idanchuang.component.base.exception.common.ErrorCode;
    import com.idanchuang.component.base.exception.core.ExFactory;
    import com.idanchuang.component.redis.annotation.RedisLock;
    import com.idanchuang.component.redis.helper.BusinessKeyHelper;
    import com.idanchuang.component.redis.util.DRedisLock;
    import org.aspectj.lang.ProceedingJoinPoint;
    import org.aspectj.lang.annotation.Around;
    import org.aspectj.lang.annotation.Aspect;
    import org.aspectj.lang.annotation.Pointcut;
    import org.aspectj.lang.reflect.MethodSignature;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.stereotype.Component;
    import org.springframework.util.StringUtils;
    
    import java.lang.reflect.Method;
    import java.util.concurrent.TimeUnit;
    
    /**
     * Aspect for methods with {@link RedisLock} annotation.
     *
     * @author yjy
     */
    @Aspect
    @Component
    public class RedisLockAspect {
    
        private static final Logger log = LoggerFactory.getLogger(RedisLockAspect.class);
    
        @Pointcut("@annotation(com.idanchuang.component.redis.annotation.RedisLock)")
        public void redisLockAnnotationPointcut() {
        }
    
        @Around("redisLockAnnotationPointcut()")
        public Object invokeWithRedisLock(ProceedingJoinPoint pjp) throws Throwable {
            Method originMethod = resolveMethod(pjp);
            RedisLock annotation = originMethod.getAnnotation(RedisLock.class);
            if (annotation == null) {
                // Should not go through here.
                throw new IllegalStateException("Wrong state for RedisLock annotation");
            }
            DRedisLock lock = null;
            String lockName = getName(annotation.value(), originMethod);
            lockName += BusinessKeyHelper.getKeyName(pjp, annotation.keys());
            try {
                lock = new DRedisLock(lockName, annotation.lockTimeout(), annotation.loopInterval());
                // 获取锁, 如果被占用则等待, 直到获取到锁, 或则等待超时
                if (lock.tryLock(annotation.tryTimeout(), TimeUnit.MILLISECONDS)) {
                    return pjp.proceed();
                } else {
                    String msg = "Get redisLock failed, lockName: " + lockName;
                    log.warn(msg);
                    throw ExFactory.throwWith(ErrorCode.CONFLICT, !StringUtils.isEmpty(annotation.errMessage()) ? msg : annotation.errMessage());
                }
            } finally {
                // 重点: 释放锁
                if (lock != null) {
                    lock.unlock();
                }
            }
        }
    
        /**
         * 获取lockName前缀
         *
         * @param lockName
         * @param originMethod
         * @return java.lang.String
         * @author sxp
         * @date 2020/7/3 11:06
         */
        private String getName(String lockName, Method originMethod) {
            // if 未指定lockName, 则默认取 类名:方法名
            if (StringUtils.isEmpty(lockName)) {
                return originMethod.getDeclaringClass().getSimpleName() + ":" + originMethod.getName();
            } else {
                return lockName;
            }
        }
    
        private Method resolveMethod(ProceedingJoinPoint joinPoint) {
            MethodSignature signature = (MethodSignature) joinPoint.getSignature();
            Class<?> targetClass = joinPoint.getTarget().getClass();
    
            Method method = getDeclaredMethodFor(targetClass, signature.getName(),
                    signature.getMethod().getParameterTypes());
            if (method == null) {
                throw new IllegalStateException("Cannot resolve target method: " + signature.getMethod().getName());
            }
            return method;
        }
    
        /**
         * Get declared method with provided name and parameterTypes in given class and its super classes.
         * All parameters should be valid.
         *
         * @param clazz          class where the method is located
         * @param name           method name
         * @param parameterTypes method parameter type list
         * @return resolved method, null if not found
         */
        private Method getDeclaredMethodFor(Class<?> clazz, String name, Class<?>... parameterTypes) {
            try {
                return clazz.getDeclaredMethod(name, parameterTypes);
            } catch (NoSuchMethodException e) {
                Class<?> superClass = clazz.getSuperclass();
                if (superClass != null) {
                    return getDeclaredMethodFor(superClass, name, parameterTypes);
                }
            }
            return null;
        }
    
    }

    至此, 我们已经可以通过注解来实现接口的分布式锁能力

    /**
     * @author yjy
     * @date 2020/5/8 10:21
     **/
    @Component
    public class LockService {
    
        private static int a = 0;
    
        @RedisLock(value = "customLockName:88888222", lockTimeout = 60000L, tryTimeout = 20000L)
        public void doSomething() {
            a ++;
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println("a: " + a);
        }
    
    }

    以上简单的介绍了我们实现的Redis分布式锁, 其实它的功能不止介绍的这些

    它还支持线程内可重入, 支持超时自动释放锁, 注解模式支持解析参数对象来作为锁资源 等等

    好了, 今天就到这里吧, 拜拜~

  • 相关阅读:
    HDU-6315 Naive Operations 线段树
    18牛客第二场 J farm
    POJ
    SPOJ
    codeforces 501C. Misha and Forest
    Codeforces 584C
    Domination
    HDU-3074 Multiply game
    Codefoeces-689D Friends and Subsequences
    Codeforces Round #486 (Div. 3)
  • 原文地址:https://www.cnblogs.com/imyjy/p/15789410.html
Copyright © 2020-2023  润新知