CountDownLatch是一个同步协助类,通常用于一个或多个线程等待,直到其他线程完成某项工作。
CountDownLatch使用一个计数值进行初始化,调用它提供的await()方法的线程会被阻塞直到该计数值减为0。减计数值的方法是countDown(),该方法可以在同一个线程中多次调用,也可以在多个线程中被调用,当计数值减为0时所有调用await()方法的线程被唤醒。
API
CountDownLatch的构造函数定义如下,count的值被赋值给state。
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
CountDownLatch的常用方法如下:
//当前线程阻塞,直到count/state的值变为0
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
//当前线程阻塞,直到count/state的值变为0或等待timeout的时间
public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
//将count/state的值减1
public void countDown() {
sync.releaseShared(1);
}
使用场景
CountDownLatch一般可以用于两个场景:
- 多个线程等待,随后并发执行;
- 单个线程等待,汇总合并多个线程的执行结果。
多个线程等待
这种场景下,通常是多个线程调用await()方法阻塞,直到其他线程调用countDownLatch()将count的值减为0,如下:
public static void main(String[] args) throws InterruptedException {
CountDownLatch countDownLatch = new CountDownLatch(1);
for (int i = 0; i < 5; i++) {
new Thread(() -> {
try {
//准备完毕……运动员都阻塞在这,等待号令
countDownLatch.await();
String parter = "【" + Thread.currentThread().getName() + "】";
System.out.println(parter + "开始执行……");
} catch (InterruptedException e) {
e.printStackTrace();
}
}).start();
}
Thread.sleep(2000);// 裁判准备发令
countDownLatch.countDown();// 发令枪:执行发令
}
单个线程等待
在很多场景中,主流程需要等待多个不同的任务完成后再处理结果,此时就要求主流程线程需要阻塞直到所有任务执行完成,如下:
public class CountDownLatchTest {
private static volatile int count = 0;
public static void main(String[] args) throws Exception {
CountDownLatch countDownLatch = new CountDownLatch(5);
for (int i = 0; i < 5; i++) {
final int index = i;
new Thread(() -> {
try {
Thread.sleep(1000 + ThreadLocalRandom.current().nextInt(1000));
count++;
System.out.println(Thread.currentThread().getName()+" finish task" + index );
countDownLatch.countDown();
} catch (InterruptedException e) {
e.printStackTrace();
}
}).start();
}
// 主线程在阻塞,当计数器==0,就唤醒主线程往下执行。
countDownLatch.await();
System.out.println("所有任务执行完成,count=" + count);
}
}
源码解析
这里主要介绍最常用的await()和countDown()方法的源码,以及CountDownLatch的原理。
await()
- await()
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
- acquireSharedInterruptibly(permits)方法是AQS中定义的共享锁获取锁的通用方法,实现如下:
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
- tryAcquireShared(arg)是提供给子类实现的模版方法,该方法在CountDownLatch中的实现如下:
protected int tryAcquireShared(int acquires) {
//如果state=0,则返回正值1,否则返回负值-1
return (getState() == 0) ? 1 : -1;
}
- 在第2步中,tryAcquireShared(arg)返回值小于0则调用doAcquireSharedInterruptibly(arg)方法阻塞当前线程,如果返回值大于等于0则方法执行结束,而在第3步的tryAcquireShared(arg)实现可以看出来,当state!=0时调用await()方法的线程都会阻塞,只有state=0时才不会阻塞。下面的测试程序和执行结果可以验证以上结论:
public static void main(String[] args) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
new Thread(() -> {
try {
System.out.println(Thread.currentThread().getName() + ":我阻塞了...");
latch.await();
System.out.println(Thread.currentThread().getName() + ":我被唤醒了...");
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
},"thread-branch").start();
Thread.sleep(100);
System.out.println("调用countDown()方法前,count=" + latch.getCount());
latch.countDown();
Thread.sleep(100);
System.out.println("调用countDown()方法后,count=" + latch.getCount());
latch.await();
System.out.println("这里如果执行了,说明主线程没有阻塞,count=" + latch.getCount());
}
- doAcquireSharedInterruptibly(arg)方法是AQS中的共享锁入队同步等待队列并阻塞通用方法,在Semaphore中详细介绍过,此处不再赘述。
countDown()
countDown()方法的主要作用是将state减1,当state=0时则需要唤醒所有在同步等待队列中阻塞的线程。
- countDown()实现如下:
public void countDown() {
sync.releaseShared(1);
}
- releaseShared()方法是AQS中提供的释放共享锁的通用方法,实现如下:
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
- tryReleaseShared()方法是AQS定义的模版方法,在CounDownLatch中的实现如下:
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
//判断如果state=0,则不需要再唤醒同步等待队列了,因为之前已经唤醒过了
if (c == 0)
return false;
//这里就是countDown()方法的主要作用,将state-1
int nextc = c-1;
//如果使用CAS修改state-1失败,则循环修改直到成功
if (compareAndSetState(c, nextc))
//修改state成功后,如果state=0,则返回true
return nextc == 0;
}
}
- 在第3步中,只有当前这次执行countDown()方法将state的值减为0后,才会返回true,此时才会去执行第2步中的doReleaseShared()方法,该方法将唤醒同步等待队列所有阻塞线程。
CountDownLatch的使用流程可以总结如下:
- 使用一个不小于0的整数初始化CountDownLatch,该整数值赋给state;
- 在state不为0时,所有调用await()方法的线程都会进入同步等待队列阻塞;当state=0后,再调用await()方法不会做任何操作;
- 每次调用countDown()方法时,先判断state是否已经是0了,如果是0则什么都不做直接返回,如果不是0则将state-1后使用CAS更新state;
- 如果更新失败,则循环调用CAS更新直到更新成功;
- 更新成功后,判断state是否为0,如果不是则什么都不做,此时countDown()方法只是将state-1;如果state=0,则需要唤醒同步等待队列的所有阻塞线程。至此,CountDownLatch执行完成。
CountDownLatch和Semaphore
从上面可以看出CountDownLatch和Semaphore都是通过AQS的共享锁实现的,虽然它们的实现效果截然不同,但是比较它们的不同可以帮助我们记忆它们各自的实现。
- Semaphore在调用acquire()方法获取许可证时将state-1,如果state=0则进入同步等待队列阻塞;CountDownLatch在调用await()方法时,只要state!=0,线程都会进入同步等待队列阻塞;
- Semaphore在调用release()方法时无论当前线程是否获取过许可证,许可证state的值都会+1(甚至可以突破初始化时给的值),并调用doReleaseShared()方法唤醒同步等待队列首节点,如果许可证足够则会一直向后唤醒;CountDownLatch在调用countDown()方法时,如果state==0则什么都不做,如果state-1!=0则只更新state,如果state-1==0则会唤醒同步等待队列所有阻塞线程。
以上就决定了Semaphore的作用是限流,CountDownLatch的作用是协助线程同步执行。
需要注意的是,CountDownLatch只能使用一次,即当调用countDown()将state减为0后,当前CountDownLatch对象就没用了。如果想要达到重复使用的目的,可以选择另一个功能较CountDownLatch更强大的CyclicBarrier。