Quick sort method is giving me a stackoverflowerror? - java

I have a quick sort method which sorts elements in ascenidng order but seem to keep getting a stackoverflowerror.
For some reason it is showing the error on the while loop, when the logic makes sense to me.
Here is the code for the quick sort class:
public T[] sort(T[] arr, int left, int right)
{
int l = left;
int r = right;
if (right <= left)
return null;
//Find the pivot in the middle
T pivot = arr[(left + (right - left)) / 2];
T temp;
while (l <= r)
{
// check values on left are bigger than the pivot
while (arr[l].compareTo(pivot) < 0)
{
l++;
}
// check if values are smaller than the pivot
while (arr[r].compareTo(pivot) > 0)
{
r--;
}
// l and r have gone past each other swap them
if (l <= r)
{
//swap process
temp = arr[l];
arr[l] = arr[r];
arr[r] = temp;
// left pointer goes up 1
// right pointer goes down 1
l++;
r--;
}
}
if (left < r)
sort(arr, left, r);
if (l < right)
sort(arr, l, right);
return arr;
}
The error seems to be pointing to
//Find the pivot in the middle
T pivot = arr[(left + (right - left)) / 2];
I then seem to be getting many occuring errors.

I believe you are calculating Pivot incorrectly
You should use T pivot = arr[left + (right - left) / 2];
Below is working quick sort program using middle element as pivot:
public void quickSort(T arr[],int left, int right){
int low =left, high = right;
int pivot = arr[left + (right - left) / 2];
while(low<=high){
while (arr[low] < pivot) {
low++;
}
while (arr[high] > pivot) {
high--;
}
if (low <= high) {
int temp = arr[low];
arr[low] = arr[high];
arr[high] = temp;
low++;
high--;
}
if (left < high) {
quickSort(arr,left, high);
}
if (low < high) {
quickSort(arr,low, right);
}
}
}
Hope it helps !!

Stream API way:
T[] result = Arrays.stream(a)
.skip(left)
.limit(right - left)
.sorted((o1, o2) -> {{you logic of comparing}})
.toArray(String[]::new);
the difference in what you get only sorted part of an array. So you should concatenate them after if it is necessary.

You have a typo in that line, yes. T pivot = arr[(left + (right - left)) / 2];
Due to the extra parentheses, everything is divided by 2. It should be
T pivot = arr[left + (right - left) / 2];
A couple stylistical remarks:
as this is an in-place sort, returning T[] is not really necessary
T temp could be moved into the swap-block
Putting it together:
import java.util.Arrays;
public class QuickSort<T extends Comparable<? super T>> {
public void sort (T[] arr)
{
if (arr == null || arr.length <= 1)
return;
sort(arr, 0, arr.length - 1);
}
public void sort(T[] arr, int left, int right)
{
int l = left;
int r = right;
if (right <= left)
return;
//Find the pivot in the middle
T pivot = arr[(left + (right - left)/2)];
while (l <= r)
{
// check values on left are bigger than the pivot
while (arr[l].compareTo(pivot) < 0)
{
l++;
}
// check if values are smaller than the pivot
while (arr[r].compareTo(pivot) > 0)
{
r--;
}
// l and r have gone past each other swap them
if (l <= r)
{
//swap process
T temp = arr[l];
arr[l] = arr[r];
arr[r] = temp;
// left pointer goes up 1
// right pointer goes down 1
l++;
r--;
}
}
if (left < r)
sort(arr, left, r);
if (l < right)
sort(arr, l, right);
}
public static void main(String args[])
{
Integer[] numbers=new Integer[] {3,2,5,4,1};
System.out.println(Arrays.asList(numbers));
new QuickSort<Integer>().sort(numbers);
System.out.println(Arrays.asList(numbers));
}
}
Output:
[3, 2, 5, 4, 1]
[1, 2, 3, 4, 5]

Related

How to implement the medians of medians algorithm in Java

I am trying to implement the median of medians algorithm in Java. The algorithm shall determine the median of a set of numbers. I tried to implement the pseudo code on wikipedia:
https://en.wikipedia.org/wiki/Median_of_medians
I am getting a buffer overflow and don't know why. Due to the recursions it's quite difficult to keep track of the code for me.
import java.util.Arrays;
public class MedianSelector {
private static final int CHUNK = 5;
public static void main(String[] args) {
int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
lowerMedian(test);
System.out.print(Arrays.toString(test));
}
/**
* Computes and retrieves the lower median of the given array of
* numbers using the Median algorithm presented in the lecture.
*
* #param input numbers.
* #return the lower median.
* #throw IllegalArgumentException if the array is {#code null} or empty.
*/
public static int lowerMedian(int[] numbers) {
if(numbers == null || numbers.length == 0) {
throw new IllegalArgumentException();
}
return numbers[select(numbers, 0, numbers.length - 1, (numbers.length - 1) / 2)];
}
private static int select(int[] numbers, int left, int right, int i) {
if(left == right) {
return left;
}
int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);
if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
return select(numbers, left, pivotIndex - 1, i);
}else {
return select(numbers, left, pivotIndex + 1, i);
}
}
private static int pivot(int numbers[], int left, int right) {
if(right - left < CHUNK) {
return partition5(numbers, left, right);
}
for(int i=left; i<=right; i=i+CHUNK) {
int subRight = i + (CHUNK-1);
if(subRight > right) {
subRight = right;
}
int medChunk = partition5(numbers, i, subRight);
int tmp = numbers[medChunk];
numbers[medChunk] = numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))];
numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))] = tmp;
}
int mid = (right - left) / 10 + left +1;
return select(numbers, left, (int) (left + Math.floor((right - left) / CHUNK)), mid);
}
private static int partition(int[] numbers, int left, int right, int idx, int k) {
int pivotVal = numbers[idx];
int storeIndex = left;
int storeIndexEq = 0;
int tmp = 0;
tmp = numbers[idx];
numbers[idx] = numbers[right];
numbers[right] = tmp;
for(int i=left; i<right; i++) {
if(numbers[i] < pivotVal) {
tmp = numbers[i];
numbers[i] = numbers[storeIndex];
numbers[storeIndex] = tmp;
storeIndex++;
}
}
storeIndexEq = storeIndex;
for(int i=storeIndex; i<right; i++) {
if(numbers[i] == pivotVal) {
tmp = numbers[i];
numbers[i] = numbers[storeIndexEq];
numbers[storeIndexEq] = tmp;
storeIndexEq++;
}
}
tmp = numbers[right];
numbers[right] = numbers[storeIndexEq];
numbers[storeIndexEq] = tmp;
if(k < storeIndex) {
return storeIndex;
}
if(k <= storeIndexEq) {
return k;
}
return storeIndexEq;
}
//Insertion sort
private static int partition5(int[] numbers, int left, int right) {
int i = left + 1;
int j = 0;
while(i<=right) {
j= i;
while(j>left && numbers[j-1] > numbers[j]) {
int tmp = numbers[j-1];
numbers[j-1] = numbers[j];
numbers[j] = tmp;
j=j-1;
}
i++;
}
return left + (right - left) / 2;
}
}
Confirm n (in the pseudo code) or i (in my code) stand for the position of the median? So lets assume our array is number = {9,8,7,6,5,4,3,2,1,0}. I would call select{numbers, 0, 9,4), correct?
I don't understand the calculation of mid in pivot? Why is there a division by 10? Maybe there is a mistake in the pseudo code?
Thanks for your help.
EDIT: It turns out the switch from iteration to recursion was a red herring. The actual issue, identified by the OP, was in the arguments to the 2nd recursive select call.
This line:
return select(numbers, left, pivotIndex + 1, i);
should be
return select(numbers, pivotIndex + 1, right, i);
I'll leave the original answer below as I don't want to appear to be clever than I actually was.
I think you may have misinterpreted the pseudocode for the select method - it uses iteration rather than recursion.
Here's your current implementation:
private static int select(int[] numbers, int left, int right, int i) {
if(left == right) {
return left;
}
int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);
if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
return select(numbers, left, pivotIndex - 1, i);
}else {
return select(numbers, left, pivotIndex + 1, i);
}
}
And the pseudocode
function select(list, left, right, n)
loop
if left = right then
return left
pivotIndex := pivot(list, left, right)
pivotIndex := partition(list, left, right, pivotIndex, n)
if n = pivotIndex then
return n
else if n < pivotIndex then
right := pivotIndex - 1
else
left := pivotIndex + 1
This would typically be implemented using a while loop:
private static int select(int[] numbers, int left, int right, int i) {
while(true)
{
if(left == right) {
return left;
}
int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);
if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
right = pivotIndex - 1;
}else {
left = pivotIndex + 1;
}
}
}
With this change your code appears to work, though obviously you'll need to test to confirm.
int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
System.out.println("Lower Median: " + lowerMedian(test));
int[] check = test.clone();
Arrays.sort(check);
System.out.println(Arrays.toString(check));
Output:
Lower Median: 6
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13]

Why "int mid = (left - right)/2 + right" will cause stack overflow?

I wrote the code of merge sort today, and I meet a StackOverFlow error when I run int mid = (left - right)/2 + right;. But when I use int mid = left + (right - left) / 2; everything is fine. I'm very confused because the mathematical meaning of the two of them is the same(if it's not, why?).
I have seen this question but I'm not sure it answers this problem.
why left+(right-left)/2 will not overflow?
Here's my code,I am sorry that I did not upload the code before.
import java.util.Arrays;
public class Solution {
public static void main(String[] args) {
int[] array = {3,5,1,2,4,8};
int[] result = mergeSort(array);
System.out.println(Arrays.toString(result));
}
public static int[] mergeSort(int[] array) {
if (array == null || array.length <= 1) {
return array;
}
int[] helper = new int[array.length];
sort(array, helper, 0, array.length - 1);
return array;
}
public static void sort(int[] array, int[] helper, int left, int right) {
if (left >= right) {
return;
}
int mid = (left - right)/2 + right;
//int mid = left + (right - left) / 2;
sort(array, helper, left, mid);
sort(array, helper, mid + 1, right);
combine(array, helper, left, mid, right);
}
public static void combine(int[] array, int[] helper, int left, int mid, int right) {
for (int i = left; i <= right; i++) {
helper[i] = array[i];
}
int leftIndex = left;
int rightIndex = mid + 1;
while (leftIndex <= mid && rightIndex <= right) {
if (helper[leftIndex] <= helper[rightIndex]) {
array[left] = helper[leftIndex];
left++;
leftIndex++;
} else {
array[left] = helper[rightIndex];
left++;
rightIndex++;
}
}
while (leftIndex <= mid) {
array[left] = helper[leftIndex];
left++;
leftIndex++;
}
}
}
As you run this, sooner or later, it's likely that left will be one less than right - for example, maybe left = 3 and right = 4.
In that case, (left - right) / 2 and (right - left) / 2 both evaluate to 0, since integer division rounds towards zero. So, if you have
int mid = (left - right)/2 + right;
sort(array, helper, left, mid);
within your call to sort(array, helper, left, right), then you're repeating a call with the exact same values. This causes infinite recursion - a StackOverflowError.
But when you have
int mid = left + (right - left)/2;
sort(array, helper, left, mid);
then you're repeating a call with different values, so the recursion has a chance to end.
Initially, you have left < right. And, of course, right - left > 0, and furthermore left + (right - left) = right (follows from basic algebra).
And consequently left + (right - left) / 2 <= right. So no overflow can happen since every step of the operation is bounded by the value of right.

How to improve the speed of my class?

I'm running this and I am being told it would not run fast enough. What is a good way to increase the speed of this running class? I am guessing I would need to change my nested while loops. That is the only thing I can think of. The if statements should all be linear...
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
public class QSortLab {
static int findpivot(Comparable[] A, int i, int j) {
return (i + j) / 2;
}
static <E> void swap(E[] A, int p1, int p2) {
E temp = A[p1];
A[p1] = A[p2];
A[p2] = temp;
}
static void quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = findpivot(A, i, j); // Pick a pivot
swap(A, pivotindex, j); // Stick pivot at end
int k = partition(A, i, j-1, A[j]);
swap(A, k, j); // Put pivot in place
if ((k-i) > 1) quicksort(A, i, k-1); // Sort left partition
if ((j-k) > 1) quicksort(A, k+1, j); // Sort right partition
}
static int partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) swap(A, left, right); // Swap out-of-place values
}
return left; // Return first position in right partition
}
}
What do you mean you need to change your nested while loops? Quick Sort is defined by those features. Removing wouldn't function properly.
As for optimization, by default it should be known that primitives vs objects tend to be different. E.g. primitives on stack/heap to keep stack small & heap stores object with refs able to be on stack.
So let's test some stuff
primitive quick sort (from here)
Integer quick sort (same code as above, but with Integer class)
Your original posted code
Your original posted code (w/ several edits)
Here's the entire code I used.
import java.util.Random;
public class App {
public static final int ARR_SIZE = 1000;
public static final int TEST_ITERS = 10000;
public static Random RANDOM = new Random();
public static void main(String[] args) {
int[] a = new int[ARR_SIZE];
Integer[] b = new Integer[ARR_SIZE];
Integer[] c = new Integer[ARR_SIZE];
Integer[] d = new Integer[ARR_SIZE];
long sum = 0, start = 0, end = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
a[j] = RANDOM.nextInt();
start = System.nanoTime();
quickSort(a, 0, a.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'int'");
sum = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
b[j] = RANDOM.nextInt();
start = System.nanoTime();
quickSort(b, 0, b.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'Integer'");
sum = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
c[j] = RANDOM.nextInt();
start = System.nanoTime();
quicksort(c, 0, c.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'Comparable' (SO user code)");
sum = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
d[j] = RANDOM.nextInt();
start = System.nanoTime();
qs_quicksort(d, 0, d.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'Comparable' (SO user code - edit)");
for (int i = 0; i < ARR_SIZE; ++i) {
final int n = RANDOM.nextInt();
a[i] = n;
b[i] = n;
c[i] = n;
d[i] = n;
}
quickSort(a, 0, a.length - 1);
Integer[] aConv = new Integer[ARR_SIZE];
for (int i = 0; i < ARR_SIZE; ++i)
aConv[i] = a[i];
quickSort(b, 0, b.length - 1);
quicksort(c, 0, c.length - 1);
qs_quicksort(d, 0, d.length - 1);
isSorted(new Integer[][] { aConv, b, c, d });
System.out.println("All properly sorted");
}
public static void isSorted(Integer[][] arrays) {
if (arrays.length != 4) {
System.out.println("error sorting, input arr len");
return;
}
for (int i = 0; i < ARR_SIZE; ++i) {
int val1 = arrays[0][i].compareTo(arrays[1][i]);
int val2 = arrays[1][i].compareTo(arrays[2][i]);
int val3 = arrays[2][i].compareTo(arrays[3][i]);
if (val1 != 0 || val2 != 0 || val3 != 00) {
System.out.printf("Error [i = %d]: a = %d, b = %d, c = %d", i, arrays[0][i], arrays[1][i], arrays[2][i], arrays[3][i]);
break;
}
}
}
public static int partition(int arr[], int left, int right) {
int i = left, j = right;
int tmp;
int pivot = arr[(left + right) / 2];
while (i <= j) {
while (arr[i] < pivot)
i++;
while (arr[j] > pivot)
j--;
if (i <= j) {
tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
i++;
j--;
}
}
return i;
}
public static void quickSort(int arr[], int left, int right) {
int index = partition(arr, left, right);
if (left < index - 1)
quickSort(arr, left, index - 1);
if (index < right)
quickSort(arr, index, right);
}
public static int partition(Integer[] arr, int left, int right) {
int i = left, j = right;
Integer pivot = arr[(left + right) / 2];
while (i <= j) {
while (arr[i].compareTo(pivot) < 0)
i++;
while (arr[j].compareTo(pivot) > 0)
j--;
if (i <= j) {
Integer temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i++;
j--;
}
}
return i;
}
public static void quickSort(Integer[] arr, int left, int right) {
int index = partition(arr, left, right);
if (left < index - 1)
quickSort(arr, left, index - 1);
if (index < right)
quickSort(arr, index, right);
}
static int findpivot(Comparable[] A, int i, int j)
{
return (i+j)/2;
}
static <E> void swap(E[] A, int p1, int p2) {
E temp = A[p1];
A[p1] = A[p2];
A[p2] = temp;
}
static void quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = findpivot(A, i, j); // Pick a pivot
swap(A, pivotindex, j); // Stick pivot at end
int k = partition(A, i, j-1, A[j]);
swap(A, k, j); // Put pivot in place
if ((k-i) > 1) quicksort(A, i, k-1); // Sort left partition
if ((j-k) > 1) quicksort(A, k+1, j); // Sort right partition
}
static int partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) swap(A, left, right); // Swap out-of-place values
}
return left; // Return first position in right partition
}
static <E> void qs_swap(E[] A, int p1, int p2) {
E temp = A[p1];
A[p1] = A[p2];
A[p2] = temp;
}
static void qs_quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = (i+j)/2;
qs_swap(A, pivotindex, j); // Stick pivot at end
int k = qs_partition(A, i, j-1, A[j]);
qs_swap(A, k, j); // Put pivot in place
if ((k-i) > 1) qs_quicksort(A, i, k-1); // Sort left partition
if ((j-k) > 1) qs_quicksort(A, k+1, j); // Sort right partition
}
static int qs_partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) { qs_swap(A, left, right); // Swap out-of-place values
left++; right--;}
}
return left; // Return first position in right partition
}
}
This produces the output:
56910 nano, qs avg - 'int'
69498 nano, qs avg - 'Integer'
76762 nano, qs avg - 'Comparable' (SO user code)
71846 nano, qs avg - 'Comparable' (SO user code - edit)
All properly sorted
Now, breaking down the results
The 'int' vs 'Integer' shows great diff when simply using primitives vs non-primitives (I'm sure at some points in the code there may be boxing but hopefully not in critical spots ;) - please edit this if so). The 'int' vs 'Integer' uses same code with exception of 'int' 'Integer'. See the following four method signatures that are used in this comparison, 'int'
public static int partition(int arr[], int left, int right)
public static void quickSort(int arr[], int left, int right)
and 'Integer'
public static int partition(Integer[] arr, int left, int right)
public static void quickSort(Integer[] arr, int left, int right)
respectively.
Then there are the method signatures related to the original code you posted,
static int findpivot(Comparable[] A, int i, int j)
static <E> void swap(E[] A, int p1, int p2)
static void quicksort(Comparable[] A, int i, int j)
static int partition(Comparable[] A, int left, int right, Comparable pivot)
and the modified ones,
static <E> void qs_swap(E[] A, int p1, int p2)
static void qs_quicksort(Comparable[] A, int i, int j)
static int qs_partition(Comparable[] A, int left, int right, Comparable pivot)
As you can see, in the modified code, findpivot was removed directly and replaced into the calling spot in quicksort. Also, the partition method gained counters for left and right respectively. left++; right--;
And finally, to ensure these 4 variations of quicksort actually did the sole purpose, sort, I added a method, isSorted() to check the validity of the same generated content and that it's sorted accordingly based on each of the 4 different sorts.
In conclusion, I think my edits may have saved a portion of time/nanoseconds, however I wasn't able to achieve the same time as the Integer test. Hopefully I've not missed anything obvious and edits are welcome if need be. Cheers
Well, I couldn't tell from testing whether this makes any difference at all because the timer on my machine is terrible , but I think most of the work in this algo is done with the swap function, so thinking about how to make that in particular more efficient, maybe the function call/return itself consumes cycles, and perhaps the creation of the temp variable each time the function is called also takes cycles, so maybe the code would be more efficient if the swap work was done in line. It was not obvious though when I tested on my machine as the nanotimer returned results +/- 20% each time I ran the program
public class QSort2 {
static int findpivot(Comparable[] A, int i, int j) {
return (i + j) / 2;
}
static Comparable temp;
static void quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = findpivot(A, i, j); // Pick a pivot
// swap(A, pivotindex, j); // Stick pivot at end
temp = A[pivotindex];
A[pivotindex] = A[j];
A[j] = temp;
int k = partition(A, i, j - 1, A[j]);
//swap(A, k, j); // Put pivot in place
temp = A[k];
A[k] = A[j];
A[j] = temp;
if ((k - i) > 1) quicksort(A, i, k - 1); // Sort left partition
if ((j - k) > 1) quicksort(A, k + 1, j); // Sort right partition
}
static int partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) {
//swap(A, left, right);} // Swap out-of-place values
temp = A[left];
A[left] = A[right];
A[right] = temp;
}
}
return left; // Return first position in right partition
}
}

CompareTo() method in Java

//I know that compareTo() returns an int value.
returns 0 if a equal b
returns -1 if a < b
returns 1 if a > b
//But how does sort method makes use of these 0, 1, -1 values, and how does it arranges the list exactly.
To sort an array with compareTo(),
public void sort(Comparable[] lst)
{
for(int i=0; i<lst.length; i++)
{
for(int j=i; j<lst.length; j++)
{
if(lst[i].compareTo(lst[j]) < 0)
{
//swap elements, sorts in ascending order
Comparable tmp = lst[i];
lst[i] = lst[j];
lst[j] = tmp;
}
}
}
}
Well, here's the code, not sure what more you want. Collections.sort seems to ultimately call this.
private static <T> void binarySort(T[] a, int lo, int hi, int start, Comparator<? super T> c) {
assert lo <= start && start <= hi;
if (start == lo)
start++;
for ( ; start < hi; start++) {
T pivot = a[start];
// Set left (and right) to the index where a[start] (pivot) belongs
int left = lo;
int right = start;
assert left <= right;
/*
* Invariants:
* pivot >= all in [lo, left).
* pivot < all in [right, start).
*/
while (left < right) {
int mid = (left + right) >>> 1;
if (c.compare(pivot, a[mid]) < 0)
right = mid;
else
left = mid + 1;
}
assert left == right;
/*
* The invariants still hold: pivot >= all in [lo, left) and
* pivot < all in [left, start), so pivot belongs at left. Note
* that if there are elements equal to pivot, left points to the
* first slot after them -- that's why this sort is stable.
* Slide elements over to make room for pivot.
*/
int n = start - left; // The number of elements to move
// Switch is just an optimization for arraycopy in default case
switch (n) {
case 2: a[left + 2] = a[left + 1];
case 1: a[left + 1] = a[left];
break;
default: System.arraycopy(a, left, a, left + 1, n);
}
a[left] = pivot;
}
}

Find the Kth smallest element. looping

I try to find k-th minimum element using my code, but can't fix an error in my code.
When it try to make partitioning for [0, 0, 2] with pivot = 0 it's looping.
import java.util.Arrays;
public class OrderStat {
public static void main(String[] args) {
int[] uA = {13, 32, 28, 17, 2, 0, 14, 34, 35, 0};
System.out.println("Initial array: " + Arrays.toString(uA));
int kth = 3; // We will try to find 3rd smallest element(or 2nd if we will count from 0).
int result = getKthSmallestElement(uA, 0, uA.length - 1, kth - 1);
System.out.println(String.format("The %d smallest element is %d", kth, result));
System.out.println("-------------------------------------");
Arrays.sort(uA);
System.out.println("Sorted array for check: " + Arrays.toString(uA));
}
private static int getKthSmallestElement(int[] uA, int start, int end, int kth) {
int l = start;
int r = end;
int pivot = uA[start];
System.out.println("===================");
System.out.println(String.format("start=%d end=%d", start, end));
System.out.println("pivot = " + pivot);
//ERROR HERE: When we will work with [0, 0, 2] part of array with pivot = 0 it will give us infinite loop;
while (l < r) {
while (uA[l] < pivot) {
l++;
}
while (uA[r] > pivot) {
r--;
}
if (l < r) {
int tmp = uA[l];
uA[l] = uA[r];
uA[r] = tmp;
}
}
System.out.println("After partitioning: " + Arrays.toString(uA) + "\n");
if (l < kth)
return getKthSmallestElement(uA, l + 1, end, kth);
else if (l > kth)
return getKthSmallestElement(uA, start, l - 1, kth);
return uA[l];
}
}
Explain me, please, how to fix this problem.
After swapping
if (l < r) {
int tmp = uA[l];
uA[l] = uA[r];
uA[r] = tmp;
}
you need to move l and r (or at least one of them, to make any progress) to the next position (++l; --r;). Otherwise, if both values are equal to the pivot, you loop infinitely.
A correct partitioning that is also usable in a quicksort would be
// make sure to call it only with valid indices, 0 <= start <= end < array.length
private int partition(int[] array, int start, int end) {
// trivial case, single element array - garbage if end < start
if(end <= start) return start;
int pivot = array[start]; // not a good choice of pivot in general, but meh
int left = start+1, right = end;
while(left < right) {
// move left index to first entry larger than pivot or right
while(left < right && array[left] <= pivot) ++left;
// move right index to last entry not larger than pivot or left
while(right > left && array[right] > pivot) --right;
// Now, either
// left == right, or
// left < right and array[right] <= pivot < array[left]
if (left < right) {
int tmp = array[left];
array[left] = array[right];
array[right] = tmp;
// move on
++left;
--right;
}
}
// Now left >= right.
// If left == right, we don't know whether array[left] is larger than the pivot or not,
// but array[left-1] certainly is not larger than the pivot.
// If left > right, we just swapped and incremented before exiting the loop,
// so then left == right+1 and array[right] <= pivot < array[left].
if (left > right || array[left] > pivot) {
--left;
}
// Now array[i] <= pivot for start <= i <= left, and array[j] > pivot for left < j <= end
// swap pivot in its proper place in the sorted array
array[start] = array[left];
array[left] = pivot;
// return pivot position
return left;
}
Then you can find the k-th smallest element in an array
int findKthSmallest(int array, int k) {
if (k < 1) throw new IllegalArgumentException("k must be positive");
if (array.length < k) throw new IllegalArgumentException("Array too short");
int left = 0, right = array.length-1, p;
--k; // 0-based indices
while(true) {
p = partition(array, left, right);
if (p == k) return array[p];
if (p < k) {
left = p+1;
k -= left;
} else {
right = p-1;
}
}
// dummy return, never reached
return 0;
}

Categories

Resources