ForkJoinPool
ForkJoinPool是一种“分治算法”的多线程并行计算框架,自Java7引入。它将一个大的任务分为若干个子任务,这些子任务分别计算,然后合并出最终结果。
ForkJoinPool比普通的线程池可以更好地实现计算的负载均衡,提高资源利用率。
创建ForkJoinPool
构造方法
共有三个public的构造方法,最多的有4个参数,分别是
并行度、工作线程工厂,线程未捕获异常的处理器、工作队列模式(FIFO或LIFO,默认是LIFO)、工作线程名称前缀。一般在使用无参或一个参数的构造方法即可,(或者使用commonPool),如果需要定制线程继承ForkJoinWorkerThread,则使用4个参数的构造方法。
//ForkJoinPool.commonPool();
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool(int parallelism) {
this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism),
checkFactory(factory), handler, asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
"ForkJoinPool-" + nextPoolId() + "-worker-");
checkPermission();
}
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode, String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
RecursiveAction和RecursiveTask
这两个类都是ForkJoinTask子类,用于实现子任务的逻辑。区别是
前者没有返回值,后者有返回值。使用时,针对不同类型的任务,可以分别继承这两个类,实现其compute方法。
使用
案例1(RecursiveAction):快速排序
基本思想:
1、利用数组的某个元素(一般取第一个元素)把数组划分成两半,左边子数组里面的元素小于等于该元素,右边子数组里面的元素大于等于该元素。
2、对左右的2个子数组分别排序。
将数组划分为两部分后,对子数组分别排序是独立的子问题,这个过程可以递归分解子问题,所以可以利用多个线程分别为子数组排序。
package com.example.demo;
import org.junit.Test;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.StringJoiner;
import java.util.concurrent.*;
import java.util.stream.Stream;
public class ForkJoinTest {
ForkJoinPool pool = ForkJoinPool.commonPool();
SecureRandom random = new SecureRandom();
@Test
public void testSort() throws ExecutionException, InterruptedException {
StringJoiner before = new StringJoiner(",","[","]");
//20个数的数组
int[] arr = getRandomIntArray(20);
Arrays.stream(arr).mapToObj(String::valueOf).forEach(before::add);
System.out.println(before.toString());
Instant start = Instant.now();
ForkJoinTask<Void> task = pool.submit(new QuickSortTask(arr));
//阻塞直到完成排序
task.get();
long i = Duration.between(start,Instant.now()).get(ChronoUnit.NANOS);
System.out.println("排序时间:" + i + "纳秒");
StringJoiner after = new StringJoiner(",","[","]");
Arrays.stream(arr).mapToObj(String::valueOf).forEach(after::add);
System.out.println(after.toString());
}
private int[] getRandomIntArray(int count) {
int bound = count * 10;
int[] array = new int[count];
for (int i = 0; i < count; i++){
array[i] = random.nextInt(bound);
}
return array;
}
public static class QuickSortTask extends RecursiveAction {
private int start;
private int end;
private int[] array;
public QuickSortTask(int[] array){
this.array = array;
this.start = 0;
this.end = array.length-1;
}
public QuickSortTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
int mid = part(array, start, end);
//当左边还有元素时
if (mid != start) {
QuickSortTask task1 = new QuickSortTask(array, start, mid - 1);
task1.fork();
task1.join();
}
//当右边还有元素时
if (mid != end) {
QuickSortTask task2 = new QuickSortTask(array, mid + 1, end);
task2.fork();
task2.join();
}
}
/**
* <p>返回基准值的下标,基准值左的元素都小于等于基准值,基准值右的元素大于等于基准值</p>
* @param array
* @param start
* @param end
* @return
*/
private int part(int[] array, int start, int end) {
int i = start, j = end;
//基准值的下标
int base = start;
//左右扫描相遇时结束
while (i < j) {
//从右向左扫描,如果当前值比基准值小,则置换,已经置换过的元素不再扫描(j的右边)
while (i < j && array[j] >= array[base]) {
j--;
}
if (i < j) {
swap(array, j, base);
base = j;
}
//从左向右扫描,如果当前值比基准值大,则置换,已经置换过的元素不再扫描(i的左边)
while (i < j && array[i] <= array[base]) {
i++;
}
if (i < j) {
swap(array,i, base);
base = i;
}
}
return base;
}
private void swap(int[] array, int x, int y) {
if (x != y) {
int temp = array[x];
array[x] = array[y];
array[y] = temp;
}
}
}
}
案例2(RecursiveTask):求和
例如从1加到100,如果不用高斯的方法,可以用程序实现累加,将数拆分成小组,每个小组互相独立,每个小组组内分别相加,最后把组的结果相加,这个过程可以使用ForkJoin。
RecursiveTask<T>
package com.example.demo;
import org.junit.Test;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.StringJoiner;
import java.util.concurrent.*;
import java.util.stream.Stream;
public class ForkJoinTest {
ForkJoinPool pool = ForkJoinPool.commonPool();
@Test
public void testSum() throws ExecutionException, InterruptedException {
ForkJoinTask<Integer> submit = pool.submit(new SumTask(1, 100));
System.out.println(submit.get());
}
public static class SumTask extends RecursiveTask<Integer>{
private int startNum;
private int endNum;
//决定当子任务处理的元素个数小于此值时不再切分任务,直接进行计算
private static final int THRESHOLD = 10;
public SumTask(int startNum, int endNum){
this.startNum = startNum;
this.endNum = endNum;
}
@Override
protected Integer compute() {
int sum = 0;
if (endNum - startNum + 1 < THRESHOLD){
for (int i = startNum; i <= endNum; i++){
sum += i;
}
return sum;
}
int mid = split(startNum, endNum);
SumTask task1 = new SumTask(startNum, mid);
SumTask task2 = new SumTask(mid + 1, endNum);
ForkJoinTask<Integer> fork1 = task1.fork();
ForkJoinTask<Integer> fork2 = task2.fork();
sum = fork1.join() + fork2.join();
return sum;
}
private int split(int startNum, int endNum) {
return (startNum + endNum)/2;
}
}
}
关闭
和ThreadPoolExecutor一样,ForkJoinPool使用完也要关闭,依然是使用shutdown和shutdownNow方法,shutdown只拒绝新提交的任务;shutdownNow会取消现有的全局队列和局部队列中的任务,同时唤醒所有空闲的线程,让这些线程自动退出。
public void shutdown();
public List<Runnable> shutdownNow();
ForkJoinPool pl=new ForkJoinPool();
try {
boolean flag;
do {
flag = pl.awaitTermination(500,TimeUnit.MILLISECONDS);
} while (!flag);
} catch (Exception e){
e.printStackTrace();
}