我一直在尝试使用 Java 中的 ForkJoin 进行快速排序,而在大部分情况下,当数组的每个元素相等时,我会出现堆栈溢出。
import java.util.Comparator;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
public class QuickSortMultiThreading extends RecursiveAction {
private final int threshold = 50000;
private static Apartment[] arr;
private int start, end;
private static final Comparator<Apartment> comparator = new Apartment.ApartmentComparator();
public QuickSortMultiThreading(Apartment[] arr, int start, int end) {
this.arr = arr;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if(arr.length == 0){
return;
}
if (end - start <= threshold) {
quickSort(start, end);
} else {
if(start < end)
{
int p = partition(start, end);
QuickSortMultiThreading left = new QuickSortMultiThreading(arr, start, p - 1);
QuickSortMultiThreading right = new QuickSortMultiThreading(arr, p + 1, end);
left.fork();
right.compute();
left.join();
}
}
}
static int partition(int low, int high)
{
Apartment pivot = arr[high];
int i = (low - 1);
for (int j = low; j <= high - 1; j++) {
if (comparator.compare(arr[j], pivot) < 0) {
i++;
swap(i, j);
}
}
swap(i + 1, high);
return (i + 1);
}
private static void swap(int i, int j) {
Apartment temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
private void quickSort(int low, int high) {
while (low < high) {
int pi = partition(low, high);
if (pi - low < high - pi) {
quickSort(low, pi - 1);
low = pi + 1;
} else {
quickSort(pi + 1, high);
high = pi - 1;
}
}
}
public static void main(String[] args) {
int size = 100000;
Apartment[] arr = new Apartment[size];
for (int j = 0; j < size; j++) {
int var1 = 1;
int var2 = 2;
int var3 = 3;
arr[j] = new Apartment(var1, var2, var3);
}
ForkJoinPool pool = new ForkJoinPool(12);
pool.invoke(new QuickSortMultiThreading(arr, 0, size - 1));
}
}
我也有一个不存在相同问题的顺序实现
import java.io.*;
import java.util.Comparator;
import java.util.Random;
class SequentialQuicksort {
static void swap(Apartment[] arr, int i, int j)
{
Apartment temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
static int partition(Apartment[] arr, int low, int high, Comparator<Apartment> comparator) {
Apartment pivot = arr[high];
int i = (low - 1);
for (int j = low; j <= high - 1; j++) {
if (comparator.compare(arr[j], pivot) < 0) {
i++;
swap(arr, i, j);
}
}
swap(arr, i + 1, high);
return (i + 1);
}
static void quickSort(Apartment[] arr, int low, int high, Comparator<Apartment> comparator)
{
while (low < high) {
int pi = partition(arr, low, high, comparator);
if (pi - low < high - pi) {
quickSort(arr, low, pi - 1, comparator);
low = pi + 1;
} else {
quickSort(arr, pi + 1, high, comparator);
high = pi - 1;
}
}
}
public static void main(String[] args)
{
int size = 100000;
Apartment[] arr = new Apartment[size];
for (int j = 0; j < size; j++) {
int var1 = 1;
int var2 = 2;
int var3 = 3;
arr[j] = new Apartment(var1, var2, var3);
}
quickSort(arr, 0, arr.length - 1, new Apartment.ApartmentComparator());
}
}
据我了解,问题可能来自无限递归,但我不确定如何避免并行实现
问题很可能出在这里:
if(start < end)
{
int p = partition(start, end);
QuickSortMultiThreading left = new QuickSortMultiThreading(arr, start, p - 1);
QuickSortMultiThreading right = new QuickSortMultiThreading(arr, p + 1, end);
left.fork();
right.compute();
left.join();
}
代码使用Lomuto分区方案,并且在所有相等元素的情况下,拆分为1个元素和n-2个元素,这将导致堆栈溢出。代码需要做一些类似于quicksort()的事情,在较小的分区上递归,在较大的分区上循环以避免堆栈溢出,将堆栈使用限制为O(log(n))。最坏情况时间复杂度仍为 O(n^2)。
while(start < end){