• 死磕 java线程系列之ForkJoinPool深入解析


    forkjoinpool

    (手机横屏看源码更方便)


    注:java源码分析部分如无特殊说明均基于 java8 版本。

    注:本文基于ForkJoinPool分治线程池类。

    简介

    随着在硬件上多核处理器的发展和广泛使用,并发编程成为程序员必须掌握的一门技术,在面试中也经常考查面试者并发相关的知识。

    今天,我们就来看一道面试题:

    如何充分利用多核CPU,计算很大数组中所有整数的和?

    剖析

    • 单线程相加?

    我们最容易想到就是单线程相加,一个for循环搞定。

    • 线程池相加?

    如果进一步优化,我们会自然而然地想到使用线程池来分段相加,最后再把每个段的结果相加。

    • 其它?

    Yes,就是我们今天的主角——ForkJoinPool,但是它要怎么实现呢?似乎没怎么用过哈^^

    三种实现

    OK,剖析完了,我们直接来看三种实现,不墨迹,直接上菜。

    /**
     * 计算1亿个整数的和
     */
    public class ForkJoinPoolTest01 {
        public static void main(String[] args) throws ExecutionException, InterruptedException {
            // 构造数据
            int length = 100000000;
            long[] arr = new long[length];
            for (int i = 0; i < length; i++) {
                arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
            }
            // 单线程
            singleThreadSum(arr);
            // ThreadPoolExecutor线程池
            multiThreadSum(arr);
            // ForkJoinPool线程池
            forkJoinSum(arr);
    
        }
    
        private static void singleThreadSum(long[] arr) {
            long start = System.currentTimeMillis();
    
            long sum = 0;
            for (int i = 0; i < arr.length; i++) {
                // 模拟耗时,本文由公从号“彤哥读源码”原创
                sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
            }
    
            System.out.println("sum: " + sum);
            System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));
    
        }
    
        private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
            long start = System.currentTimeMillis();
    
            int count = 8;
            ExecutorService threadPool = Executors.newFixedThreadPool(count);
            List<Future<Long>> list = new ArrayList<>();
            for (int i = 0; i < count; i++) {
                int num = i;
                // 分段提交任务
                Future<Long> future = threadPool.submit(() -> {
                    long sum = 0;
                    for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
                        try {
                            // 模拟耗时
                            sum += (arr[j]/3*3/3*3/3*3/3*3/3*3);
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                    return sum;
                });
                list.add(future);
            }
    
            // 每个段结果相加
            long sum = 0;
            for (Future<Long> future : list) {
                sum += future.get();
            }
    
            System.out.println("sum: " + sum);
            System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
        }
    
        private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
            long start = System.currentTimeMillis();
    
            ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
            // 提交任务
            ForkJoinTask<Long> forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));
            // 获取结果
            Long sum = forkJoinTask.get();
    
            forkJoinPool.shutdown();
    
            System.out.println("sum: " + sum);
            System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
        }
    
        private static class SumTask extends RecursiveTask<Long> {
            private long[] arr;
            private int from;
            private int to;
    
            public SumTask(long[] arr, int from, int to) {
                this.arr = arr;
                this.from = from;
                this.to = to;
            }
    
            @Override
            protected Long compute() {
                // 小于1000的时候直接相加,可灵活调整
                if (to - from <= 1000) {
                    long sum = 0;
                    for (int i = from; i < to; i++) {
                        // 模拟耗时
                        sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
                    }
                    return sum;
                }
    
                // 分成两段任务,本文由公从号“彤哥读源码”原创
                int middle = (from + to) / 2;
                SumTask left = new SumTask(arr, from, middle);
                SumTask right = new SumTask(arr, middle, to);
    
                // 提交左边的任务
                left.fork();
                // 右边的任务直接利用当前线程计算,节约开销
                Long rightResult = right.compute();
                // 等待左边计算完毕
                Long leftResult = left.join();
                // 返回结果
                return leftResult + rightResult;
            }
        }
    }
    

    彤哥偷偷地告诉你,实际上计算1亿个整数相加,单线程是最快的,我的电脑大概是100ms左右,使用线程池反而会变慢。

    所以,为了演示ForkJoinPool的牛逼之处,我把每个数都/3*3/3*3/3*3/3*3/3*3了一顿操作,用来模拟计算耗时。

    来看结果:

    sum: 107352457433800662
    single thread elapse: 789
    sum: 107352457433800662
    multi thread elapse: 228
    sum: 107352457433800662
    fork join elapse: 189
    

    可以看到,ForkJoinPool相对普通线程池还是有很大提升的。

    问题:普通线程池能否实现ForkJoinPool这种计算方式呢,即大任务拆中任务,中任务拆小任务,最后再汇总?

    forkjoinpool

    你可以试试看(-᷅_-᷄)

    OK,下面我们正式进入ForkJoinPool的解析。

    分治法

    • 基本思想

    把一个规模大的问题划分为规模较小的子问题,然后分而治之,最后合并子问题的解得到原问题的解。

    • 步骤

    (1)分割原问题:

    (2)求解子问题:

    (3)合并子问题的解为原问题的解。

    在分治法中,子问题一般是相互独立的,因此,经常通过递归调用算法来求解子问题。

    • 典型应用场景

    (1)二分搜索

    (2)大整数乘法

    (3)Strassen矩阵乘法

    (4)棋盘覆盖

    (5)归并排序

    (6)快速排序

    (7)线性时间选择

    (8)汉诺塔

    ForkJoinPool继承体系

    ForkJoinPool是 java 7 中新增的线程池类,它的继承体系如下:

    forkjoinpool

    ForkJoinPool和ThreadPoolExecutor都是继承自AbstractExecutorService抽象类,所以它和ThreadPoolExecutor的使用几乎没有多少区别,除了任务变成了ForkJoinTask以外。

    这里又运用到了一种很重要的设计原则——开闭原则——对修改关闭,对扩展开放。

    可见整个线程池体系一开始的接口设计就很好,新增一个线程池类,不会对原有的代码造成干扰,还能利用原有的特性。

    ForkJoinTask

    两个主要方法

    • fork()

    fork()方法类似于线程的Thread.start()方法,但是它不是真的启动一个线程,而是将任务放入到工作队列中。

    • join()

    join()方法类似于线程的Thread.join()方法,但是它不是简单地阻塞线程,而是利用工作线程运行其它任务。当一个工作线程中调用了join()方法,它将处理其它任务,直到注意到目标子任务已经完成了。

    三个子类

    • RecursiveAction

    无返回值任务。

    • RecursiveTask

    有返回值任务。

    • CountedCompleter

    无返回值任务,完成任务后可以触发回调。

    ForkJoinPool内部原理

    ForkJoinPool内部使用的是“工作窃取”算法实现的。

    forkjoinpool

    (1)每个工作线程都有自己的工作队列WorkQueue;

    (2)这是一个双端队列,它是线程私有的;

    (3)ForkJoinTask中fork的子任务,将放入运行该任务的工作线程的队头,工作线程将以LIFO的顺序来处理工作队列中的任务;

    (4)为了最大化地利用CPU,空闲的线程将从其它线程的队列中“窃取”任务来执行;

    (5)从工作队列的尾部窃取任务,以减少竞争;

    (6)双端队列的操作:push()/pop()仅在其所有者工作线程中调用,poll()是由其它线程窃取任务时调用的;

    (7)当只剩下最后一个任务时,还是会存在竞争,是通过CAS来实现的;

    forkjoinpool

    ForkJoinPool最佳实践

    (1)最适合的是计算密集型任务,本文由公从号“彤哥读源码”原创;

    (2)在需要阻塞工作线程时,可以使用ManagedBlocker;

    (3)不应该在RecursiveTask的内部使用ForkJoinPool.invoke()/invokeAll();

    总结

    (1)ForkJoinPool特别适合于“分而治之”算法的实现;

    (2)ForkJoinPool和ThreadPoolExecutor是互补的,不是谁替代谁的关系,二者适用的场景不同;

    (3)ForkJoinTask有两个核心方法——fork()和join(),有三个重要子类——RecursiveAction、RecursiveTask和CountedCompleter;

    (4)ForkjoinPool内部基于“工作窃取”算法实现;

    (5)每个线程有自己的工作队列,它是一个双端队列,自己从队列头存取任务,其它线程从尾部窃取任务;

    (6)ForkJoinPool最适合于计算密集型任务,但也可以使用ManagedBlocker以便用于阻塞型任务;

    (7)RecursiveTask内部可以少调用一次fork(),利用当前线程处理,这是一种技巧;

    彩蛋

    ManagedBlocker怎么使用?

    答:ManagedBlocker相当于明确告诉ForkJoinPool框架要阻塞了,ForkJoinPool就会启另一个线程来运行任务,以最大化地利用CPU。

    请看下面的例子,自己琢磨哈^^。

    /**
     * 斐波那契数列
     * 一个数是它前面两个数之和
     * 1,1,2,3,5,8,13,21
     */
    public class Fibonacci {
    
    	public static void main(String[] args) {
    		long time = System.currentTimeMillis();
            Fibonacci fib = new Fibonacci();
    		int result = fib.f(1_000).bitCount();
    		time = System.currentTimeMillis() - time;
    		System.out.println("result,本文由公从号“彤哥读源码”原创 = " + result);
            System.out.println("test1_000() time = " + time);
    	}
    
    	public BigInteger f(int n) {
    		Map<Integer, BigInteger> cache = new ConcurrentHashMap<>();
    		cache.put(0, BigInteger.ZERO);
    		cache.put(1, BigInteger.ONE);
    		return f(n, cache);
    	}
    
    	private final BigInteger RESERVED = BigInteger.valueOf(-1000);
    
    	public BigInteger f(int n, Map<Integer, BigInteger> cache) {
    		BigInteger result = cache.putIfAbsent(n, RESERVED);
    		if (result == null) {
    
    			int half = (n + 1) / 2;
    
    			RecursiveTask<BigInteger> f0_task = new RecursiveTask<BigInteger>() {
    				@Override
    				protected BigInteger compute() {
    					return f(half - 1, cache);
    				}
    			};
    			f0_task.fork();
    
    			BigInteger f1 = f(half, cache);
    			BigInteger f0 = f0_task.join();
    
    			long time = n > 10_000 ? System.currentTimeMillis() : 0;
    			try {
    
    				if (n % 2 == 1) {
    					result = f0.multiply(f0).add(f1.multiply(f1));
    				} else {
    					result = f0.shiftLeft(1).add(f1).multiply(f1);
    				}
    				synchronized (RESERVED) {
    					cache.put(n, result);
    					RESERVED.notifyAll();
    				}
    			} finally {
    				time = n > 10_000 ? System.currentTimeMillis() - time : 0;
    				if (time > 50)
    					System.out.printf("f(%d) took %d%n", n, time);
    			}
    		} else if (result == RESERVED) {
    			try {
    				ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache);
    				ForkJoinPool.managedBlock(blocker);
    				result = blocker.result;
    			} catch (InterruptedException e) {
    				throw new CancellationException("interrupted");
    			}
    
    		}
    		return result;
    		// return f(n - 1).add(f(n - 2));
    	}
    
    	private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker {
    		private BigInteger result;
    		private final int n;
    		private final Map<Integer, BigInteger> cache;
    
    		public ReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) {
    			this.n = n;
    			this.cache = cache;
    		}
    
    		@Override
    		public boolean block() throws InterruptedException {
    			synchronized (RESERVED) {
    				while (!isReleasable()) {
    					RESERVED.wait();
    				}
    			}
    			return true;
    		}
    
    		@Override
    		public boolean isReleasable() {
    			return (result = cache.get(n)) != RESERVED;
    		}
    	}
    }
    

    欢迎关注我的公众号“彤哥读源码”,查看更多源码系列文章, 与彤哥一起畅游源码的海洋。

    qrcode

  • 相关阅读:
    excel 2003系列
    DataTab转换XML XML转换DataTable 的类[转]
    全角转半角与半角转全角
    Day2
    Day6 && Day7图论
    Day1
    Android为何以及如何保存Fragment实例
    Android应用的本地化及知识拓展之配置修饰符
    Leetcode NO.136 Single Number 只出现一次的数字
    经典排序算法(四) —— Quick Sort 快速排序
  • 原文地址:https://www.cnblogs.com/tong-yuan/p/11824018.html
Copyright © 2020-2023  润新知