概述
Fork/Join是Java7提供的一个用于并行执行任务的框架,是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。Fork负责把一个大任务切分为若干并行执行的子任务,Join负责合并这些子任务的执行结果,最后得到这个大任务的结果。
使用类似MapReduce的分治思想,先分割再合并,故而包括两个关键步骤:
- 分割任务。需要有一个fork类来把大任务分割成子任务,有可能子任务还是很大,所以还需要不停的分割,直到分割出的子任务足够小。
- 执行任务并合并结果。分割的子任务分别放在双端队列里,然后几个启动线程分别从双端队列里获取任务执行。子任务执行完的结果都统一放在一个队列里,启动一个线程从队列里拿数据,然后合并这些数据。
Fork/Join使用两个类来完成以上两件事情:
- ForkJoinTask:要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供以下两个子类:
- RecursiveAction:用于没有返回结果的任务。
- RecursiveTask:用于有返回结果的任务。
- ForkJoinPool:ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。
ForkJoinTask与一般的任务的主要区别在于它需要实现compute方法,在这个方法里,首先需要判断任务是否足够小,如果足够小就直接执行任务。如果不足够小,就必须分割成两个子任务,每个子任务在调用fork方法时,又会进入compute方法,看看当前子任务是否需要继续分割成孙任务,如果不需要继续分割,则执行当前子任务并返回结果。使用join方法会等待子任务执行完并得到其结果。
实例
public class ForkJoinMergeSort {
public static void main(String[] args) {
int length = 200000000;
int[] array = new int[length];
Random random = new Random();
for (int i = 0; i < length; i++) {
int x = random.nextInt(1000000000);
array[i] = x;
}
long start = System.currentTimeMillis();
int[] tmp = new int[length];
MergeSort task = new MergeSort(array, tmp, 0, length - 1);
// fork/join框架测试
ForkJoinPool pool = new ForkJoinPool(8);
pool.invoke(task);
if (task.isCompletedNormally()) {
System.out.println("fork/join timing: " + (System.currentTimeMillis() - start) + "ms");
}
// 单线程测试
task.sort(array, tmp, 0, length-1);
System.out.println("single thread timing: " + (System.currentTimeMillis() - start) + "ms");
}
}
@AllArgsConstructor
class MergeSort extends RecursiveAction {
private final int[] array;
private final int[] tmp;
private final int first;
private final int last;
@Override
protected void compute() {
// 当排序项分解成少于1000时直接执行归并排序算法
if (last - first < 1000) {
sort(array, tmp, first, last);
} else {
// 当排序项大于1000时,将数组分成两部分(由框架根据条件自动递归分解,直到项数少于1000为止)
int middle = (first + last) / 2;
MergeSort t1 = new MergeSort(array, tmp, first, middle);
MergeSort t2 = new MergeSort(array, tmp, middle + 1, last);
invokeAll(t1, t2);
// 递归归并排序被分解的两组数字
merge(array, tmp, first, middle + 1, last);
}
}
public void sort(int[] array, int[] tmp, int first, int last) {
if (first < last) {
int middle = (first + last) / 2;
sort(array, tmp, first, middle);
sort(array, tmp, middle + 1, last);
merge(array, tmp, first, middle + 1, last);
}
}
private void merge(int[] array, int[] tmp, int leftStart, int rightStart, int rightEnd) {
int leftEnd = rightStart - 1;
int tmpPos = leftStart;
int total = rightEnd - leftStart + 1;
while (leftStart <= leftEnd && rightStart <= rightEnd) {
if (array[leftStart] <= array[rightStart]) {
tmp[tmpPos++] = array[leftStart++];
} else {
tmp[tmpPos++] = array[rightStart++];
}
}
while (leftStart <= leftEnd) {
tmp[tmpPos++] = array[leftStart++];
}
while (rightStart <= rightEnd) {
tmp[tmpPos++] = array[rightStart++];
}
for (int i = 0; i < total; i++, rightEnd--) {
array[rightEnd] = tmp[rightEnd];
}
}
}
输出:
fork/join timing: 19109ms
single thread timing: 38884ms
原理
并行度
即parallelism参数,也就是初始化ForkJoinPool时传入的参数,对应的构造函数如下:
public ForkJoinPool(int parallelism) {
this(parallelism, defaultForkJoinWorkerThreadFactory, null, false, 0, MAX_CAP, 1, null, DEFAULT_KEEPALIVE, TimeUnit.MILLISECONDS);
}
parallelism参数指定线程池中可用的最大并行线程数,基本上就是允许同时运行的线程数量,会影响任务的并行度和性能。
提高parallelism的直观预期是提高并行度,从而减少任务执行时间,但实际情况可能并非如此,原因有以下几点:
- 任务分解粒度:如果任务过于简单,分解为多个子任务时可能不值得产生额外的线程开销。每个线程的上下文切换和管理都有成本,线程过多反而会导致性能下降;
- 任务间依赖:如果任务之间有依赖关系,增加并行度可能导致线程等待其他线程完成,从而减少并行执行的效果;
- 资源竞争:如果任务频繁访问共享资源(如内存、文件系统等),会导致线程间的竞争和锁争用,降低整体性能;
- CPU核心数限制:parallelism不应超过可用CPU核心数。过多的线程会导致线程上下文切换,降低效率。理想情况下,parallelism设置为核心数的1到2倍通常能获得最佳性能;
- 工作窃取机制:ForkJoinPool使用工作窃取算法来分配任务。增加parallelism可能会导致一些线程空闲,而其他线程忙于执行任务。
工作窃取
work-stealing,采用工作窃取算法来实现,核心是指某个线程从其他队列里窃取任务来执行。为减少线程间的竞争,会把这些子任务分别放到不同的队列里,然后为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应。假设此时A线程已经把自己A队列里的所有子任务执行完毕,而B线程还没将对应的B队列里子任务执行完毕,此时A线程会去B线程的队列里窃取一个子任务来执行。在窃取操作中由于A、B线程同时访问B线程对应的子任务队列,为减少AB之间的竞争,通常使用双端队列,B永远从双端队列的头部获取任务执行,而A从尾部获取任务执行。
优势:充分利用线程进行并行计算,提高执行效率,以空间换时间。
缺点:
- 某些情况下存在竞争,如双端队列只剩一个任务时;
- 消耗更多的系统资源(创建多个线程和多个双端队列)。
核心类
- ForkJoinPool
ForkJoinPool自动计算线程池参数,且提供工作窃取算法来管理这些任务。如果有线程空闲,则会从其它线程的队列尾中窃取一个任务给空闲线程来运行。ForkJoinPool由ForkJoinTask数组和ForkJoinWorkerThread数组组成,ForkJoinTask数组负责存放程序提交给ForkJoinPool的任务,而ForkJoinWorkerThread数组负责执行这些任务。 - ForkJoinTask:抽象泛型类。是Fork/Join任务的一个抽象,你需要继承此类,然后定义自己的计算逻辑。任务的创建就是通过此类中的fork()方法来实现的。这里说的任务几乎类似Thread类创建的那些普通线程,但更轻量级。因为它可以使用ForkJoinPool中少量有限的线程来管理大量的任务,所以它要比Thread类创建的线程更轻量。fork()方法异步执行任务,join方法可以一直等待到任务执行完毕。invoke()方法把fork和join两个操作合二为一成一个单独的调用。代表fork/join里面任务类型,一般用它的两个子类RecursiveTask、RecursiveAction。任务的处理逻辑包括任务的切分都集中在compute()方法里面。
- RecursiveTask:ForkJoinTask的子类,也是抽象泛型类。通过重载RecursiveTask类的compute方法来实现Fork-Join的逻辑。在compute方法里,要实现两件事,Fork-Join就是要先fork出RecursiveTask对象的子任务,然后将它们join在一起。有返回值。
- RecursiveAction:没有返回值
- ForkJoinPool.WorkQueue:任务队列
- ForkJoinWorkerThread:fork/join里面真正干活的线程,里面有一个ForkJoinPool.WorkQueue的队列存放着它要干的活,接活之前它要向ForkJoinPool注册(registerWorker),拿到相应的workQueue。然后就从workQueue里面拿任务出来处理。依附于ForkJoinPool而存活,如果ForkJoinPool销毁,它也会跟着结束。
源码分析
源码基于JDK22。
任务状态
任务状态有四种:DONE(已完成,不一定是成功),ABNORMAL(不正常),和THROWN(出现异常),HAVE_EXCEPTION(异常)。对应源码如下:
static final int DONE = 1 << 31; // must be negative
static final int ABNORMAL = 1 << 16;
static final int THROWN = 1 << 17;
static final int HAVE_EXCEPTION = DONE | ABNORMAL | THROWN;
fork
ForkJoinTask的fork方法源码如下:
public final ForkJoinTask<V> fork() {
Thread t;ForkJoinWorkerThread wt;
ForkJoinPool p; ForkJoinPool.WorkQueue q; boolean internal;
// 先判断当前线程是否是ForkJoinWorkerThread的实例,是则将任务push到当前线程所维护的双端队列中
if (internal = (t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
q = (wt = (ForkJoinWorkerThread)t).workQueue;
p = wt.pool;
}
else
q = (p = ForkJoinPool.common).externalSubmissionQueue();
q.push(this, p, internal);
return this;
}
解析:调用fork方法时,会调用ForkJoinPool.WorkQueue的push方法将任务放进队列,然后立即返回结果。push方法源码如下:
final void push(ForkJoinTask<?> task, ForkJoinPool pool, boolean internal) {
int s = top, b = base, cap, m, p, room, newCap; ForkJoinTask<?>[] a;
if ((a = array) == null || (cap = a.length) <= 0 || (room = (m = cap - 1) - (s - b)) < 0) {
// could not resize
if (!internal)
unlockPhase();
throw new RejectedExecutionException("Queue capacity exceeded");
}
top = s + 1;
long pos = slotOffset(p = m & s);
if (!internal)
U.putReference(a, pos, task); // inside lock
else
U.getAndSetReference(a, pos, task); // fully fenced
if (room == 0 && (newCap = cap << 1) > 0) {
ForkJoinTask<?>[] newArray = null;
try { // resize for next time
newArray = new ForkJoinTask<?>[newCap];
} catch (OutOfMemoryError ex) {
}
if (newArray != null) { // else throw on next push
int newMask = newCap - 1; // poll old, push to new
for (int k = s, j = cap; j > 0; --j, --k) {
ForkJoinTask<?> u;
if ((u = (ForkJoinTask<?>)U.getAndSetReference(a, slotOffset(k & m), null)) == null)
break; // lost to pollers
newArray[k & newMask] = u;
}
updateArray(newArray); // fully fenced
}
a = null; // always signal
}
if (!internal)
unlockPhase();
if ((a == null || a[m & (s - 1)] == null) && pool != null)
pool.signalWork(a, p);
}
push方法把当前任务存放在ForkJoinPool.WorkQueue里的ForkJoinTask<?>
数组,然后再调用ForkJoinPool的signalWork()方法唤醒或创建一个工作线程来执行任务。ForkJoinPool.signalWork()方法如下:
final void signalWork(ForkJoinTask<?>[] a, int k) {
int pc = parallelism;
for (long c = ctl;;) {
WorkQueue[] qs = queues;
long ac = (c + RC_UNIT) & RC_MASK, nc;
int sp = (int)c, i = sp & SMASK;
if (qs == null || qs.length <= i)
break;
WorkQueue w = qs[i], v = null;
if (sp == 0) {
if ((short)(c >>> TC_SHIFT) >= pc)
break;
nc = ((c + TC_UNIT) & TC_MASK);
}
else if ((short)(c >>> RC_SHIFT) >= pc || (v = w) == null)
break;
else
nc = (v.stackPred & LMASK) | (c & TC_MASK);
if (c == (c = compareAndExchangeCtl(c, nc | ac))) {
if (v == null)
createWorker();
else {
v.phase = sp;
if (v.parking != 0)
U.unpark(v.owner);
}
break;
}
if (a != null && k >= 0 && k < a.length && a[k] == null)
break;
}
}
join
ForkJoinTask的join方法源码如下:
public final V join() {
int s;
if ((((s = status) < 0 ? s : awaitDone(false, 0L)) & ABNORMAL) != 0)
reportException(false);
return getRawResult();
}
就一行代码,如果状态不为0,表示异常,调用方法reportException。awaitDone表示阻塞当前线程并等待获取结果。
private int awaitDone(boolean interruptible, long deadline) {
ForkJoinWorkerThread wt; ForkJoinPool p; ForkJoinPool.WorkQueue q;
Thread t; boolean internal; int s;
if (internal = (t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
p = (wt = (ForkJoinWorkerThread)t).pool;
q = wt.workQueue;
}
else
q = ForkJoinPool.externalQueue(p = ForkJoinPool.common);
return (((s = (p == null) ? 0 :
((this instanceof CountedCompleter) ?
p.helpComplete(this, q, internal) :
(this instanceof InterruptibleTask) && !internal ? status :
p.helpJoin(this, q, internal))) < 0)) ? s :
awaitDone(internal ? p : null, s, interruptible, deadline);
}
ForkJoinPool的helpComplete和helpJoin两个方法可以看到工作窃取算法的思想,源码优点长,略。大致思路是通过for循环加if条件判断来分担任务。
异常处理
ForkJoinTask在执行时可能会抛出异常,但是没办法在主线程里直接捕获异常,ForkJoinTask提供isCompletedAbnormally()
方法来检查任务是否已经抛出异常或已经被取消,可通过ForkJoinTask.getException()
方法获取异常:
if(task.isCompletedAbnormally()) {
System.out.println(task.getException());
}
getException方法返回Throwable对象,如果任务被取消则返回CancellationException。如果任务没有完成或者没有抛出异常则返回null。
适用场景
分治法非常适合解决以下问题:
- 二分搜索
- 大整数乘法
- Strassen矩阵乘法
- 棋盘覆盖
- 合并排序
- 快速排序
- 线性时间选择
- 汉诺塔
在真实业务开发中,也有很多场景:
- 做报表导出时,大量数据的导出处理;
- 做BI时,大量的数据迁移清洗作业等。
参考
- fork/join全面剖析
- how-to-specify-forkjoinpool-for-java-8-parallel-stream