Skip to content

并发编程之 ForkJoin 分而治之

分而治之

forkjoin 在处理某一类问题时非常的有用,哪一类问题?分而治之的问题。

十大计算机经典算法:快速排序、堆排序、归并排序、二分查找、线性查找、 深度优先、广度优先、Dijkstra、动态规划、朴素贝叶斯分类,有几个属于分 而治之?3 个,快速排序、归并排序、二分查找,还有大数据中 M/R 都是。

分治法的设计思想是:将一个难以直接解决的大问题,分割成一些规模较小的相同问题,以便各个击破,分而治之。

分治策略是:对于一个规模为 n 的问题,若该问题可以容易地解决(比如说 规模 n 较小)则直接解决,否则将其分解为 k 个规模较小的子问题,这些子问题互相独立且与原问题形式相同(子问题相互之间有联系就会变为动态规范算法),递归地解这些子问题,然后将各子问题的解合并得到原问题的解。这种算法设计策略叫做分治法。

Fork/Join 使用的标准范式

我们要使用 ForkJoin 框架,必须首先创建一个 ForkJoin 任务。它提供在任务 中执行 fork 和 join 的操作机制,通常我们不直接继承 ForkjoinTask 类,只需要直接继承其子类。

  1. RecursiveAction,用于没有返回结果的任务
  2. RecursiveTask,用于有返回值的任务

task 要通过 ForkJoinPool 来执行,使用 submit 或 invoke 提交,两者的区别是:invoke 是同步执行,调用之后需要等待任务完成,才能执行后面的代码; submit 是异步执行。

join() 和 get 方法当任务完成的时候返回计算结果。

在我们自己实现的 compute 方法里,首先需要判断任务是否足够小,如果足够小就直接执行任务。如果不足够小,就必须分割成两个子任务,每个子任务在调用 invokeAll 方法时,又会进入 compute 方法,看看当前子任务是否需要继续分割成孙任务,如果不需要继续分割,则执行当前子任务并返回结果。使用 join 方法会等待子任务执行完并得到其结果。

代码示例

java
package com.mengweijin.learning.basic.thread;

import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import java.util.stream.LongStream;

/**
 * 生成一个计算任务,计算1+2+3+4+......
 *
 * 16:15:22.317 [main] DEBUG com.mengweijin.learning.basic.thread.ForkJoinDemo - sum=500000500000, time=51
 * 16:15:31.537 [main] DEBUG com.mengweijin.learning.basic.thread.ForkJoinDemo - sum=500000500000, time=9201
 * 16:15:31.834 [main] DEBUG com.mengweijin.learning.basic.thread.ForkJoinDemo - sum=500000500000, time=296
 *
 * 测试发现,使用 ForkJoin 必须要有优秀的拆分算法,否则性能不高。一般多线程情况下,可以优先使用 java 8 中的并行流,性能更高。
 * @author mengweijin
 */
@Slf4j
public class ForkJoinDemo {

    private static final long MAX_NUMBER = 1_000_000L;

    public static void main(String[] args) {
        forSum();
        forkJoin();
        lambda();
    }

    public static void forSum(){
        long start = System.currentTimeMillis();

        Long sum = 0L;
        for (Long i = 1L; i <= MAX_NUMBER; i++) {
            sum += i;
        }

        long end = System.currentTimeMillis();
        log.debug("sum={}, time={}", sum, end-start);
    }

    @SneakyThrows
    public static void forkJoin(){
        long start = System.currentTimeMillis();

        ForkJoinPool forkjoinPool = new ForkJoinPool();
        TaskDemo task = new TaskDemo(1, MAX_NUMBER);
        Future<Long> future = forkjoinPool.submit(task);
        Long sum = future.get();

        long end = System.currentTimeMillis();
        log.debug("sum={}, time={}", sum, end-start);
    }

    public static void lambda(){
        long start = System.currentTimeMillis();

        //jdk8 stream流式计算
        long sum = LongStream.rangeClosed(0L, MAX_NUMBER).parallel()
                .reduce(0, Long::sum);

        long end = System.currentTimeMillis();
        log.debug("sum={}, time={}", sum, end-start);
    }

    /**
     * ForkJoinTask 有两个子类:
     * 1. 有返回值的 RecursiveTask
     * 2. 无返回值的 RecursiveAction
     * 可以根据实际需要使用不同的类。
     */
    static class TaskDemo extends RecursiveTask<Long> {
        // 阈值,设置为 2 就是二分法
        public static final int threshold = 2;
        private long start;
        private long end;

        public TaskDemo(long start, long end) {
            this.start = start;
            this.end = end;
        }

        @Override
        protected Long compute() {
            //如果任务足够小就计算任务
            if ((end - start) <= threshold) {
                long sum = 0;
                for (long i = start; i <= end; i++) {
                    sum += i;
                }
                return sum;
            } else {
                // 如果任务大于阈值,就分裂成两个子任务计算
                long middle = (start + end) / 2;
                TaskDemo leftTask = new TaskDemo(start, middle);
                TaskDemo rightTask = new TaskDemo(middle + 1, end);
                // 执行子任务
                leftTask.fork();
                rightTask.fork();
                // 等待任务执行结束合并其结果
                long leftResult = leftTask.join();
                long rightResult = rightTask.join();
                // 合并子任务
                return leftResult + rightResult;
            }
        }
    }
}