CountDownLatch 详解
2024-11-21 09:25:53 # Technical # JavaConcurrency

简单示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class CountDownLatchTest {

private static final CountDownLatch startSignal = new CountDownLatch(1);
private static final CountDownLatch doneSignal = new CountDownLatch(10);
private static final Random random = new Random();

public static void main(String[] args) throws InterruptedException {
for (int i = 0; i < 10; i++) {
new Thread(new Worker(i + 1)).start();
}
// 等待所有线程都启动完成
TimeUnit.SECONDS.sleep(1);
// 启动
startSignal.countDown();
// 等待 Workers “工作”结束
doneSignal.await();
System.out.println("Done");
}


private static class Worker implements Runnable {
private final int i;

public Worker(int i) {
this.i = i;
}

@Override
public void run() {
// 等待开始信号
try {
startSignal.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
long start = System.currentTimeMillis();
System.out.println("Worker-" + i + " started");
// 模拟工作
try {
TimeUnit.SECONDS.sleep(random.nextInt(5) + 1);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("Worker-" + i + " done! cost " + (System.currentTimeMillis() - start) + "ms");
doneSignal.countDown();
}
}
}

这里通过两个 CountDownLatch 实现了多线程的同步启动以及完全结束的控制

这里再次注意:初始化线程时,不要用 Worker::new

new Thread(Worker::new).start 等同于:

1
2
3
4
5
6
new Thread(new Runnable() {
@Override
public void run() {
new Worker();
}
}).start();

源码分析

相较于 ReentranLock 和 ReentranReadWriteLock,CountDownLatch 的实现算是「十分简单啦」

CountDownLatch 中只有一个变量

1
private final Sync sync

和 ReentrantLock 它们类似,这个 sync 也是一个继承了 AQS 的内部类

构造函数中会初始化这个 sync

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

先看下 CountDownLatch 对外暴露的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// 等待计数器为0
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

// 有超时时间的等待
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

// 计数器减 1
public void countDown() {
sync.releaseShared(1);
}

// 获取当前计数
public long getCount() {
return sync.getCount();
}

public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}

所有的实现都是通过 Sync 实现的,所以直接看 Sync 就行了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

// 构造函数,初始计数 count
Sync(int count) {
// AQS 的方法
setState(count);
}

int getCount() {
return getState();
}

// 尝试获取共享锁,await 会间接调用这个方法
// 如果 state = 0 说明计数完成,可以继续执行,否则需要等待
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

// 尝试释放共享锁,countDown 会间接调用这个方法
// 通过无限循环和 CAS 保证最终操作成功
protected boolean tryReleaseShared(int releases) {
for (;;) {
// 获取当前计数
int c = getState();
// 如果已经是 0,无法减少
if (c == 0)
return false;
// 减一
int nextc = c - 1;
// CAS 更新状态
if (compareAndSetState(c, nextc))
// 更新后如果状态为 0,返回 true(完全释放共享锁)
return nextc == 0;
}
}
}

这个 Sync 内部类的逻辑也是很简单的,整体 CountDownLatch 的实现还需结合 AQS 来看

初始化

首先在 CountDownLatch 的构造函数中初始化 Sync

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

然后在内部类 Sync 的构造函数中初始化 AQS 中的状态

1
2
3
4
// CountDownLatch.Sync
Sync(int count) {
setState(count);
}

通过给 state 设置值,也是设置了 Sync 共享锁的重入次数(计数器次数)

await

CountDownLatch 中调用 await 方法

1
2
3
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

这个 acquireSharedInterruptibly 方法并未在 CountDownLatch.Sync 中重写,直接使用的是 AQS 的

1
2
3
4
5
6
7
8
9
10
11
12
13
// java.util.concurrent.locks.AbstractQueuedSynchronizer
// 可被中断地获取共享锁
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如果当前线程被中断,直接抛出异常
if (Thread.interrupted())
throw new InterruptedException();
// 尝试获取共享锁,这个方法 AQS 并未实现,交由子类来实现,所以这里是会调用 CountDownLat.Sync
// 逻辑很简单,就是获取状态,状态不为 0 就返回 -1
if (tryAcquireShared(arg) < 0)
// 将当前线程加入阻塞队列,暂停当前线程
doAcquireSharedInterruptibly(arg);
}

countDown

CountDownLatch 中调用 countDown 方法

1
2
3
4
public void countDown() {
// 释放共享锁,这个方法是 AQS 实现的
sync.releaseShared(1);
}

进入到 AQS 的 releaseShared 方法

1
2
3
4
5
6
7
8
9
10
public final boolean releaseShared(int arg) {
// 尝试释放共享锁,这个 AQS 没实现,所以去到了 CountDownLatch.Sync
// 逻辑比较简单,就是自旋尝试将重入次数 - 1
if (tryReleaseShared(arg)) {
// 如果重入次数为 0 了,就直接释放掉这个共享锁
doReleaseShared();
return true;
}
return false;
}