CountDownLatch
CountDownLatch
基于AQS实现的同步器,允许一个或者多个线程通过await()
方法进入阻塞等待,直到一个或者多个线程执行countDown()
完成。CountDownLatch
在创建时需要传入一个count
值,一旦某个或者多个线程调用了await()
方法,那么需要等待count
值减为0,才能继续执行。
countDown()
方法每执行一次,count(state)值减1,直到减为0。一个线程可以多次调用countDown()
方法,每次调用都会造成count减1
CountDownLatch
在RocketMQ底层通信被大量使用,实现远程调用异步转同步。Netty Client
发送消息之前创建一个ResponseFuture
,ReponseFuture
中有一个CountDownLatch
属性,发送消息之后调用await()
,等待response,当接收到响应之后,调用对应ResponseFuture
中CountDownLatch#countDown
,唤醒阻塞线程。
内部类AQS实现
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
构造函数
public CountDownLatch(int count) {
// count不能为负数
if (count < 0) throw new IllegalArgumentException("count < 0");
// 创建同步器,设置state为count
this.sync = new Sync(count);
}
await
public void await() throws InterruptedException {
// AQS#acquireSharedInterruptibly -> Sync#tryAcquireShared(如果state=0 返回1,立即返回,线程继续向下执行,如果state != 0, 返回-1,线程进入同步队列,阻塞排队)
sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 如果state != 0,tryAcquireShared()方法返回-1,说明需要等待其他线程执行countDown()方法,线程进入同步队列阻塞
// 如果state = 0,tryAcquireShared()方法返回1,线程立即返回,继续向下执行
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 进入同步队列阻塞
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
// 自旋等待state = 0,等待其他线程执行完毕
for (;;) {
final Node p = node.predecessor();
if (p == head) {
// 如果state = 0,表明其他同步线程执行完毕,线程阻塞结束
int r = tryAcquireShared(arg);
if (r >= 0) {
// 更新头节点为自己,并向后唤醒其他阻塞的线程
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
countDown
/**
* count(state)值减1,当减为0时,由于await调用阻塞的线程将被唤醒继续执行
*/
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) { // 将count值减-1,如果count值减1后等于0,返回true,
// count值减1后等于0,唤醒在同步队列上等待的第一个线程,第一个线程会向后传播,唤醒后驱节点(doAcquireSharedInterruptibly)
doReleaseShared();
return true;
}
return false;
}
/**
* 自旋 + CAS完成更新
*/
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
/**
* count(state)值减1后等于0,会调用该方法,该方法唤醒在同步队列上等待的第一个线程
*/
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
获取count
public long getCount() {
return sync.getCount();
}