I am trying to implement the median of medians algorithm in Java. The algorithm shall determine the median of a set of numbers. I tried to implement the pseudo code on wikipedia:
https://en.wikipedia.org/wiki/Median_of_medians
I am getting a buffer overflow and don't know why. Due to the recursions it's quite difficult to keep track of the code for me.
import java.util.Arrays;
public class MedianSelector {
private static final int CHUNK = 5;
public static void main(String[] args) {
int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
lowerMedian(test);
System.out.print(Arrays.toString(test));
}
/**
* Computes and retrieves the lower median of the given array of
* numbers using the Median algorithm presented in the lecture.
*
* #param input numbers.
* #return the lower median.
* #throw IllegalArgumentException if the array is {#code null} or empty.
*/
public static int lowerMedian(int[] numbers) {
if(numbers == null || numbers.length == 0) {
throw new IllegalArgumentException();
}
return numbers[select(numbers, 0, numbers.length - 1, (numbers.length - 1) / 2)];
}
private static int select(int[] numbers, int left, int right, int i) {
if(left == right) {
return left;
}
int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);
if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
return select(numbers, left, pivotIndex - 1, i);
}else {
return select(numbers, left, pivotIndex + 1, i);
}
}
private static int pivot(int numbers[], int left, int right) {
if(right - left < CHUNK) {
return partition5(numbers, left, right);
}
for(int i=left; i<=right; i=i+CHUNK) {
int subRight = i + (CHUNK-1);
if(subRight > right) {
subRight = right;
}
int medChunk = partition5(numbers, i, subRight);
int tmp = numbers[medChunk];
numbers[medChunk] = numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))];
numbers[(int) (left + Math.floor((double) (i-left)/CHUNK))] = tmp;
}
int mid = (right - left) / 10 + left +1;
return select(numbers, left, (int) (left + Math.floor((right - left) / CHUNK)), mid);
}
private static int partition(int[] numbers, int left, int right, int idx, int k) {
int pivotVal = numbers[idx];
int storeIndex = left;
int storeIndexEq = 0;
int tmp = 0;
tmp = numbers[idx];
numbers[idx] = numbers[right];
numbers[right] = tmp;
for(int i=left; i<right; i++) {
if(numbers[i] < pivotVal) {
tmp = numbers[i];
numbers[i] = numbers[storeIndex];
numbers[storeIndex] = tmp;
storeIndex++;
}
}
storeIndexEq = storeIndex;
for(int i=storeIndex; i<right; i++) {
if(numbers[i] == pivotVal) {
tmp = numbers[i];
numbers[i] = numbers[storeIndexEq];
numbers[storeIndexEq] = tmp;
storeIndexEq++;
}
}
tmp = numbers[right];
numbers[right] = numbers[storeIndexEq];
numbers[storeIndexEq] = tmp;
if(k < storeIndex) {
return storeIndex;
}
if(k <= storeIndexEq) {
return k;
}
return storeIndexEq;
}
//Insertion sort
private static int partition5(int[] numbers, int left, int right) {
int i = left + 1;
int j = 0;
while(i<=right) {
j= i;
while(j>left && numbers[j-1] > numbers[j]) {
int tmp = numbers[j-1];
numbers[j-1] = numbers[j];
numbers[j] = tmp;
j=j-1;
}
i++;
}
return left + (right - left) / 2;
}
}
Confirm n (in the pseudo code) or i (in my code) stand for the position of the median? So lets assume our array is number = {9,8,7,6,5,4,3,2,1,0}. I would call select{numbers, 0, 9,4), correct?
I don't understand the calculation of mid in pivot? Why is there a division by 10? Maybe there is a mistake in the pseudo code?
Thanks for your help.
EDIT: It turns out the switch from iteration to recursion was a red herring. The actual issue, identified by the OP, was in the arguments to the 2nd recursive select call.
This line:
return select(numbers, left, pivotIndex + 1, i);
should be
return select(numbers, pivotIndex + 1, right, i);
I'll leave the original answer below as I don't want to appear to be clever than I actually was.
I think you may have misinterpreted the pseudocode for the select method - it uses iteration rather than recursion.
Here's your current implementation:
private static int select(int[] numbers, int left, int right, int i) {
if(left == right) {
return left;
}
int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);
if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
return select(numbers, left, pivotIndex - 1, i);
}else {
return select(numbers, left, pivotIndex + 1, i);
}
}
And the pseudocode
function select(list, left, right, n)
loop
if left = right then
return left
pivotIndex := pivot(list, left, right)
pivotIndex := partition(list, left, right, pivotIndex, n)
if n = pivotIndex then
return n
else if n < pivotIndex then
right := pivotIndex - 1
else
left := pivotIndex + 1
This would typically be implemented using a while loop:
private static int select(int[] numbers, int left, int right, int i) {
while(true)
{
if(left == right) {
return left;
}
int pivotIndex = pivot(numbers, left, right);
pivotIndex = partition(numbers, left, right, pivotIndex, i);
if(i == pivotIndex) {
return i;
}else if(i < pivotIndex) {
right = pivotIndex - 1;
}else {
left = pivotIndex + 1;
}
}
}
With this change your code appears to work, though obviously you'll need to test to confirm.
int[] test = {9,8,7,6,5,4,3,2,1,0,13,11,10};
System.out.println("Lower Median: " + lowerMedian(test));
int[] check = test.clone();
Arrays.sort(check);
System.out.println(Arrays.toString(check));
Output:
Lower Median: 6
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13]
Related
I have a quick sort method which sorts elements in ascenidng order but seem to keep getting a stackoverflowerror.
For some reason it is showing the error on the while loop, when the logic makes sense to me.
Here is the code for the quick sort class:
public T[] sort(T[] arr, int left, int right)
{
int l = left;
int r = right;
if (right <= left)
return null;
//Find the pivot in the middle
T pivot = arr[(left + (right - left)) / 2];
T temp;
while (l <= r)
{
// check values on left are bigger than the pivot
while (arr[l].compareTo(pivot) < 0)
{
l++;
}
// check if values are smaller than the pivot
while (arr[r].compareTo(pivot) > 0)
{
r--;
}
// l and r have gone past each other swap them
if (l <= r)
{
//swap process
temp = arr[l];
arr[l] = arr[r];
arr[r] = temp;
// left pointer goes up 1
// right pointer goes down 1
l++;
r--;
}
}
if (left < r)
sort(arr, left, r);
if (l < right)
sort(arr, l, right);
return arr;
}
The error seems to be pointing to
//Find the pivot in the middle
T pivot = arr[(left + (right - left)) / 2];
I then seem to be getting many occuring errors.
I believe you are calculating Pivot incorrectly
You should use T pivot = arr[left + (right - left) / 2];
Below is working quick sort program using middle element as pivot:
public void quickSort(T arr[],int left, int right){
int low =left, high = right;
int pivot = arr[left + (right - left) / 2];
while(low<=high){
while (arr[low] < pivot) {
low++;
}
while (arr[high] > pivot) {
high--;
}
if (low <= high) {
int temp = arr[low];
arr[low] = arr[high];
arr[high] = temp;
low++;
high--;
}
if (left < high) {
quickSort(arr,left, high);
}
if (low < high) {
quickSort(arr,low, right);
}
}
}
Hope it helps !!
Stream API way:
T[] result = Arrays.stream(a)
.skip(left)
.limit(right - left)
.sorted((o1, o2) -> {{you logic of comparing}})
.toArray(String[]::new);
the difference in what you get only sorted part of an array. So you should concatenate them after if it is necessary.
You have a typo in that line, yes. T pivot = arr[(left + (right - left)) / 2];
Due to the extra parentheses, everything is divided by 2. It should be
T pivot = arr[left + (right - left) / 2];
A couple stylistical remarks:
as this is an in-place sort, returning T[] is not really necessary
T temp could be moved into the swap-block
Putting it together:
import java.util.Arrays;
public class QuickSort<T extends Comparable<? super T>> {
public void sort (T[] arr)
{
if (arr == null || arr.length <= 1)
return;
sort(arr, 0, arr.length - 1);
}
public void sort(T[] arr, int left, int right)
{
int l = left;
int r = right;
if (right <= left)
return;
//Find the pivot in the middle
T pivot = arr[(left + (right - left)/2)];
while (l <= r)
{
// check values on left are bigger than the pivot
while (arr[l].compareTo(pivot) < 0)
{
l++;
}
// check if values are smaller than the pivot
while (arr[r].compareTo(pivot) > 0)
{
r--;
}
// l and r have gone past each other swap them
if (l <= r)
{
//swap process
T temp = arr[l];
arr[l] = arr[r];
arr[r] = temp;
// left pointer goes up 1
// right pointer goes down 1
l++;
r--;
}
}
if (left < r)
sort(arr, left, r);
if (l < right)
sort(arr, l, right);
}
public static void main(String args[])
{
Integer[] numbers=new Integer[] {3,2,5,4,1};
System.out.println(Arrays.asList(numbers));
new QuickSort<Integer>().sort(numbers);
System.out.println(Arrays.asList(numbers));
}
}
Output:
[3, 2, 5, 4, 1]
[1, 2, 3, 4, 5]
I have a quicksort code in java which I want to improve. The improved code should take less time than the quicksort code. However when using my improved code which implement median of 3 partitioning, it takes 400 miliseconds more. Can anybody help me solve this out? if possible, can you suggest me other possible ways to improve my code. I have from 10,000 to 10 million integers to sort.
Quicksort
public void quickSort(int arr[], int begin, int end) {
if (begin < end) {
int partitionIndex = partition(arr, begin, end);
quickSort(arr, begin, partitionIndex-1);
quickSort(arr, partitionIndex+1, end);
}
}
private int partition(int arr[], int begin, int end) {
int pivot = arr[end];
int i = (begin-1);
for (int j = begin; j < end; j++) {
if (arr[j] <= pivot) {
i++;
int swapTemp = arr[i];
arr[i] = arr[j];
arr[j] = swapTemp;
}
}
int swapTemp = arr[i+1];
arr[i+1] = arr[end];
arr[end] = swapTemp;
return i+1;
}
}
Improved code
package sorting;
public class improvement {
Clock c = new Clock();
public void quickSort(int[] intArray) {
recQuickSort(intArray, 0, intArray.length - 1);
}
public static void recQuickSort(int[] intArray, int left, int right) {
int size = right - left + 1;
if (size <= 3)
manualSort(intArray, left, right);
else {
double median = medianOf3(intArray, left, right);
int partition = partitionIt(intArray, left, right, median);
recQuickSort(intArray, left, partition - 1);
recQuickSort(intArray, partition + 1, right);
}
}
public static int medianOf3(int[] intArray, int left, int right) {
int center = (left + right) / 2;
if (intArray[left] > intArray[center])
swap(intArray, left, center);
if (intArray[left] > intArray[right])
swap(intArray, left, right);
if (intArray[center] > intArray[right])
swap(intArray, center, right);
swap(intArray, center, right - 1);
return intArray[right - 1];
}
public static void swap(int[] intArray, int dex1, int dex2) {
int temp = intArray[dex1];
intArray[dex1] = intArray[dex2];
intArray[dex2] = temp;
}
public static int partitionIt(int[] intArray, int left, int right, double
pivot) {
int leftPtr = left;
int rightPtr = right - 1;
while (true) {
while (intArray[++leftPtr] < pivot)
;
while (intArray[--rightPtr] > pivot)
;
if (leftPtr >= rightPtr)
break;
else
swap(intArray, leftPtr, rightPtr);
}
swap(intArray, leftPtr, right - 1);
return leftPtr;
}
public static void manualSort(int[] intArray, int left, int right) {
int size = right - left + 1;
if (size <= 1)
return;
if (size == 2) {
if (intArray[left] > intArray[right])
swap(intArray, left, right);
return;
} else {
if (intArray[left] > intArray[right - 1])
swap(intArray, left, right - 1);
if (intArray[left] > intArray[right])
swap(intArray, left, right);
if (intArray[right - 1] > intArray[right])
swap(intArray, right - 1, right);
}
}
}
I'm implementing a recursive quicksort however I'm receiving stackoverflow and not sure where the bug lies :(
I'm sorting 1 million ints from 10-50.
I works for sizes less than 1 million like 100 thousand etc.
public Quicksort(int NUM_TESTS, int NUM_ELEMENTS){
num_tests = NUM_TESTS;
num_elements = NUM_ELEMENTS;
}
private void start(){
for (int i = 0; i < num_tests; i++){
int[] d1 = dataGeneration(num_elements);
qSortRecursive(d1,0,d1.length-1);
}
}
public static void main(String args[]){
Quicksort q = new Quicksort(1,1000000);
q.start();
}
private int[] dataGeneration(int n) {
int[] d1 = new int[n];
for (int i = 0; i < n; i++){
d1[i] = (int)(Math.random() * ((50 - 10) + 1) + 10);
}
return d1;
}
private void qSortRecursive(int[] data, int left, int right){
if(left < right){
int pivot = partition(data,left,right);
qSortRecursive(data,left,pivot-1);
qSortRecursive(data,pivot+1,right);
}
}
private int partition(int[] data, int left, int right){
int pivot = left ;
left++;
while (left <= right){
while (left <= right && data[left] <= data[pivot]) {
left++;
}
while (left <= right && data[right] >= data[pivot]){
right--;
}
if (left < right){
swap(data,left,right);
left++;
right--;
}
}
if (data[right] <= data[pivot]){
if (data[right] != data[pivot]){
swap(data,right,pivot);
}
pivot = right;
}
return pivot;
}
private void swap(int[] data, int i, int j){
int temp = data[i];
data[i] = data[j];
data[j] = temp;
}
private void qSortRecursive(int[] data, int left, int right){
while (left < right){
int pivot = partition(data,left,right);
if (pivot - left < right - pivot){
qSortRecursive(data, left, pivot - 1);
left = pivot + 1;
} else {
qSortRecursive(data, pivot + 1, right);
right = pivot - 1;
}
}
Performing a tail call by reducing number of recursion solved my problem, thanks for help everyone :)
You can try to rewrite the algorithm without recursion. Well, you remove recursion by adding your own stack and in that case you can have available the entire memory, not just size of stack.
Something like: http://alienryderflex.com/quicksort/
I'm running this and I am being told it would not run fast enough. What is a good way to increase the speed of this running class? I am guessing I would need to change my nested while loops. That is the only thing I can think of. The if statements should all be linear...
import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
public class QSortLab {
static int findpivot(Comparable[] A, int i, int j) {
return (i + j) / 2;
}
static <E> void swap(E[] A, int p1, int p2) {
E temp = A[p1];
A[p1] = A[p2];
A[p2] = temp;
}
static void quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = findpivot(A, i, j); // Pick a pivot
swap(A, pivotindex, j); // Stick pivot at end
int k = partition(A, i, j-1, A[j]);
swap(A, k, j); // Put pivot in place
if ((k-i) > 1) quicksort(A, i, k-1); // Sort left partition
if ((j-k) > 1) quicksort(A, k+1, j); // Sort right partition
}
static int partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) swap(A, left, right); // Swap out-of-place values
}
return left; // Return first position in right partition
}
}
What do you mean you need to change your nested while loops? Quick Sort is defined by those features. Removing wouldn't function properly.
As for optimization, by default it should be known that primitives vs objects tend to be different. E.g. primitives on stack/heap to keep stack small & heap stores object with refs able to be on stack.
So let's test some stuff
primitive quick sort (from here)
Integer quick sort (same code as above, but with Integer class)
Your original posted code
Your original posted code (w/ several edits)
Here's the entire code I used.
import java.util.Random;
public class App {
public static final int ARR_SIZE = 1000;
public static final int TEST_ITERS = 10000;
public static Random RANDOM = new Random();
public static void main(String[] args) {
int[] a = new int[ARR_SIZE];
Integer[] b = new Integer[ARR_SIZE];
Integer[] c = new Integer[ARR_SIZE];
Integer[] d = new Integer[ARR_SIZE];
long sum = 0, start = 0, end = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
a[j] = RANDOM.nextInt();
start = System.nanoTime();
quickSort(a, 0, a.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'int'");
sum = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
b[j] = RANDOM.nextInt();
start = System.nanoTime();
quickSort(b, 0, b.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'Integer'");
sum = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
c[j] = RANDOM.nextInt();
start = System.nanoTime();
quicksort(c, 0, c.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'Comparable' (SO user code)");
sum = 0;
for (int i = 0; i < TEST_ITERS; ++i) {
for (int j = 0; j < ARR_SIZE; ++j)
d[j] = RANDOM.nextInt();
start = System.nanoTime();
qs_quicksort(d, 0, d.length - 1);
end = System.nanoTime();
sum += (end - start);
}
System.out.println((sum / TEST_ITERS) + " nano, qs avg - 'Comparable' (SO user code - edit)");
for (int i = 0; i < ARR_SIZE; ++i) {
final int n = RANDOM.nextInt();
a[i] = n;
b[i] = n;
c[i] = n;
d[i] = n;
}
quickSort(a, 0, a.length - 1);
Integer[] aConv = new Integer[ARR_SIZE];
for (int i = 0; i < ARR_SIZE; ++i)
aConv[i] = a[i];
quickSort(b, 0, b.length - 1);
quicksort(c, 0, c.length - 1);
qs_quicksort(d, 0, d.length - 1);
isSorted(new Integer[][] { aConv, b, c, d });
System.out.println("All properly sorted");
}
public static void isSorted(Integer[][] arrays) {
if (arrays.length != 4) {
System.out.println("error sorting, input arr len");
return;
}
for (int i = 0; i < ARR_SIZE; ++i) {
int val1 = arrays[0][i].compareTo(arrays[1][i]);
int val2 = arrays[1][i].compareTo(arrays[2][i]);
int val3 = arrays[2][i].compareTo(arrays[3][i]);
if (val1 != 0 || val2 != 0 || val3 != 00) {
System.out.printf("Error [i = %d]: a = %d, b = %d, c = %d", i, arrays[0][i], arrays[1][i], arrays[2][i], arrays[3][i]);
break;
}
}
}
public static int partition(int arr[], int left, int right) {
int i = left, j = right;
int tmp;
int pivot = arr[(left + right) / 2];
while (i <= j) {
while (arr[i] < pivot)
i++;
while (arr[j] > pivot)
j--;
if (i <= j) {
tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
i++;
j--;
}
}
return i;
}
public static void quickSort(int arr[], int left, int right) {
int index = partition(arr, left, right);
if (left < index - 1)
quickSort(arr, left, index - 1);
if (index < right)
quickSort(arr, index, right);
}
public static int partition(Integer[] arr, int left, int right) {
int i = left, j = right;
Integer pivot = arr[(left + right) / 2];
while (i <= j) {
while (arr[i].compareTo(pivot) < 0)
i++;
while (arr[j].compareTo(pivot) > 0)
j--;
if (i <= j) {
Integer temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i++;
j--;
}
}
return i;
}
public static void quickSort(Integer[] arr, int left, int right) {
int index = partition(arr, left, right);
if (left < index - 1)
quickSort(arr, left, index - 1);
if (index < right)
quickSort(arr, index, right);
}
static int findpivot(Comparable[] A, int i, int j)
{
return (i+j)/2;
}
static <E> void swap(E[] A, int p1, int p2) {
E temp = A[p1];
A[p1] = A[p2];
A[p2] = temp;
}
static void quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = findpivot(A, i, j); // Pick a pivot
swap(A, pivotindex, j); // Stick pivot at end
int k = partition(A, i, j-1, A[j]);
swap(A, k, j); // Put pivot in place
if ((k-i) > 1) quicksort(A, i, k-1); // Sort left partition
if ((j-k) > 1) quicksort(A, k+1, j); // Sort right partition
}
static int partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) swap(A, left, right); // Swap out-of-place values
}
return left; // Return first position in right partition
}
static <E> void qs_swap(E[] A, int p1, int p2) {
E temp = A[p1];
A[p1] = A[p2];
A[p2] = temp;
}
static void qs_quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = (i+j)/2;
qs_swap(A, pivotindex, j); // Stick pivot at end
int k = qs_partition(A, i, j-1, A[j]);
qs_swap(A, k, j); // Put pivot in place
if ((k-i) > 1) qs_quicksort(A, i, k-1); // Sort left partition
if ((j-k) > 1) qs_quicksort(A, k+1, j); // Sort right partition
}
static int qs_partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) { qs_swap(A, left, right); // Swap out-of-place values
left++; right--;}
}
return left; // Return first position in right partition
}
}
This produces the output:
56910 nano, qs avg - 'int'
69498 nano, qs avg - 'Integer'
76762 nano, qs avg - 'Comparable' (SO user code)
71846 nano, qs avg - 'Comparable' (SO user code - edit)
All properly sorted
Now, breaking down the results
The 'int' vs 'Integer' shows great diff when simply using primitives vs non-primitives (I'm sure at some points in the code there may be boxing but hopefully not in critical spots ;) - please edit this if so). The 'int' vs 'Integer' uses same code with exception of 'int' 'Integer'. See the following four method signatures that are used in this comparison, 'int'
public static int partition(int arr[], int left, int right)
public static void quickSort(int arr[], int left, int right)
and 'Integer'
public static int partition(Integer[] arr, int left, int right)
public static void quickSort(Integer[] arr, int left, int right)
respectively.
Then there are the method signatures related to the original code you posted,
static int findpivot(Comparable[] A, int i, int j)
static <E> void swap(E[] A, int p1, int p2)
static void quicksort(Comparable[] A, int i, int j)
static int partition(Comparable[] A, int left, int right, Comparable pivot)
and the modified ones,
static <E> void qs_swap(E[] A, int p1, int p2)
static void qs_quicksort(Comparable[] A, int i, int j)
static int qs_partition(Comparable[] A, int left, int right, Comparable pivot)
As you can see, in the modified code, findpivot was removed directly and replaced into the calling spot in quicksort. Also, the partition method gained counters for left and right respectively. left++; right--;
And finally, to ensure these 4 variations of quicksort actually did the sole purpose, sort, I added a method, isSorted() to check the validity of the same generated content and that it's sorted accordingly based on each of the 4 different sorts.
In conclusion, I think my edits may have saved a portion of time/nanoseconds, however I wasn't able to achieve the same time as the Integer test. Hopefully I've not missed anything obvious and edits are welcome if need be. Cheers
Well, I couldn't tell from testing whether this makes any difference at all because the timer on my machine is terrible , but I think most of the work in this algo is done with the swap function, so thinking about how to make that in particular more efficient, maybe the function call/return itself consumes cycles, and perhaps the creation of the temp variable each time the function is called also takes cycles, so maybe the code would be more efficient if the swap work was done in line. It was not obvious though when I tested on my machine as the nanotimer returned results +/- 20% each time I ran the program
public class QSort2 {
static int findpivot(Comparable[] A, int i, int j) {
return (i + j) / 2;
}
static Comparable temp;
static void quicksort(Comparable[] A, int i, int j) { // Quicksort
int pivotindex = findpivot(A, i, j); // Pick a pivot
// swap(A, pivotindex, j); // Stick pivot at end
temp = A[pivotindex];
A[pivotindex] = A[j];
A[j] = temp;
int k = partition(A, i, j - 1, A[j]);
//swap(A, k, j); // Put pivot in place
temp = A[k];
A[k] = A[j];
A[j] = temp;
if ((k - i) > 1) quicksort(A, i, k - 1); // Sort left partition
if ((j - k) > 1) quicksort(A, k + 1, j); // Sort right partition
}
static int partition(Comparable[] A, int left, int right, Comparable pivot) {
while (left <= right) { // Move bounds inward until they meet
while (A[left].compareTo(pivot) < 0) left++;
while ((right >= left) && (A[right].compareTo(pivot) >= 0)) right--;
if (right > left) {
//swap(A, left, right);} // Swap out-of-place values
temp = A[left];
A[left] = A[right];
A[right] = temp;
}
}
return left; // Return first position in right partition
}
}
I try to use "randomized pivot" method to find the Kth min elem among given array.
[The code]
public class FindKthMin {
// Find the Kth min elem by randomized pivot.
private static void exchange (int[] givenArray, int firstIndex, int secondIndex) {
int tempElem = givenArray[firstIndex];
givenArray[firstIndex] = givenArray[secondIndex];
givenArray[secondIndex] = tempElem;
}
private static int partition (int[] givenArray, int start, int end, int pivotIndex) {
// Debug:
//System.out.println("debug: start = " + start);
//System.out.println(">> end = " + end);
//System.out.println(">> pivotIndex = " + pivotIndex);
int pivot = givenArray[pivotIndex];
int left = start - 1;
int right = end;
boolean hasDone = false;
while (!hasDone) {
while (!hasDone) {
left ++;
if (left == right) {
hasDone = true;
break;
}
if (givenArray[left] >= pivot) {
// Exchange givenArray[left] and the givenArray[right].
exchange(givenArray, left, right);
break;
}
}
while (!hasDone) {
right --;
if (left == right) {
hasDone = true;
break;
}
if (givenArray[right] < pivot) {
// Exchange the givenArray[right] and the givenArray[left].
exchange(givenArray, right, left);
break;
}
}
}
givenArray[right] = pivot;
// Debug:
//System.out.println(">> split = " + right);
//System.out.println();
return right;
}
private static int findKthMin_RanP_Helper (int[] givenArray, int start, int end, int k) {
if (start > end) return -1;
// Generate a random num in the range[start, end].
int rand = (int)(start + Math.random() * (end - start + 1));
// Using this random num as the pivot index to partition the array in the current scope.
int split = partition(givenArray, start, end, rand);
if (k == split + 1) return givenArray[split];
else if (k < split + 1) return findKthMin_RanP_Helper(givenArray, start, split - 1, k);
else return findKthMin_RanP_Helper(givenArray, split + 1, end, k);
}
public static int findKthMin_RanP (int[] givenArray, int k) {
int size = givenArray.length;
if (k < 1 || k > size) return -1;
return findKthMin_RanP_Helper(givenArray, 0, size - 1, k);
}
// Main method to test.
public static void main (String[] args) {
// Test data: {8, 9, 5, 2, 8, 4}.
int[] givenArray = {8, 9, 5, 2, 8, 4};
// Test finding the Kth min elem by randomized pivot method.
System.out.println("Test finding the Kth min elem by randomized pivot method, rest = " + findKthMin_RanP(givenArray, 1));
}
}
But the result is unstable, sometimes right and sometimes wrong.
Please have a look at the 5th row of findKthMin_RanP_Helper method:
If I change this int split = partition(givenArray, start, end, rand); to int split = partition(givenArray, start, end, end);, the result is always correct. I really can not find what's wrong with this.
EDIT:
The problem comes from the "partition", the new partition should like this:
private static int partition_second_version (int[] givenArray, int start, int end, int pivotIndex) {
int pivot = givenArray[pivotIndex];
int left = start;
int right = end;
while (left <= right) {
while (givenArray[left] < pivot) left ++;
while (givenArray[right] > pivot) right --;
if (left <= right) {
// Exchange givenArray[left] and givenArray[right].
exchange(givenArray, left, right);
left ++;
right --;
}
}
return left;
}
And the findKthMin_RanP_Helper should be changed like this:
private static int findKthMin_RanP_Helper (int[] givenArray, int start, int end, int k) {
if (start > end) return -1;
// Generate a random num in the range[start, end].
int rand = start + (int)(Math.random() * ((end - start) + 1));
// Using this random num as the pivot index to partition the array in the current scope.
int split = partition_second_version (givenArray, start, end, rand);
if (k == split) return givenArray[split - 1];
else if (k < split) return findKthMin_RanP_Helper(givenArray, start, split - 1, k);
else return findKthMin_RanP_Helper(givenArray, split, end, k);
}
Your partition routine could be simplified...
private static int partition(int[] givenArray, int start, int end, int pivotIndex) {
final int pivot = givenArray[pivotIndex];
int left = start;
int right = end;
while (left < right) {
while (left < givenArray.length && givenArray[left] <= pivot) {
left++;
}
while (right > -1 && givenArray[right] > pivot) {
right--;
}
if (left >= right) {
break;
}
exchange(givenArray, right, left);
}
return right;
}
The one bug I see in your code is your partition routine. In the first exchange call, it is not guaranteed that the right index will always point to a value which is < pivot.