package example.test; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.RecursiveTask; public class TSKK { public static void main(String[] args) throws Exception { // 创建随机数组成的数组: Long[] array = new Long[40000000]; fillRandom(array); test01(array); test02(array); } public static void test01(Long[] array) { // fork/join task: ForkJoinPool fjp = new ForkJoinPool(4); // 最大并发数4 ForkJoinTask<Long> task = new SumTask(array, 0, array.length); long startTime = System.currentTimeMillis(); Long result = fjp.invoke(task); long endTime = System.currentTimeMillis(); System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms."); } public static void test02(Long[] array) { List<Long> arrays = new ArrayList<>(Arrays.asList(array)); long startTime = System.currentTimeMillis(); Long result = arrays.stream().parallel().reduce(0L, Long::sum); long endTime = System.currentTimeMillis(); System.out.println("stream/join sum: " + result + " in " + (endTime - startTime) + " ms."); } public static void fillRandom(Long[] array) { Random random = new Random(); for (int i = 0; i < array.length; i++) { array[i] = random.nextLong(); } } } @SuppressWarnings("serial") class SumTask extends RecursiveTask<Long> { static final int THRESHOLD = 100000; Long[] array; int start; int end; SumTask(Long[] array, int start, int end) { this.array = array; this.start = start; this.end = end; } @Override protected Long compute() { if (end - start <= THRESHOLD) { // 如果任务足够小,直接计算: long sum = 0; for (int i = start; i < end; i++) { sum += array[i]; } // System.out.println(String.format("compute %d~%d = %d", start, end, sum)); return sum; } // 任务太大,一分为二: int middle = (end + start) / 2; // System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, // start, middle, middle, end)); SumTask subtask1 = new SumTask(this.array, start, middle); SumTask subtask2 = new SumTask(this.array, middle, end); invokeAll(subtask1, subtask2); Long subresult1 = subtask1.join(); Long subresult2 = subtask2.join(); Long result = subresult1 + subresult2; // System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + // result); return result; } }