使用多个线程对数组进行快速排序

思路分析

为了实现数组的快速排序,可以采用分而治之的思想,将数组拆分为多个小的数组,然后再进行排序,最后将所有小数组的排序结果汇总在一起,如下图所示:

代码案例

这里主要使用到 Java 多线程编程中的 Fork/Join 框架。在下面的代码中,QuickSortTask 类继承了 RecursiveAction 类,并实现了快速排序算法。在 compute() 方法中,首先会选择一个枢轴值,然后对数组进行分区,并创建两个子任务来处理左右两部分,最后使用 invokeAll() 方法来执行这两个子任务。另外,parallelQuickSort() 方法创建了一个 ForkJoinPool 实例,并通过 invoke() 方法提交了一个初始的 QuickSortTask 任务。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import java.util.Arrays;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

public class ParallelQuickSort {

private static class QuickSortTask extends RecursiveAction {
private final int[] array;
private final int left;
private final int right;

public QuickSortTask(int[] array, int left, int right) {
this.array = array;
this.left = left;
this.right = right;
}

@Override
protected void compute() {
if (left < right) {
int pivotIndex = partition(array, left, right);
QuickSortTask leftTask = new QuickSortTask(array, left, pivotIndex - 1);
QuickSortTask rightTask = new QuickSortTask(array, pivotIndex + 1, right);
invokeAll(leftTask, rightTask);
}
}

private int partition(int[] array, int left, int right) {
int pivot = array[right];
int i = left - 1;
for (int j = left; j < right; j++) {
if (array[j] <= pivot) {
i++;
swap(array, i, j);
}
}
swap(array, i + 1, right);
return i + 1;
}

private void swap(int[] array, int i, int j) {
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
}

/**
* 快速排序
*/
public static void parallelQuickSort(int[] array) {
ForkJoinPool pool = new ForkJoinPool();
pool.invoke(new QuickSortTask(array, 0, array.length - 1));
}

public static void main(String[] args) {
int[] array = {9, 5, 3, 7, 2, 8, 6, 1, 4};
System.out.println("Original array: " + Arrays.toString(array) + ", timestamp: " + System.currentTimeMillis());
parallelQuickSort(array);
System.out.println("Sorted array: " + Arrays.toString(array) + ", timestamp: " + System.currentTimeMillis());
}

}

程序运行输出的结果:

1
2
Original array: [9, 5, 3, 7, 2, 8, 6, 1, 4], timestamp: 1716809594701
Sorted array: [1, 2, 3, 4, 5, 6, 7, 8, 9], timestamp: 1716809594757

参考资料