一、使用场景
用于CPU密集型的任务,通过把任务进行拆分,拆分成多个小任务去执行,然后小任务执行完毕后再把每个小任务执行的结果合并起来,这样就可以节省时间。
CPU密集型(CPU-bound):CPU密集型也叫计算密集型,指的是系统的硬盘、内存性能相对CPU要好很多,此时,系统运作大部分的状况是CPU Loading 100%,CPU要读/写I/O(硬盘/内存),I/O在很短的时间就可以完成,而CPU还有许多运算要处理,CPU Loading很高。
例如:大部份时间用来做计算、逻辑判断等CPU动作的程序称之CPU bound。
线程数一般设置为:线程数 = CPU核数+1 (现代CPU支持超线程)
IO密集型(I/O bound):IO密集型指的是系统的CPU性能相对硬盘、内存要好很多,此时,系统运作,大部分的状况是CPU在等I/O (硬盘/内存) 的读/写操作,此时CPU Loading并不高。
例如:读取本地文件、读取redis缓存、读取数据库等操作
线程数一般设置为:线程数 = ((线程等待时间+线程CPU时间)/线程CPU时间 )* CPU数目
二、简单使用
问题:计算1至10000000的正整数之和。
方案一:for循环解决
public static void main(String[] args) {
long sum = 0;
long start = System.currentTimeMillis();
for (int i = 1; i <= 10000000;i++) {
sum += i;
}
System.out.println("结果为:" + sum);
System.out.println("耗时为:" + (System.currentTimeMillis() - start));
}
方案二:采用并行流(JDK8以后)
public static void main(String[] args) {
long start = System.currentTimeMillis();
long sum = LongStream.rangeClosed(0, 10000000L).parallel().reduce(0, Long::sum);
System.out.println("结果为:" + sum);
System.out.println("耗时为:" + (System.currentTimeMillis() - start));
}
}
方案三:ExecutorService多线程方式实现
package concurrency.threadpool;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.stream.LongStream;
public class ExecutorTest {
int availableProcessors = Runtime.getRuntime().availableProcessors();
ExecutorService executorService = Executors.newFixedThreadPool(availableProcessors);
public static void main(String[] args) {
long[] nums = LongStream.rangeClosed(1, 10000000L).toArray();
long start = System.currentTimeMillis();
System.out.println("结果为:" + new ExecutorTest().sumVal(nums));
System.out.println("耗时为:" + (System.currentTimeMillis() - start));
}
private static class SumTask implements Callable<Long> {
private long[] nums;
private int from;
private int to;
public SumTask(long[] nums, int from, int to) {
this.nums = nums;
this.from = from;
this.to = to;
}
@Override
public Long call() throws Exception {
long sum = 0;
for (int i = from; i <= to; i++) {
sum += nums[i];
}
return sum;
}
}
private long sumVal(long[] nums) {
List<Future<Long>> results = new ArrayList<>();
int part = nums.length / availableProcessors;
for (int i = 0; i < availableProcessors; i++) {
int from = i * part;
int to = (i == availableProcessors - 1) ? nums.length - 1 : (i + 1) * part - 1;
results.add(executorService.submit(new SumTask(nums, from, to)));
}
long sum = 0;
for (Future<Long> future : results) {
try {
sum += future.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
return sum;
}
}
方案四:采用ForkJoinPool(Fork/Join)
package concurrency.threadpool;
import java.util.concurrent.*;
import java.util.stream.LongStream;
public class ForkJoinTest {
int availableProcessors = Runtime.getRuntime().availableProcessors();
ForkJoinPool pool = new ForkJoinPool();
public static void main(String[] args) {
long[] nums = LongStream.rangeClosed(1, 10000000).toArray();
long start = System.currentTimeMillis();
System.out.println("结果为:" + new ForkJoinTest().sumVal(nums));
System.out.println("耗时为:" + (System.currentTimeMillis() - start));
}
private static class SumTask extends RecursiveTask<Long> {
private long[] nums;
private int from;
private int to;
public SumTask(long[] nums, int from, int to) {
this.nums = nums;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
// 当需要计算的数字个数小于6时,直接采用for loop方式计算结果
if (to - from < 6) {
long sum = 0;
for (int i = from; i <= to; i++) {
sum += nums[i];
}
return sum;
} else { // 否则,把任务一分为二,递归拆分(注意此处有递归)到底拆分成多少分 需要根据具体情况而定
int middle = (from + to) / 2;
SumTask taskLeft = new SumTask(nums, from, middle);
SumTask taskRight = new SumTask(nums, middle + 1, to);
taskLeft.fork();
taskRight.fork();
return taskLeft.join() + taskRight.join();
}
}
}
private long sumVal(long[] nums) {
Long result = pool.invoke(new SumTask(nums, 0, nums.length - 1));
pool.shutdown();
return result;
}
}
总结:
1.ForkJoinPool 不是为了替代 ExecutorService,而是它的补充,在某些应用场景下性能比 ExecutorService 更好(例:从数据库拉取了千亿万的数据到本地,然后进行排序 )。
2. ForkJoinPool 主要用于实现“分而治之”的算法,特别是分治之后递归调用的函数,例如 quick sort 等。
3. ForkJoinPool 最适合的是计算密集型的任务,如果存在 I/O,线程间同步,sleep() 等会造成线程长时间阻塞的情况时,最好配合使用 ManagedBlocker。
三、整体流程
1.任务入队
task1还能继续拆分,则调用fork方法进行拆分,
2.任务执行
worker-thread1的task1执行完成,出队
此时,worker-thread1会去问问worker-thread0是否需要帮忙,会从队头获取任务进行执行,而worker-thread0是从队尾获取任务进行执行。这就是“工作窃取算法”(工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。)。
四、源码解析
使用 ForkJoin 框架,必须首先创建一个 ForkJoin 任务。它提供在任务中执行 fork() 和 join() 操作的机制,通常情况下我们不需要直接继承 ForkJoinTask 类,而只需要继承它的子类,Fork/Join 框架提供了以下两个子类:
RecursiveAction:用于没有返回结果的任务。(比如写数据到磁盘,然后就退出了。 一个RecursiveAction可以把自己的工作分割成更小的几块, 这样它们可以由独立的线程或者CPU执行。 我们可以通过继承来实现一个RecursiveAction)
RecursiveTask :用于有返回结果的任务。(可以将自己的工作分割为若干更小任务,并将这些子任务的执行合并到一个集体结果。 可以有几个水平的分割和合并)
4-1 ForkJoinPool构造函数
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(‐parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
参数解释:
1>parallelism:并行度( the parallelism level),默认情况下跟我们机器的cpu个数保持一致,使用 Runtime.getRuntime().availableProcessors()可以得到我们机器运行时可用的CPU个数。
2>factory:创建新线程的工厂( the factory for creating new threads)。默认情况下使用
ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory。
3handler:线程异常情况下的处理器(Thread.UncaughtExceptionHandler handler),该处理器在线程执行任务时由于某些无法预料
到的错误而导致任务线程中断时进行一些处理,默认情况为null。
4>asyncMode:这个参数要注意,在ForkJoinPool中,每一个工作线程都有一个独立的任务队列,asyncMode表示工作线程内的任务队列是采用何种方式进行调度,可以是先进先出FIFO,也可以是后进先出LIFO。如果为true,则线程池中的工作线程则使用先进先出方式进行任务调度,默认情况下是false。
4-2 ForkJoinTask fork 方法
将任务推入当前工作线程的工作队列中。
public final ForkJoinTask<V> fork() {
Thread var1;
if ((var1 = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
((ForkJoinWorkerThread)var1).workQueue.push(this);
} else {
ForkJoinPool.common.externalPush(this);
}
return this;
}
package java.util.concurrent;
import java.lang.Thread.UncaughtExceptionHandler;
import java.security.AccessControlContext;
import java.security.CodeSource;
import java.security.PermissionCollection;
import java.security.ProtectionDomain;
import java.util.concurrent.ForkJoinPool.WorkQueue;
import sun.misc.Unsafe;
/**
* 线程池中的每个工作线程(ForkJoinWorkerThread)都有一个自己的任务队列(WorkQueue),工作线程优先处理自身队列中的任务(LIFO或FIFO顺序,由线程池构造时的参数 mode 决定),自身队列为空时,以FIFO的顺序随机窃取其它队列中的任务。
*/
public class ForkJoinWorkerThread extends Thread {
final ForkJoinPool pool; // 该工作线程归属的线程池
final WorkQueue workQueue; // 指定的队列
private static final Unsafe U;
private static final long THREADLOCALS;
private static final long INHERITABLETHREADLOCALS;
private static final long INHERITEDACCESSCONTROLCONTEXT;
protected ForkJoinWorkerThread(ForkJoinPool var1) {
super("aForkJoinWorkerThread"); // 指定工作线程名称
this.pool = var1;
this.workQueue = var1.registerWorker(this); // 将自己注册到线程池中
}
ForkJoinWorkerThread(ForkJoinPool var1, ThreadGroup var2, AccessControlContext var3) {
super(var2, (Runnable)null, "aForkJoinWorkerThread");
U.putOrderedObject(this, INHERITEDACCESSCONTROLCONTEXT, var3);
this.eraseThreadLocals();
this.pool = var1;
this.workQueue = var1.registerWorker(this);
}
public ForkJoinPool getPool() {
return this.pool;
}
public int getPoolIndex() {
return this.workQueue.getPoolIndex();
}
protected void onStart() {
}
protected void onTermination(Throwable var1) {
}
public void run() {
if (this.workQueue.array == null) {
Throwable var1 = null;
try {
// 空方法,待用户自己实现
this.onStart();
// 执行队列中的task任务
this.pool.runWorker(this.workQueue);
} catch (Throwable var40) {
var1 = var40;
} finally {
try {
// 空方法,待用户自己实现
this.onTermination(var1);
} catch (Throwable var41) {
if (var1 == null) {
var1 = var41;
}
} finally {
this.pool.deregisterWorker(this, var1);
}
}
}
}
final void eraseThreadLocals() {
U.putObject(this, THREADLOCALS, (Object)null);
U.putObject(this, INHERITABLETHREADLOCALS, (Object)null);
}
void afterTopLevelExec() {
}
static {
try {
U = Unsafe.getUnsafe();
Class var0 = Thread.class;
THREADLOCALS = U.objectFieldOffset(var0.getDeclaredField("threadLocals"));
INHERITABLETHREADLOCALS = U.objectFieldOffset(var0.getDeclaredField("inheritableThreadLocals"));
INHERITEDACCESSCONTROLCONTEXT = U.objectFieldOffset(var0.getDeclaredField("inheritedAccessControlContext"));
} catch (Exception var1) {
throw new Error(var1);
}
}
static final class InnocuousForkJoinWorkerThread extends ForkJoinWorkerThread {
private static final ThreadGroup innocuousThreadGroup = createThreadGroup();
private static final AccessControlContext INNOCUOUS_ACC = new AccessControlContext(new ProtectionDomain[]{new ProtectionDomain((CodeSource)null, (PermissionCollection)null)});
InnocuousForkJoinWorkerThread(ForkJoinPool var1) {
super(var1, innocuousThreadGroup, INNOCUOUS_ACC);
}
void afterTopLevelExec() {
this.eraseThreadLocals();
}
public ClassLoader getContextClassLoader() {
return ClassLoader.getSystemClassLoader();
}
public void setUncaughtExceptionHandler(UncaughtExceptionHandler var1) {
}
public void setContextClassLoader(ClassLoader var1) {
throw new SecurityException("setContextClassLoader");
}
private static ThreadGroup createThreadGroup() {
try {
Unsafe var0 = Unsafe.getUnsafe();
Class var1 = Thread.class;
Class var2 = ThreadGroup.class;
long var3 = var0.objectFieldOffset(var1.getDeclaredField("group"));
long var5 = var0.objectFieldOffset(var2.getDeclaredField("parent"));
ThreadGroup var8;
for(ThreadGroup var7 = (ThreadGroup)var0.getObject(Thread.currentThread(), var3); var7 != null; var7 = var8) {
var8 = (ThreadGroup)var0.getObject(var7, var5);
if (var8 == null) {
return new ThreadGroup(var7, "InnocuousForkJoinWorkerThreadGroup");
}
}
} catch (Exception var9) {
throw new Error(var9);
}
throw new Error("Cannot create ThreadGroup");
}
}
}
4-3 ForkJoinTask join 方法
public final V join() {
int var1;
if ((var1 = this.doJoin() & -268435456) != -268435456) {
this.reportException(var1);
}
return this.getRawResult();
}
private int doJoin() {
int var1;
Thread var2;
ForkJoinWorkerThread var3;
WorkQueue var4;
return (var1 = this.status) < 0 ? var1 : ((var2 = Thread.currentThread()) instanceof ForkJoinWorkerThread ? ((var4 = (var3 = (ForkJoinWorkerThread)var2).workQueue).tryUnpush(this) && (var1 = this.doExec()) < 0 ? var1 : var3.pool.awaitJoin(var4, this, 0L)) : this.externalAwaitDone());
}
private void reportException(int var1) {
if (var1 == -1073741824) {
throw new CancellationException();
} else {
if (var1 == -2147483648) {
rethrow(this.getThrowableException());
}
}
}
工作流程:
1.检查调用 join() 的线程是否是 ForkJoinThread 线程。如果不是(例如 main 线程),则阻塞当前线程,等待任务完成。如果是,则不阻塞。
2. 查看任务的完成状态,如果已经完成,直接返回结果。
3. 如果任务尚未完成,但处于自己的工作队列内,则完成它。
4. 如果任务已经被其他的工作线程偷走,则窃取这个小偷的工作队列内的任务(以 FIFO 方式),执行,以期帮助它早日完成欲 join 的任务。
5. 如果偷走任务的小偷也已经把自己的任务全部做完,正在等待需要 join 的任务时,则找到小偷的小偷,帮助它完成它的任务。
6. 递归地执行第5步。