1,信号量(Stemaphore)
Semaphore也就是信号量,提供了资源数量的并发访问控制,其使用代码很简单,如下所示:
// 一开始有5份共享资源。第二个参数表示是否是公平
// 公平锁排队,非公平锁竞争
Semaphore myResources = new Semaphore(5, true);
// 工作线程每获取一份资源,就在该对象上记下来
// 在获取的时候是按照公平的方式还是非公平的方式,就要看上一行代码的第二个参数了。
// 一般非公平抢占效率较高。
myResources.acquire();
// 工作线程每归还一份资源,就在该对象上记下来
// 此时资源可以被其他线程使用
myResources.release();
/*
释放指定数目的许可,并将它们归还给信标。 可用许可数加上该指定数目。
如果线程需要获取N个许可,在有N个许可可用之前,该线程阻塞。
如果线程获取了N个许可,还有可用的许可,则依次将这些许可赋予等待获取许可的其他线程。
*/
semaphore.release(2);
/*
从信标获取指定数目的许可。如果可用许可数目不够,则线程阻塞,直到被中断。
该方法效果与循环相同,
for (int i = 0; i < permits; i++) acquire();
只不过该方法是原子操作。
如果可用许可数不够,则当前线程阻塞,直到:(二选一)
1. 如果其他线程释放了许可,并且可用的许可数满足当前线程的请求数字;
2. 其他线程中断了当前线程。
permits – 要获取的许可数
*/
semaphore.acquire(3);
当初始化的资源个数为1时,Semaphore退化为排他锁,正因为如此,Semaphore的实现原理和锁类似,是基于AQS,有公平和非公平之分。Semaphore相关类的继承体系如下:
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public void release() {
sync.releaseShared(1);
}
由于Semaphore和锁的实现原理基本相同。资源总数即state的初始值,在acquire里对state变量进行CAS减操作,减到0后,线程阻塞;在release里面对变量进行CAS操作。
public abstract class AbstractQueuedSynchronizer
extends AbstractOwnableSynchronizer
implements java.io.Serializable {
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
}
public class Semaphore {
abstract static class Sync extends AbstractQueuedSynchronizer {
protected final boolean tryReleaseShared(int releases) {
for (;;) {
int current = getState();
int next = current + releases;
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
if (compareAndSetState(current, next))
return true;
}
}
}
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
}
package java.lang.invoke;
public abstract class VarHandle {
// ...
// CAS,原子操作
public final native
@MethodHandle.PolymorphicSignature
@HotSpotIntrinsicCandidate
boolean compareAndSet(Object... args);
// ...
}
2,等待完成(CountDownLatch)
假设一个主线程要等待5个worker线程执行完才能退出,可以使用CountDownLatch来实现:
import java.util.Random;
import java.util.concurrent.CountDownLatch;
public class MyThread extends Thread {
private final CountDownLatch latch;
private final Random random = new Random();
public MyThread(String name,CountDownLatch latch){
super(name);
this.latch = latch;
}
@Override
public void run() {
try {
Thread.sleep(random.nextInt(2000));
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName()+"执行完毕.");
// latch计数减一
latch.countDown();
}
}
import java.util.concurrent.CountDownLatch;
public class Main {
public static void main(String[] args) throws InterruptedException {
CountDownLatch latch = new CountDownLatch(5);
for (int i = 0; i < 5; i++) {
new MyThread("线程"+(i+1),latch).start();
}
// main线程等待
latch.await();
System.out.println("main线程执行结束");
}
}
下图为CountDownLatch相关类的继承层次,CountDownLatch原理和Semaphore原理类似,同样是基于AQS,不过没有公平和非公平之分。
await()调用的是AQS的模板方法,CountDownLatch.Sync重新实现了tryAcquireShared方法。从tryAcquireShared(…)方法的实现看,只要state!=0,调用await()方法的线程便会被放入AQS的阻塞队列,进入阻塞状态。
public void await() throws InterruptedException {
// AQS的模板方法
sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 被CountDownLatch.Sync实现
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
countDown()调用的AQS的模板方法releaseShared(),里面的tryReleaseShared(…)由CountDownLatch.Sync实现。从上面代码看,只有state=0,tryReleaseShared(…)才会返回true,然后doReleaseShared(…),一次性唤醒队列中所有阻塞的线程。
public void countDown() {
sync.releaseShared(1);
}
// AQS的模板方法
public final boolean releaseShared(int arg) {
// 有CountDownLatch.Sync实现
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c - 1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
总结:由于基于AQS阻塞队列来实现的,所以可以让多个线程都阻塞在state=0条件上,通过countDown()一直减state,减到0后一次性唤醒所有线程。
3,阶段性同步(CyclicBarrier)
CyclicBarrier类用于协调多个线程同步执行操作的场合。
CyclicBarrier barrier = new CyclicBarrier(5);
barrier.await();
考虑这样一个场景:10个工程师一起来公司应聘,招聘方式分为
笔试和面试。首先,要等人到齐后,开始笔试;笔试结束之后,再一起参加面试。把10个人看作10个线程,10个线程之间的同步过程如下:
在整个过程中,有2个同步点:第1个同步点,要等所有应聘者都到达公司,再一起开始笔试;第2 个同步点,要等所有应聘者都结束笔试,之后一起进入面试环节。
import java.util.Random;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
public class MyThread extends Thread {
private final CyclicBarrier barrier;
private final Random random = new Random();
MyThread(String name,CyclicBarrier barrier){
super(name);
this.barrier = barrier;
}
@Override
public void run() {
try {
System.out.println(Thread.currentThread().getName()+" - 向公司出发");
Thread.sleep(random.nextInt(5000));
System.out.println(Thread.currentThread().getName()+" - 已经到达公司");
// 等待其他线程该阶段结束
barrier.await();
System.out.println(Thread.currentThread().getName()+" - 开始笔试");
Thread.sleep(random.nextInt(5000));
System.out.println(Thread.currentThread().getName()+" - 笔试结束");
// 等待其他线程该阶段结束
barrier.await();
System.out.println(Thread.currentThread().getName()+" - 开始面试");
Thread.sleep(random.nextInt(5000));
System.out.println(Thread.currentThread().getName()+" - 面试结束");
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
}
}
import java.util.concurrent.CyclicBarrier;
public class Main {
public static void main(String[] args) {
CyclicBarrier cyclicBarrier = new CyclicBarrier(5, new Runnable() {
@Override
public void run() {
System.out.println("该阶段结束");
}
});
for (int i = 0; i < 5; i++) {
new MyThread("线程-"+(i+1),cyclicBarrier).start();
}
}
}
CyclicBarrier基于ReetrantLock+Condition实现:
public class CyclicBarrier {
private final ReentrantLock lock = new ReentrantLock();
// 用于线程之间相互唤醒
private final Condition trip = lock.newCondition();
// 线程总数
private final int parties;
private int count;
private Generation generation = new Generation();
// ...
}
下面详细介绍CyclicBarrier的实现原理,先看构造方法:
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
// 参与方数量
this.parties = parties;
this.count = parties;
// 当所有线程被唤醒,执行barrierCommand表示Runnable
this.barrierCommand = barrierAction;
}
接下来看一下await()方法的实现过程:
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
// 响应中断
if (Thread.interrupted()) {
// 唤醒所有阻塞的线程
breakBarrier();
throw new InterruptedException();
}
// 每个线程调用一次await(),count都要减1
int index = --count;
// 当count减到0时,此线程唤醒其他所有线程
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
for (;;) {
try {
if (!timed)
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
private void nextGeneration() {
// signal completion of last generation
trip.signalAll();
// set up next generation
count = parties;
generation = new Generation();
}
- CyclicBarrier是可以被重用的。以上应聘场景为例,来了10个线程,这10个线程互相等待,到期后一起被唤醒,各自执行接下来的逻辑。然后,这10个线程继续互相等待,到齐后再一起被唤醒。每一轮被称为一个Generation,就是一次同步点。
- CyclicBarrier会响应中断。10个线程没有到齐,如果有线程收到中断信号,所有阻塞的线程也会被唤醒,就是上面的breakBarrier()方法。然后count被重置为初始值(parties),重新开始。
- 上面的回调方法,barrierAction只会被第10个线程执行1次(在唤醒其他9个线程之前),而不是10个线程每个都执行1次。
4,交换数据(Exchanger)
Exchanger用于线程之间交换数据,其使用代码很简单,是一个exchange(…)方法,使用示例如下:
import java.util.Random;
import java.util.concurrent.Exchanger;
public class Main {
private static final Random random = new Random();
public static void main(String[] args) {
// 建一个多线程共用的exchange对象
// 把exchange对象传给3个线程对象。每个线程在自己的run方法中调用exchange,把自己的数据作为参数
// 传递进去,返回值是另外一个线程调用exchange传进去的参数
Exchanger<String> exchanger = new Exchanger<>();
new Thread("线程1"){
@Override
public void run() {
while (true){
try {
final String otherData = exchanger.exchange("交换数据1");
System.out.println(Thread.currentThread().getName() + "得到 <==" +otherData);
Thread.sleep(random.nextInt(2000));
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}.start();
new Thread("线程2"){
@Override
public void run() {
while (true){
try {
final String otherData = exchanger.exchange("交换数据2");
System.out.println(Thread.currentThread().getName() + "得到 <==" +otherData);
Thread.sleep(random.nextInt(2000));
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}.start();
new Thread("线程3"){
@Override
public void run() {
while (true){
try {
final String otherData = exchanger.exchange("交换数据3");
System.out.println(Thread.currentThread().getName() + "得到 <==" +otherData);
Thread.sleep(random.nextInt(2000));
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}.start();
}
}
在上面的例子中,3个线程并发的调用exchange(…),会两两交互数据,如1/2、1/3、2/3。
Exchanger的核心机制和Lock一样,也是CAS+park/unpark。首先,在Exchanger内部,有两个内部类:Participant和Node,代码如下:
// 添加了Contended注解,表示伪共享与缓存填充
@jdk.internal.vm.annotation.Contended static final class Node {
int index; // Arena index
int bound; // Last recorded value of Exchanger.bound
int collides; // 本次绑定中,CAS操作失败次数
int hash; // 自旋伪随机
Object item; // 本线程要交换的数据
volatile Object match; // 对方线程交换来的数据
// 当前线程
volatile Thread parked; // 当前线程阻塞的时候设置该属性,不阻塞为null
}
static final class Participant extends ThreadLocal<Node> {
public Node initialValue() { return new Node(); }
}
每个线程在调用exchange(…)方法交换数据的时候,会先创建一个Node对象。
这个Node对象就是对该线程的包装,里面包含了3个重要字段:第一个是该线程要交换的数据,第二个是对方线程交换来的数据,最会一个是该线程本身。
一个Node只能支持2个线程之间交换数据,要实现多个线程并行的交换数据,需要多个Node,因此Exchanger里面定义了Node数组:
private volatitle Node[] arena;
明白了大致思路,下面来看exchange(V x)方法的详细实现:
@SuppressWarnings("unchecked")
public V exchange(V x) throws InterruptedException {
Object v;
Object item = (x == null) ? NULL_ITEM : x; // translate null args
if ((arena != null ||
(v = slotExchange(item, false, 0L)) == null) &&
((Thread.interrupted() || // disambiguates null return
(v = arenaExchange(item, false, 0L)) == null)))
throw new InterruptedException();
return (v == NULL_ITEM) ? null : (V)v;
}
- 如果arena不是null,表示启用了arena方式交换数据。
- 如果arena不是null,并且线程被中断,则抛异常;
- 如果arena不是null,并且arenaExchange的返回值为null,则抛异常。对方线程交换来的null值是封装为NULL_ITEM对象的,而不是null。
- 如果slotExchange的返回值是null,并且线程被中断,则抛异常。
- 如果slotExchange的返回值是null,并且arenaExchange的返回值是null,则抛异常。
slotExchange的实现:
/**
* Exchange function used until arenas enabled. See above for explanation.
* 如果不启用arena,则使用该方法进行线程间数据交换
* @param item 需要交换的数据
* @param timed 是否计时等待,true表示是计时等待
* @param ns 如果计时等待,该值表示最大等待的时长
* @return 对方线程交换来的数据;如果等待超时或线程中断,或者启用了arena,则返回null
*/
private final Object slotExchange(Object item, boolean timed, long ns) {
// participant在初始化的时候这只初始值为new Node()
// 获取本线程要交换的数据节点
Node p = participant.get();
// 获取当先线程
Thread t = Thread.currentThread();
// 如果线程被中断,则返回null
if (t.isInterrupted()) // preserve interrupt status so caller can recheck
return null;
for (Node q;;) {
// 如果slot非空,表明有其他线程在等待该线程交换数据
if ((q = slot) != null) {
// CAS操作,将当前线程的slot由slot设置为null
// 如果操作成功,则执行if中的语句
if (SLOT.compareAndSet(this, q, null)) {
// 获取对方线程交换来的数据
Object v = q.item;
// 设置要交换的数据
q.match = item;
// 获取q中阻塞的线程对象
Thread w = q.parked;
if (w != null)
// 如果对方阻塞的线程非空,则唤醒阻塞的线程
LockSupport.unpark(w);
return v;
}
// create arena on contention, but continue until slot null
// 创建arena用于处理多个线程需要交换数据的场合,防止slot冲突
if (NCPU > 1 && bound == 0 &&
BOUND.compareAndSet(this, 0, SEQ))
arena = new Node[(FULL + 2) << ASHIFT];
}
// 如果arena不是null,需要调用者调用arenaExchange方法接着获取对方线程交换来的数据
else if (arena != null)
return null; // caller must reroute to arenaExchange
else {
// 如果slot为null,表示对方没有线程等待该线程交换数据
// 设置要交换的本方数据
p.item = item;
// 设置当前线程要交换的数据到slot
// CAS操作,如果设置失败,则进入下一轮for循环
if (SLOT.compareAndSet(this, null, p))
break;
p.item = null;
}
}
// await release
// 没有对方线程等待交换数据,将当前线程要交换的数据放到slot中,是一个Node对象
// 然后阻塞,等待唤醒
int h = p.hash;
// 如果是计时等待交换,则计算超时时间;否则设置为0
long end = timed ? System.nanoTime() + ns : 0L;
// 如果CPU核心数大于1,则使用SPINS数,自旋;否则为1,没有必要自旋。
int spins = (NCPU > 1) ? SPINS : 1;
Object v;
while ((v = p.match) == null) {
if (spins > 0) {
h ^= h << 1; h ^= h >>> 3; h ^= h << 10;
if (h == 0)
h = SPINS | (int)t.getId();
else if (h < 0 && (--spins & ((SPINS >>> 1) - 1)) == 0)
Thread.yield();
}
else if (slot != p)
spins = SPINS;
else if (!t.isInterrupted() && arena == null &&
(!timed || (ns = end - System.nanoTime()) > 0L)) {
p.parked = t;
if (slot == p) {
if (ns == 0L)
LockSupport.park(this);
else
LockSupport.parkNanos(this, ns);
}
p.parked = null;
}
else if (SLOT.compareAndSet(this, p, null)) {
v = timed && ns <= 0L && !t.isInterrupted() ? TIMED_OUT : null;
break;
}
}
MATCH.setRelease(p, null);
p.item = null;
p.hash = h;
return v;
}
arenaExchange的实现:
/**
* Exchange function when arenas enabled. See above for explanation.
* 当启用arenas时,使用该方法进行线程间的数据交换
* @param item 本线程要交换的非null数据
* @param timed 如果需要计时等待,则设置为true
* @param ns 表示计时等待的最大时长
* @return 对方线程交换来的数据。如果线程被中断,或者等待超时,则返回null
*/
private final Object arenaExchange(Object item, boolean timed, long ns) {
Node[] a = arena;
int alen = a.length;
Node p = participant.get();
for (int i = p.index;;) { // access slot at i
int b, m, c;
int j = (i << ASHIFT) + ((1 << ASHIFT) - 1);
if (j < 0 || j >= alen)
j = alen - 1;
Node q = (Node)AA.getAcquire(a, j);
if (q != null && AA.compareAndSet(a, j, q, null)) {
Object v = q.item; // release
q.match = item;
Thread w = q.parked;
if (w != null)
LockSupport.unpark(w);
return v;
}
else if (i <= (m = (b = bound) & MMASK) && q == null) {
p.item = item; // offer
if (AA.compareAndSet(a, j, null, p)) {
long end = (timed && m == 0) ? System.nanoTime() + ns : 0L;
Thread t = Thread.currentThread(); // wait
for (int h = p.hash, spins = SPINS;;) {
Object v = p.match;
if (v != null) {
MATCH.setRelease(p, null);
p.item = null; // clear for next use
p.hash = h;
return v;
}
else if (spins > 0) {
h ^= h << 1; h ^= h >>> 3; h ^= h << 10; // xorshift
if (h == 0) // initialize hash
h = SPINS | (int)t.getId();
else if (h < 0 && // approx 50% true
(--spins & ((SPINS >>> 1) - 1)) == 0)
Thread.yield(); // two yields per wait
}
else if (AA.getAcquire(a, j) != p)
spins = SPINS; // releaser hasn't set match yet
else if (!t.isInterrupted() && m == 0 &&
(!timed ||
(ns = end - System.nanoTime()) > 0L)) {
p.parked = t; // minimize window
if (AA.getAcquire(a, j) == p) {
if (ns == 0L)
LockSupport.park(this);
else
LockSupport.parkNanos(this, ns);
}
p.parked = null;
}
else if (AA.getAcquire(a, j) == p &&
AA.compareAndSet(a, j, p, null)) {
if (m != 0) // try to shrink
BOUND.compareAndSet(this, b, b + SEQ - 1);
p.item = null;
p.hash = h;
i = p.index >>>= 1; // descend
if (Thread.interrupted())
return null;
if (timed && m == 0 && ns <= 0L)
return TIMED_OUT;
break; // expired; restart
}
}
}
else
p.item = null; // clear offer
}
else {
if (p.bound != b) { // stale; reset
p.bound = b;
p.collides = 0;
i = (i != m || m == 0) ? m : m - 1;
}
else if ((c = p.collides) < m || m == FULL ||
!BOUND.compareAndSet(this, b, b + SEQ + 1)) {
p.collides = c + 1;
i = (i == 0) ? m : i - 1; // cyclically traverse
}
else
i = m + 1; // grow
p.index = i;
}
}
}
5,Phaser(CyclicBarrier+CountDownLatch)
5.1,用Phaser替代CyclicBarrier和CountDownLatch
从JDK7开始,新增了一个同步工具类Phaser,其功能比CyclicBarrier和CountDownLatch更加强大。
用Phaser替代CountDownLatch: 在CountDownLatch中,主要是2个方法:await()和countDown()。在Phaser中,与之对应的方法是awaitAdvance(int n)和arrive()。
import java.nio.file.attribute.UserPrincipal;
import java.time.Year;
import java.util.Random;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
public class Main {
public static void main(String[] args) {
Phaser phaser = new Phaser(5);
for (int i = 0; i < 5; i++) {
new Thread("线程-"+(i+1)){
private final Random random = new Random();
@Override
public void run() {
System.out.println(getName()+" - 开始运行");
try {
Thread.sleep(random.nextInt(1000));
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(getName() + " - 运行结束");
phaser.arrive();
}
}.start();
}
System.out.println("线程启动完毕");
System.out.println(phaser.getPhase());
phaser.awaitAdvance(0);
System.out.println("线程运行结束");
}
}
用Phaser代替CyclicBarrier: arriveAndAwaitAdvance()就是arrive()与awaitAdvance(int)的组合,表示“我自己已到达这个同步点,同时要等待所有人都到达这个同步点,然后再一起前行”。
import java.util.concurrent.Phaser;
public class Main1 {
public static void main(String[] args) {
Phaser phaser = new Phaser(5);
for (int i = 0; i < 5; i++) {
new MyThread("线程-"+(i+1),phaser).start();
}
phaser.awaitAdvance(0);
System.out.println("线程运行结束");
}
}
import java.util.Random;
import java.util.concurrent.Phaser;
public class MyThread extends Thread {
private final Phaser phaser;
private final Random random = new Random();
MyThread(String name,Phaser phaser){
super(name);
this.phaser = phaser;
}
@Override
public void run() {
System.out.println(getName() + " - 开始向公司出发");
slowly();
System.out.println(getName() + " - 已经到达公司");
// 到达同步点,等待其他线程
phaser.arriveAndAwaitAdvance();
System.out.println(getName() + " - 开始笔试");
slowly();
System.out.println(getName() + " - 笔记结束");
// 到达同步点,等待其他线程
phaser.arriveAndAwaitAdvance();
System.out.println(getName() + " - 开始面试");
slowly();
System.out.println(getName() + " - 面试结束");
}
private void slowly() {
try {
Thread.sleep(random.nextInt(1000));
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
5.2,Phaser新特性
动态调整线程个数: CyclicBarrier所要同步的线程个数是在构造方法中指定的,之后不能修改。而Phaser可以再运行期间动态的调整要同步的线程个数。Phaser提供了下面这些方法来增加、减少所要同步的线程个数。
register() // 注册一个
bulkRegister(int parties) // 注册多个
arriveAndDeregister() // 解除注册
层次Phaser: 多个Phaser可以组成如下图所示的树状结构,可以通过在构造方法中传入父Phaser来实现。
public Phaser(Phaser parent, int parties) {
//...
}
先简单看一下Phaser内部关于树状结构的存储,如下所示:
public class Phaser{
private final Phaser parent;
...
}
可以发现,在Phaser的内部结构中,每个Phaser记录了自己的父节点,但没有记录自己的子节点列表。所以,每个Phaser知道自己的父节点是谁,但父节点并不知道自己有多少子节点,对父节点的操作,是通过子节点来实现的。
树状的Phaser怎么使用?如下代码可以组成如下图所示的树状Phaser。
Phaser root = new Phaser(2);
Phaser c1 = new Phaser(root, 3);
Phaser c2 = new Phaser(root, 2);
Phaser c3 = new Phaser(c1, 0);
本来root有两个参与者,然后为其加入了两个子Phaser(c1, c2),每个子Phaser会算作1个参与者,root的参与者就变成2+2=4个。c1本来有3个参与者,为其加入了一个子Phaser c3,参与者数量编程3+1=4个。c3的参与者初始为0,后续可以通过调用register()方法加入。
对于树状Phaser上的每个检点来说,可以当作一个独立的Phaser来看待,其运作机制和一个单独的Phaser是一样的。父Phaser并不用感知子Phaser的存在,当子Phaser中注册的参与者数量大于0时,会把自己向父节点注册;当子Phaser中注册的参与者数量等于0时,会自动向父节点解除注册。父Phaser把子Phaser当作一个正常参与的线程。
5.3,state变量解析
Phaser没有基于AQS来实现,但具备AQS的核心特性:state变量、CAS操作、阻塞队列。先从state变量说起。
private volatile long state;
这个64位的state变量被拆成4部分,下图为state变量各部分:
最高位0表示未同步,1表示同步完成,初始最高位为0。Phaser提供了一些列的成员方法来从state中获取上图中的几个数字,如下所示:
private static final int EMPTY = 1;
private static final int PHASE_SHIFT = 32;
private static final int PARTIES_SHIFT = 16;
private static final int UNARRIVED_MASK = 0xffff;
//获取当前的轮数
public final int getPhase() {
//当前轮数同步完成,返回值是一个负数(最高位为1)
//当前phase未完成,返回值是一个负数(最高位为1)
return (int)(root.state >>> PHASE_SHIFT);
}
//当前轮数同步完成,最高位为1
public boolean isTerminated() {
return root.state < 0L;
}
//获取总注册线程数
public int getRegisteredParties() {
return partiesOf(state);
}
//先把state转为32位int,再右移16位
private static int partiesOf(long s) {
return (int)s >>> PARTIES_SHIFT;
}
//获取未到达的线程数
public int getUnarrivedParties() {
return unarrivedOf(reconcileState());
}
//截取低16位
private static int unarrivedOf(long s) {
int counts = (int)s;
return (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
}
state变量在构造方法中赋值:
public Phaser(Phaser parent, int parties) {
if (parties >>> PARTIES_SHIFT != 0)
// 如果parties数超出了最大个数(2的16次方),抛异常
throw new IllegalArgumentException("Illegal number of parties");
// 初始化轮数为0
int phase = 0;
this.parent = parent;
if (parent != null) {
final Phaser root = parent.root;
// 父节点各根节点就是自己的根节点
this.root = root;
// 父节点的evenQ就是自己的evenQ
this.evenQ = root.evenQ;
// 父节点的oddQ就是自己的oddQ
this.oddQ = root.oddQ;
// 如果参与者不是0,则向附加点注册自己
if (parties != 0)
phase = parent.doRegister(1);
}
else {
// 如果父节点为null,则自己就是root节点
this.root = this;
// 创建奇数节点
this.evenQ = new AtomicReference<QNode>();
// 创建偶数节点
this.oddQ = new AtomicReference<QNode>();
}
this.state = (parties == 0) ? (long)EMPTY :
((long)phase << PHASE_SHIFT) | // 位或操作,赋值state。最高位为0,表示同步未完成
((long)parties << PARTIES_SHIFT) |
((long)parties);
}
- 当parties = 0时,state被赋予一个EMPTY常量,常量为1;
- 当parties != 0时,把phaser值左移32位;把parties左移16位;然后parties也作为最低的16位,3个值做或操作,赋值给state。
5.4,阻塞与唤醒(Treiber Stack)
基于上述的state变量,对其进行CAS操作,并进行相应的阻塞与唤醒。如下图所示,右边的主线程会调用awaitAdvance()进行阻塞;左边的arrive()会对state进行CAS的雷减操作,当未到达的线程数减到0,唤醒右边阻塞的主线程。
在这里,阻塞使用的是一个称为Treiber Stack的数据结构,而不是AQS的双向链表。Treiber Stack是一个无锁的栈,他是一个单向链表,出栈、入栈都在链表头部,所以只需要一个head指针,而不需要tail指针,如下:
static final class QNode implements ForkJoinPool.ManagedBlocker {
final Phaser phaser;
final int phase;
final boolean interruptible;
final boolean timed;
boolean wasInterrupted;
long nanos;
final long deadline;
volatile Thread thread; // 每个Node节点对应一个线程
QNode next; // 下一个节点的引用
...
}
//两个引用表示链表头部,可以避免线程冲突。
private final AtomicReference<QNode> evenQ;
private final AtomicReference<QNode> oddQ;
为了减少并发冲突,这里定义了2个链表,也就是2个Treiber Stack。当phase为奇数轮的时候,阻塞线程放在oddQ里面;当phase为偶数轮的时候,阻塞线程放在evenQ里面。代码如下:
private void releaseWaiters(int phase) {
QNode q; // first element of queue
Thread t; // its thread
//选择链表
AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
while ((q = head.get()) != null &&
q.phase != (int)(root.state >>> PHASE_SHIFT)) {
if (head.compareAndSet(q, q.next) &&
(t = q.thread) != null) {
q.thread = null;
LockSupport.unpark(t);
}
}
}
5.5,arrive()方法
arrive()方法是如何对state变量进行操作,又是如何唤醒线程的。
private static final int ONE_ARRIVAL = 1;
private static final int ONE_PARTY = 1 << PARTIES_SHIFT;
private static final int ONE_DEREGISTER = ONE_ARRIVAL|ONE_PARTY;
private static final int PARTIES_SHIFT = 16;
public int arrive() {
return doArrive(ONE_ARRIVAL);
}
public int arriveAndDeregister() {
return doArrive(ONE_DEREGISTER);
}
arrive()和arriveAndDeregister()内部调用的都是doArrive(boolean)方法。区别在于前者只是把“未达到线程数”减1;后者则把“未到达线程数”和“下一轮的总线程数”都减1。
doArrive(boolean)方法的实现:
private int doArrive(int adjust) {
final Phaser root = this.root;
for (;;) {
long s = (root == this) ? state : reconcileState();
int phase = (int)(s >>> PHASE_SHIFT);
if (phase < 0)
return phase;
int counts = (int)s;
// 获取未到达线程数
int unarrived = (counts == EMPTY) ? 0 : (counts & UNARRIVED_MASK);
// 如果未到达线程数小于等于0,抛异常。
if (unarrived <= 0)
throw new IllegalStateException(badArrive(s));
// CAS操作,将state的值减去adjust
if (STATE.compareAndSet(this, s, s-=adjust)) {
// 如果未达到线程为1
if (unarrived == 1) {
long n = s & PARTIES_MASK; // base of next state
int nextUnarrived = (int)n >>> PARTIES_SHIFT;
if (root == this) {
if (onAdvance(phase, nextUnarrived))
n |= TERMINATION_BIT;
else if (nextUnarrived == 0)
n |= EMPTY;
else
n |= nextUnarrived;
int nextPhase = (phase + 1) & MAX_PHASE;
n |= (long)nextPhase << PHASE_SHIFT;
STATE.compareAndSet(this, s, n);
releaseWaiters(phase);
}
// 如果下一轮的未到达线程数为0
else if (nextUnarrived == 0) { // propagate deregistration
phase = parent.doArrive(ONE_DEREGISTER);
STATE.compareAndSet(this, s, s | EMPTY);
}
else
// 否则调用父节点doArrive方法,传递参数1,表示当前节点已完成
phase = parent.doArrive(ONE_ARRIVAL);
}
return phase;
}
}
}
关于方面的方法,有以下几点说明:
- 定义两个常量:当deregister = false 时,只有最低的16位需要减1,adj=ONE_ARRIVAL;当deregister=true时,低32位中的2个16位都需要减1,adj=ONE_ARRIVAL|ONE_PARTY。
private static final int ONE_ARRIVAL = 1;
private static final int ONE_PARTY = 1 << PARTIES_SHIFT;
- 把未到达线程数减1:减了之后,如果还未到0,什么都不做,直接返回。如果到0,会做2件事:第一,重置state,把state的未到达线程个数重置到总的注册的线程数中,同时phase加1;第二,唤醒队列中的线程。
private void releaseWaiters(int phase) {
QNode q; // first element of queue
Thread t; // its thread
//选择链表
AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
//遍历整个栈,只要栈当中节点的phase不等于当前Phaser的phase,说明该节点不是当前轮的,而是前一轮的,应该被释放并唤醒。
while ((q = head.get()) != null &&
q.phase != (int)(root.state >>> PHASE_SHIFT)) {
if (head.compareAndSet(q, q.next) &&
(t = q.thread) != null) {
q.thread = null;
LockSupport.unpark(t);
}
}
}
5.6,awaitAdvance()方法
public int awaitAdvance(int phase) {
final Phaser root = this.root;
//当只有一个Phaser,没有树状结构时,root就是this。
long s = (root == this) ? state : reconcileState();
int p = (int)(s >>> PHASE_SHIFT);
if (phase < 0)
//phase已经结束,无需阻塞,直接返回。
return phase;
if (p == phase)
//阻塞在phase这一轮上
return root.internalAwaitAdvance(phase, null);
return p;
}
下面的while循环中有4个分支:
- 初始的时候,node==null,进入第一个分支进行自旋,自旋次数满足之后,会新建一个QNode节点;
- 之后执行第3、第4个分支,分别把该节点入栈并阻塞。
private int internalAwaitAdvance(int phase, QNode node) {
// assert root == this;
releaseWaiters(phase-1); // ensure old queue clean
boolean queued = false; // true when node is enqueued
int lastUnarrived = 0; // to increase spins upon change
int spins = SPINS_PER_ARRIVAL;
long s;
int p;
while ((p = (int)((s = state) >>> PHASE_SHIFT)) == phase) {
if (node == null) { // 不可中断模式的自旋
int unarrived = (int)s & UNARRIVED_MASK;
if (unarrived != lastUnarrived &&
(lastUnarrived = unarrived) < NCPU)
spins += SPINS_PER_ARRIVAL;
boolean interrupted = Thread.interrupted();
if (interrupted || --spins < 0) { // 自旋结束,建一个节点,之后进入阻塞
node = new QNode(this, phase, false, false, 0L);
node.wasInterrupted = interrupted;
}
else
Thread.onSpinWait();
}
else if (node.isReleasable()) // 从阻塞唤醒,退出while循环
break;
else if (!queued) { // push onto queue
AtomicReference<QNode> head = (phase & 1) == 0 ? evenQ : oddQ;
QNode q = node.next = head.get();
if ((q == null || q.phase == phase) &&
(int)(state >>> PHASE_SHIFT) == phase) // avoid stale enq
queued = head.compareAndSet(q, node); // 节点入栈
}
else {
try {
ForkJoinPool.managedBlock(node); // 调用node.block()阻塞
} catch (InterruptedException cantHappen) {
node.wasInterrupted = true;
}
}
}
if (node != null) {
if (node.thread != null)
node.thread = null; // avoid need for unpark()
if (node.wasInterrupted && !node.interruptible)
Thread.currentThread().interrupt();
if (p == phase && (p = (int)(state >>> PHASE_SHIFT)) == phase)
return abortWait(phase); // possibly clean up on abort
}
releaseWaiters(phase);
return p;
}
这里调用了ForkJoinPool.managedBlock(ManagedBlocker blocker)方法,目的是把node对应的线程阻塞。ManagedBlocker时ForkJoinPool里面的一个接口,定义如下:
public static interface ManagedBlocker {
boolean block() throws InterruptedException;
boolean isReleasable();
}
QNode实现了该接口,实现原理还是park(),如下所示。之所以没有直接使用park()/unpark()来实现阻塞、唤醒,而是封装了ManagedBlocker这一层,主要是处于使用上的方便考虑。一方面是park()可能被中断唤醒,另一方面是带超时时间的park(),把这二者都封装在一起。
static final class QNode implements ForkJoinPool.ManagedBlocker {
final Phaser phaser;
final int phase;
final boolean interruptible;
final boolean timed;
boolean wasInterrupted;
long nanos;
final long deadline;
volatile Thread thread; // nulled to cancel wait
QNode next;
QNode(Phaser phaser, int phase, boolean interruptible,
boolean timed, long nanos) {
this.phaser = phaser;
this.phase = phase;
this.interruptible = interruptible;
this.nanos = nanos;
this.timed = timed;
this.deadline = timed ? System.nanoTime() + nanos : 0L;
thread = Thread.currentThread();
}
public boolean isReleasable() {
if (thread == null)
return true;
if (phaser.getPhase() != phase) {
thread = null;
return true;
}
if (Thread.interrupted())
wasInterrupted = true;
if (wasInterrupted && interruptible) {
thread = null;
return true;
}
if (timed) {
if (nanos > 0L) {
nanos = deadline - System.nanoTime();
}
if (nanos <= 0L) {
thread = null;
return true;
}
}
return false;
}
public boolean block() {
if (isReleasable())
return true;
else if (!timed)
LockSupport.park(this);
else if (nanos > 0L)
LockSupport.parkNanos(this, nanos);
return isReleasable();
}
}