I want to implement a dynamic programming algorithm for this problem:
Input: A given arrangement S of non-negative numbers {S1.......Sn}
we want to partition S into 2 subset S1 an S2 and minimize |sum(S1)-sum(S2)|, then partition the 2 subset in the same way , we stop when we reach an Subset with 2 or 1 element ()We must preserve the order of S elements).
example : S= {1,2,2,3,4} Output { { {1,2}{2} } {3,4} }
With the help of this article this is my implementation :
static String partition(int s[], int db,int fn)
{
int n = (fn-db) +1;
String res ="";
if (n<=2){
res +="[";
for(int l =db ;l<=fn;l++) res+=s[l];
res +="]";
return res;
}
int[][] m= new int [n+1][3]; /* DP table for values */
int[][] d= new int [n+1][3]; /* DP table for dividers */
int [] p = new int [n+1]; /* prefix sums array */
int cost; /* test split cost */
int i,x = 0; /* counters */
p[0] = 0; /* construct prefix sums */
for (i=1; i<=n; i++)
p[i]=p[i-1]+s[(db-1)+i];
for (i=1; i<=n; i++)
m[i][1] = p[i]; /* initialize boundaries */
m[1][2] = s[db];
for (i=2; i<=n; i++){ /* evaluate main recurrence */
m[i][2] = Integer.MAX_VALUE;
for (x=1; x<=(i-1); x++) {
cost = Math.max(m[x][1], p[i]-p[x]);
if (m[i][2] > cost) {
m[i][2] = cost;
d[i][2] = db+(x-1);
}
}
}
return res +="["+partition(s,db,d[n][2])+partition(s,d[n][2]+1,fn)+"]";
}
public static void main(String[] args) {
int []set ={2,1,1,1,5};
System.out.print(partition(set,0,set.length-1));
}
Is my implementation is the good one or there is another dynamic programing solution whitout recursive call ?
I cannot calculate complexity of this algorithm , I try to use Master theorem T(n)=aT(nb)+f(n) , but I don't now n/b the size of each subproblem for the 2 recursive call.
3.How we can do the same partition if we can change the order of element ?
Think about it this way: in the worst case, you have an array where each index i contains the value 2i. In this case, the split only decreases the length by one, meaning your recursion depth is linear in n. At each level, you do O(n) work, so the total complexity is O(n2). Fortunately, such arrays will be very short in practice, because we generally do not consider such huge numbers, so the real-world performance will generally be better. When the weights are somewhat balanced, you should have O(n log n) performance.
As for the first question, I'm not sure what you mean with "the good one".
EDIT: You algorithm could be made more efficient by performing a binary search on the location you want to split the array at. This is possible because the order must be preserved.
Related
I am trying to write a program that will iterate through all possible permutations of a String array, and return a two dimensional array with all the permutations. Specifically, I am trying to use a String array of length 4 to return a 2D array with 24 rows and 4 columns.
I have only found ways to print the Strings iteratively but not use them in an array. I have also found recursive ways of doing it, but they do not work, as I am using this code with others, and the recursive function is much more difficult.
For what I want the code to do, I know the header should be:
public class Permutation
{
public String[][] arrayPermutation(String[] str)
{
//code to return 2D array
}
}
//I tried using a recursive method with heap's algorithm, but it is very //complex with its parameters.
I am very new to programming and any help would be greatly appreciated.
Your permutation-problem is basically just an index-permutation problem.
If you can order the numbers from 0 to n - 1 in all possible variations, you can use them as indexes of your input array, and simply copy the Strings. The following algorithm is not optimal, but it is graphic enough to explain and implement iteratively.
public static String[][] getAllPermutations(String[] str) {
LinkedList<Integer> current = new LinkedList<>();
LinkedList<Integer[]> permutations = new LinkedList<>();
int length = str.length;
current.add(-1);
while (!current.isEmpty()) {
// increment from the last position.
int position = Integer.MAX_VALUE;
position = getNextUnused(current, current.pop() + 1);
while (position >= length && !current.isEmpty()) {
position = getNextUnused(current, current.pop() + 1);
}
if (position < length) {
current.push(position);
} else {
break;
}
// fill with all available indexes.
while (current.size() < length) {
// find first unused index.
int unused = getNextUnused(current, 0);
current.push(unused);
}
// record result row.
permutations.add(current.toArray(new Integer[0]));
}
// select the right String, based on the index-permutation done before.
int numPermutations = permutations.size();
String[][] result = new String[numPermutations][length];
for (int i = 0; i < numPermutations; ++i) {
Integer[] indexes = permutations.get(i);
String[] row = new String[length];
for (int d = 0; d < length; ++d) {
row[d] = str[indexes[d]];
}
result[i] = row;
}
return result;
}
public static int getNextUnused(LinkedList<Integer> used, Integer current) {
int unused = current != null ? current : 0;
while (used.contains(unused)) {
++unused;
}
return unused;
}
The getAllPermutations-method is organized in an initialization part, a loop collecting all permutations (numeric), and finally a convertion of the found index-permutation into the String-permutations.
As the convertion from int to String is trivial, I'll just explain the collection part. The loop iterates as long, as the representation is not completely depleted, or terminated from within.
First, we increment the representation (current). For that, we take the last 'digit' and increment it to the next free value. Then we pop, if we are above length, and look at the next digit (and increment it). We continue this, until we hit a legal value (one below length).
After that, we fill the remainder of the digits with all still remaining digits. That done, we store the current representation to the list of arrays.
This algorithm is not optimal in terms of runtime! Heap is faster. But implementing Heap's iteratively requires a non-trivial stack which is tiresome to implement/explain.
I need to write an Algorithm for my course, to find the middle value of 4 sorted arrays different sizes in O(n), and i'm not allowed to create an array to store the data.
how should I approach the problem? I thought about running on a loop with 4 indexes as if i'm sorting the arrays into a big array but instead just run without storing the data. the loop will stop at n/2 and it should provide me the middle value.
writing it seems complex and very messy (i need to check for 4 of the arrays if i'm out of bound), is there a better way to approach this?
I think you're on to the key idea: the median value of all the values in four arrays is just the median of all the values, so if we get half way through all the values then whatever is next is the median. I would suggest structuring as follows:
int firstIndex = 0;
int secondIndex = 0;
int thirdIndex = 0;
int fourthIndex = 0;
double current;
for (int i = 0; i < n/2; i++) {
// 1.) Find the value out of the four at firstIndex, secondIndex, ...
// which is smallest, and assign it to current
// 2.) Increment whichever of the four indices belongs to that element
}
// whatever is in current at the end of the loop is the middle element
You probably want a function findMin(int index1, int index2, int index3, int index4). This method could also be responsible for the out-of-bounds checks, so the main loop could just rely on it to be pointed in the right direction, and not care if it's run out of elements in any given array.
Does this make sense? I've tried to leave enough ambiguity to let you handle most of the real implementation work :)
Think of a single unsorted array
Consider thinking about the 4 arrays as a single unsorted array, broken into 4 parts. If you can modify the arrays, you can sort all 4 arrays as if 1 by swapping values between them (some optimizations can be made since you know the 4 arrays are sorted). Once you've sorted the arrays up to n/2 (where n is the aggregate length of the 4 arrays), just return the middle value of all 4.
Some Code
The implementation below begins to make multiple arrays function like a single one. I've implemented get, set, and length methods, the basis for any array. All that needs to happen now is for the class' data to be sorted (possibly up to n/2) using get(int), set(int,int), and length(), and a method which returns the median value median().
There is also room for further optimization by sorting only up to n/2 within the median method, also when caching (i,j) pairs for each element when doing so.
int median( int[] a1, int[] a2, int[] a3, int[] a4 ) {
MultiIntArray array = new MultiIntArray( a1, a2, a3, a4 );
array.sort();
return array.get( array.length() / 2 );
}
public class MultiIntArray {
private int[][] data;
public MultiIntArray( int[]... data ) {
this.data = data;
}
public void sort() {
// FOR YOU TO IMPLEMENT
}
public int length() {
int length = 0;
for ( int[] array : data ) {
length += array.length;
}
return length;
}
public int get( int index ) {
int i = 0;
while ( index >= data[i].length ) {
index -= data[i].length;
i += 1;
}
return data[i][index];
}
public void set( int index, int value ) {
int i = 0;
while ( index >= data[i].length ) {
index -= data[i].length;
i += 1;
}
data[i][index] = value;
}
}
Closed. This question does not meet Stack Overflow guidelines. It is not currently accepting answers.
This question does not appear to be about programming within the scope defined in the help center.
Closed 5 years ago.
Improve this question
I am working on a problem. Out of 17 test cases 10 works fine and gives the result in less than a second but in 7 cases it is taking 2 seconds which are beyond the time limit. Following is the code
import java.util.*;
import java.io.*;
class TestClass
{
static PrintWriter wr = new PrintWriter(System.out);
public static void func1(int arr[], int n)
{
int temp = arr[0];
for (int jj = 0; jj < n; jj++)
{
if (jj == (n - 1))
arr[jj] = temp;
else
arr[jj] = arr[jj + 1];
}
}
public static void func2(int arr[], int n, int rt)
{
int count = 0;
for (int a = 0; a < n; a++)
{
for (int b = a; b < n; b++)
{
if (arr[a] > arr[b])
count++;
}
}
if (rt == (n - 1))
wr.print(count);
else
wr.print(count + " ");
}
public static void main(String args[]) throws Exception
{
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String str = br.readLine().trim();
StringTokenizer st = new StringTokenizer(str);
int t = Integer.parseInt(st.nextToken());
for (int i = 0; i < t; i++) //for test cases
{
str = br.readLine().trim();
st = new StringTokenizer(str);
int n = Integer.parseInt(st.nextToken());
int arr[] = new int[n];
str = br.readLine().trim();
st = new StringTokenizer(str);
for (int j = 0; j < n; j++) //to take input of array for each test case
{
arr[j] = Integer.parseInt(st.nextToken());
}
for (int rt = 0; rt < n; rt++) //for number of times circular shifting of array is done
{
func1(arr, n); //circularly shifts the array by one position
func2(arr, n, rt); //prints the number of inversion counts
}
if (i != (t - 1))
wr.println();
}
wr.close();
br.close();
}
}
Can someone suggest how to optimize the code so that it takes less time in execution.
I know BufferReader and PrintWriter takes less time as compared to Scanner and System.out.print. I was using scanner and System.out.print earlier but changed it later in hope of getting less time but it didn't help. Also I did it earlier without the use of func1 and func2 and did all the operations in main only. The time in both the cases remains the same.
I am getting the currect output in all the cases so code is correct, I just need help in optimizing it.
The website you are using acquires questions from past programming competitions. I recognize this as a familiar problem
Like most optimization questions, the preferred steps are:
Do less.
Do the same in fewer instructions.
Don't use functions.
Use faster instructions.
In your case, you have an array, and you wish to rotate it a number of times, and then to process it from the rotated position.
Rotating an array is an incredibly expensive operation, because you typically need to copy every element in the array into a new location. What is worse for you is that you are doing it the simplest way, you are rotating the array one step for every step needing rotation.
So, if you have a 100 element array that needs rotated 45 steps, you would then have (3 copies per element swap) 100 * 45 * 3 copies to perform your rotation.
In the above example, a better approach would be to figure out a routine that rotates an array 45 elements at a time. There are a number of ways to do this. The easiest is to double the RAM requirements and just have two arrays
b[x] = a[(mod(x+45), a.length)]
An even faster "do less" would be to never rotate the array, but to perform the calculation in reverse. This is conceptually the function of the desired index in the rotated array to the actual index in the pre-rotated array. This avoids all copying, and the index numbers (by virtue of being heavily manipulated in the math processing unit) will already be stored in the CPU registers, which is the fastest RAM a computer has.
Note that once you have the starting index in the original array, you can then calculate the next index without going through the calculation again.
I might have read this problem a bit wrong; because, it is not written to highlight the problem being solved. However, the core principles above apply, and it will be up to you to apply them to the exact specifics of your programming challenge.
An example of a faster rotate that does less
public static void func1(int arr[], int shift) {
int offset = shift % arr.length;
int [] rotated = new int[arr.length];
// (arr.length - 1) is the last index, walk up till we need to copy from the head of arr
for (int index = 0; index < (arr.length - 1) - offset; index++) {
rotated[index] = arr[index+offset];
}
// copy the end of the array back into the beginning
for ( int index = (arr.length - 1) - offset; index < arr.length; index++ ) {
rotated[index] = (offset - ((arr.length - 1) - index) - 1);
}
System.arraycopy(rotated, 0, arr, 0, arr.length);
}
This copies the array into its rotated position in one pass, instead of doing a pass per index to be rotated.
The first rule of optimisation (having decided it is necessary) is to use a profiler. This counts how many times methods are invoked, and measures the accumulated time within each method, and gives you a report.
It doesn't matter if a method is slow if you only run it a few times. If you run it hundreds of thousands of times, you need to either make it faster, or run it fewer times.
If you're using a mainstream IDE, you already have a profiler. Read its documentation and use it.
The other first rule of optimisation is, if there's already literature about the problem you're trying to solve, read it. Most of us might have invented bubble-sort independently. Fewer of us would have come up with QuickSort, but it's a better solution.
It looks as if you're counting inversions in the array. Your implementation is about as efficient as you can get, given that naive approach.
for(int i=0; i< array.length; i++) {
int n1 = array[i];
for(int j=i+1; j< array.length; j++) {
n2 = array[j];
if(n1 > n2) {
count++;
}
}
}
For an array of length l this will take ( l - 1) + ( l - 2 ) ... 1 -- that's a triangular number, and grows proportionally to the square of l.
So for l=1000 you're doing ~500,000 comparisons. Then since you're repeating the count for all 1000 rotations of the array, that would be 500,000,000 comparisons, which is definitely the sort of number where things start taking a noticeable amount of time.
Googling for inversion count reveals a more sophisticated approach, which is to perform a merge sort, counting inversions as they are encountered.
Otherwise, we need to look for opportunities for huge numbers of loop iterations. A loop inside a loop makes for big numbers. A loop inside a loop inside another loop makes for even bigger numbers.
You have:
for (int i = 0; i < t; i++) {
// stuff removed
for (int rt = 0; rt < n; rt++) {
// snip
func2(arr, n, rt); //prints the number of inversion counts
}
// snip
}
public static void func2(int arr[], int n, int rt) {
// snip
for (int a = 0; a < n; a++) {
for (int b = a; b < n; b++) {
// stuff
}
}
// snip
}
That's four levels of looping. Look at the input values for your slow tests, and work out what n * n * n * t is -- that an indicator of how many times it'll do the work in the inner block.
We don't know what your algorithm is supposed to achieve. But think about whether you're doing the same thing twice in any of these loops.
It looks as if func1() is supposed to rotate an array. Have a look at System.arrayCopy() for moving whole chunks of array at a time. Most CPUs will do this in a single operation.
Problem 10 from Project Euler:
The program runs for smaller numbers and slows to a crawl in the hundred thousands.
At 2 million, an answer fails to show up even though the program seems like it is still running.
I'm trying to implement the Sieve of Eratosthenes. It is supposed to be very fast. What's wrong with my approach?
import java.util.ArrayList;
public class p010
{
/**
* The sum of the primes below 10 is 2 + 3 + 5 + 7 = 17
* Find the sum of all the primes below two million.
* #param args
*/
public static void main(String[] args)
{
ArrayList<Integer> primes = new ArrayList<Integer>();
int upper = 2000000;
for (int i = 2; i < upper; i++)
{
primes.add(i);
}
int sum = 0;
for (int i = 0; i < primes.size(); i++)
{
if (isPrime(primes.get(i)))
{
for (int k = 2; k*primes.get(i) < upper; k++)
{
if (primes.contains(k*primes.get(i)))
{
primes.remove(primes.indexOf(k*primes.get(i)));
}
}
}
}
for (int i = 0; i < primes.size(); i++)
{
sum += primes.get(i);
}
System.out.println(sum);
}
public static boolean isPrime(int number)
{
boolean returnVal = true;
for (int i = 2; i <= Math.sqrt(number); i ++)
{
if (number % i == 0)
{
returnVal = false;
}
}
return returnVal;
}
}
You appear to be trying to implement the Sieve of Eratosthenes which should perform better that O(N^2) (In fact, Wikipedia says it is O(N log(log N)) ...).
The fundamental problem is your choice of data structure. You've chosen to represent the set of remaining prime candidates as an ArrayList of primes. This means that your test to see if a number is still in the set takes O(N) comparisons ... where N is the number of remaining primes. Then you are using ArrayList.remove(int) to remove the non-primes ... which is O(N) also.
That all adds up to making your Sieve implementation worse than O(N^2).
The solution is to replace the ArrayList<Integer> with an boolean[] where the positions (indexes) in the boolean array represent the numbers, and the value of the boolean says whether the number is prime / possibly prime, or not prime.
(There were other problems too that I didn't notice ... see the other answers.)
There are a few issues here. First, lets talk about the algorithm. Your isPrime method is actually the very thing that the sieve is designed to avoid. When you get to a number in the sieve, you already know it's prime, you don't need to test it. If it weren't prime, it would already have been eliminated as a factor of a lower number.
So, point 1:
You can eliminate the isPrime method altogether. It should never return false.
Then, there are implementation issues. primes.contains and primes.remove are problems. They run in linear time on an ArrayList, because they require checking each element or rewriting a large portion of the backing array.
Point 2:
Either mark values in place (use boolean[], or use some other more appropriate data structure.)
I typically use something like boolean primes = new boolean[upper+1], and define n to be included if !(primes[n]). (I just ignore elements 0 and 1 so I don't have to subtract indices.) To "remove" an element, I set it to true. You could also use something like TreeSet<Integer>, I suppose. Using boolean[], the method is near-instantaneous.
Point 3:
sum needs to be a long. The answer (roughly 1.429e11) is larger than the maximum value of an integer (2^31-1)
I can post working code if you like, but here's a test output, without spoilers:
public static void main(String[] args) {
long value;
long start;
long finish;
start = System.nanoTime();
value = arrayMethod(2000000);
finish = System.nanoTime();
System.out.printf("Value: %.3e, time: %4d ms\n", (double)value, (finish-start)/1000000);
start = System.nanoTime();
value = treeMethod(2000000);
finish = System.nanoTime();
System.out.printf("Value: %.3e, time: %4d ms\n", (double)value, (finish-start)/1000000);
}
output:
Using boolean[]
Value: 1.429e+11, time: 17 ms
Using TreeSet<Integer>
Value: 1.429e+11, time: 4869 ms
Edit:
Since spoilers are posted, here's my code:
public static long arrayMethod(int upper) {
boolean[] primes = new boolean[upper+1];
long sum = 0;
for (int i = 2; i <=upper; i++) {
if (!primes[i]) {
sum += i;
for (int k = 2*i; k <= upper; k+=i) {
primes[k] = true;
}
}
}
return sum;
}
public static long treeMethod(int upper) {
TreeSet<Integer> primes = new TreeSet<Integer>();
for (int i = 2; i <= upper; i++) {
primes.add(i);
}
long sum = 0;
for (Integer i = 2; i != null; i=primes.higher(i)) {
sum += i;
for (int k = 2*i; k <= upper; k+=i) {
primes.remove(k);
}
}
return sum;
}
Two things:
Your code is hard to follow. You have a list called "primes", that contains non prime numbers!
Also, you should strongly consider whether or not an array list is appropriate. In this case, a LinkedList would be much more efficient.
Why is this? An array list must constantly resize an array by: asking for new memory to create an array, then copying the old memory over in the newly created array. A Linked list would just resize the memory by changing a pointer. This is a lot quicker! However, I do not think that by making this change you can salvage your algorithm.
You should use an array list if you need to access the items non-sequentially, here, (with a suitable algorithm) you need to access the items sequentially.
Also, your algorithm is slow.Take the advice of SJuan76 (or gyrogearless), thanks sjuan76
The key to the efficiency of classic implementation of the sieve of Eratosthenes on modern CPUs is the direct (i.e. non-sequential) memory access. Fortunately, ArrayList<E> does implement RandomAccess.
Another key to the sieve's efficiency is its conflation of index and value, just like in integer sorting. Actually removing any number from the sequence destroys this ability to directly address without any computations. We must mark, not remove, any composite as we find them, so any numbers greater than it will remain in their places in the sequence.
ArrayList<Integer> can be used for that (except taking more memory than is strictly necessary, but for 2 million this is inconsequential).
So your code with a minimal edit fix (also changing sum to be long as others point out too), becomes
import java.util.ArrayList;
public class Main
{
/**
* The sum of the primes below 10 is 2 + 3 + 5 + 7 = 17
* Find the sum of all the primes below two million.
* #param args
*/
public static void main(String[] args)
{
ArrayList<Integer> primes = new ArrayList<Integer>();
int upper = 5000;
primes.ensureCapacity(upper);
for (int i = 0; i < upper; i++) {
primes.add(i);
}
long sum = 0;
for (int i = 2; i <= upper / i; i++) {
if ( primes.get(i) > 0 ) {
for (int k = i*i; k < upper ; k+=i) {
primes.set(k, 0);
}
}
}
for (int i = 2; i < upper; i++) {
sum += primes.get(i);
}
System.out.println(sum);
}
}
Finds the result for 2000000 in half a second on Ideone. The projected run time for your original code there: between 10 and 400 hours (!).
To find rough estimates for the run time when faced with a slow code, you should always try to find out its empirical orders of growth: run it for some small size n1, then a bigger size n2, record the run times t1 and t2. If t ~ n^a, then a = log(t2/t1) / log(n2/n1).
For your original code the empirical orders of growth measured on 10k .. 20k .. 40k range of upper limit value N, are ~ N^1.7 .. N^1.9 .. N^2.1. For the fixed code it's faster than ~ N (in fact, it's ~ N^0.9 in the tested range 0.5 mln .. 1 mln .. 2 mln). The theoretical complexity is O(N log (log N)).
Your program is not the Sieve of Eratosthenes; the modulo operator gives it away. Your program will be O(n^2), where a proper Sieve of Eratosthenes is O(n log log n), which is essentially n. Here's my version; I'll leave it to you to translate to Java with appropriate numeric datatypes:
function sumPrimes(n)
sum := 0
sieve := makeArray(2..n, True)
for p from 2 to n step 1
if sieve[p]
sum := sum + p
for i from p * p to n step p
sieve[i] := False
return sum
If you're interested in programming with prime numbers, I modestly recommend this essay at my blog.
given a unsorted set of n integers, return all subsets of size k (i.e. each set has k unique elements) that sum to 0.
So I gave the interviewer the following solution ( which I studied on GeekViewpoint). No extra space used, everything is done in place, etc. But of course the cost is a high time complexity of O(n^k) where k=tuple in the solution.
public void zeroSumTripplets(int[] A, int tuple, int sum) {
int[] index = new int[tuple];
for (int i = 0; i < tuple; i++)
index[i] = i;
int total = combinationSize(A.length, tuple);
for (int i = 0; i < total; i++) {
if (0 != i)
nextCombination(index, A.length, tuple);
printMatch(A, Arrays.copyOf(index, tuple), sum);
}// for
}// zeroSumTripplets(int[], int, int)
private void printMatch(int[] A, int[] ndx, int sum) {
int calc = 0;
for (int i = 0; i < ndx.length; i++)
calc += A[ndx[i]];
if (calc == sum) {
Integer[] t = new Integer[ndx.length];
for (int i = 0; i < ndx.length; i++)
t[i] = A[ndx[i]];
System.out.println(Arrays.toString(t));
}// if
}// printMatch(int[], int[], int)
But then she imposed the following requirements:
must use hashmap in answer so to reduce time complexity
Must absolutely -- ABSOLUTELY -- provide time complexity for general case
hint when k=6, O(n^3)
She was more interested in time-complexity more than anything else.
Does anyone know a solution that would satisfy the new constraints?
EDIT:
Supposedly, in the correct solution, the map is to store the elements of the input and the map is then to be used as a look up table just as in the case for k=2.
When the size of the subset is 2 (i.e. k=2), the answer is trivial: loop through and load all the elements into a map. Then loop through the inputs again this time searching the map for sum - input[i] where i is the index from 0 to n-1, which would then be the answers. Supposedly this trivial case can be extended to where k is anything.
Since no-one else has made an attempt, I might as well throw in at least a partial solution. As I pointed out in an earlier comment, this problem is a variant of the subset sum problem and I have relied heavily on documented approaches to that problem in developing this solution.
We're trying to write a function subsetsWithSum(A, k, s) that computes all the k-length subsets of A that sum to s. This problem lends itself to a recursive solution in two ways:
The solution of subsetsWithSum(x1 ... xn, k, s) can be found by computing subsetsWithSum(x2 ... xn, k, s) and adding all the valid subsets (if any) that include x1; and
All the valid subsets that include element xi can be found by computing subsetsWithSum(A - xi, k-1, s-xi) and adding xi to each subset (if any) that results.
The base case for the recursion occurs when k is 1, in which case the solution to subsetsWithSum(A, 1, s) is the set of all single element subsets where that element is equal to s.
So a first stab at a solution would be
/**
* Return all k-length subsets of A starting at offset o that sum to s.
* #param A - an unordered list of integers.
* #param k - the length of the subsets to find.
* #param s - the sum of the subsets to find.
* #param o - the offset in A at which to search.
* #return A list of k-length subsets of A that sum to s.
*/
public static List<List<Integer>> subsetsWithSum(
List<Integer> A,
int k,
int s,
int o)
{
List<List<Integer>> results = new LinkedList<List<Integer>>();
if (k == 1)
{
if (A.get(o) == s)
results.add(Arrays.asList(o));
}
else
{
for (List<Integer> sub : subsetsWithSum(A, k-1, s-A.get(o), o+1))
{
List<Integer> newSub = new LinkedList<Integer>(sub);
newSub.add(0, o);
results.add(0, newSub);
}
}
if (o < A.size() - k)
results.addAll(subsetsWithSum(A, k, s, o+1));
return results;
}
Now, notice that this solution will often call subsetsWithSum(...) with the same set of arguments that it has been called with before. Hence, subsetsWithSum is just begging to be memoized.
To memoize the function, I've put the arguments k, s and o into a three element list which will be the key to a map from these arguments to a result computed earlier (if there is one):
public static List<List<Integer>> subsetsWithSum(
List<Integer> A,
List<Integer> args,
Map<List<Integer>, List<List<Integer>>> cache)
{
if (cache.containsKey(args))
return cache.get(args);
int k = args.get(0), s = args.get(1), o = args.get(2);
List<List<Integer>> results = new LinkedList<List<Integer>>();
if (k == 1)
{
if (A.get(o) == s)
results.add(Arrays.asList(o));
}
else
{
List<Integer> newArgs = Arrays.asList(k-1, s-A.get(o), o+1);
for (List<Integer> sub : subsetsWithSum(A, newArgs, cache))
{
List<Integer> newSub = new LinkedList<Integer>(sub);
newSub.add(0, o);
results.add(0, newSub);
}
}
if (o < A.size() - k)
results.addAll(subsetsWithSum(A, Arrays.asList(k, s, o+1), cache));
cache.put(args, results);
return results;
}
To use the subsetsWithSum function to compute all the k-length subsets that sum to zero, one can use the following function:
public static List<List<Integer>> subsetsWithZeroSum(List<Integer> A, int k)
{
Map<List<Integer>, List<List<Integer>>> cache =
new HashMap<List<Integer>, List<List<Integer>>> ();
return subsetsWithSum(A, Arrays.asList(k, 0, 0), cache);
}
Regrettably my complexity calculating skills are a bit (read: very) rusty, so hopefully someone else can help us compute the time complexity of this solution, but it should be an improvement on the brute-force approach.
Edit: Just for clarity, note that the first solution above should be equivalent in time complexity to a brute-force approach. Memoizing the function should help in many cases, but in the worst case the cache will never contain a useful result and the time complexity will then be the same as the first solution. Note also that the subset-sum problem is NP-complete meaning that any solution has an exponential time complexity. End Edit.
Just for completeness, I tested this with:
public static void main(String[] args) {
List<Integer> data = Arrays.asList(9, 1, -3, -7, 5, -11);
for (List<Integer> sub : subsetsWithZeroSum(data, 4))
{
for (int i : sub)
{
System.out.print(data.get(i));
System.out.print(" ");
}
System.out.println();
}
}
and it printed:
9 -3 5 -11
9 1 -3 -7
I think your answer was very close to what they were looking for, but you can improve the complexity by noticing that any subset of size k can be thought of as two subsets of size k/2. So instead of finding all subsets of size k (which takes O(n^k) assuming k is small), use your code to find all subsets of size k/2, and put each subset in a hashtable, with its sum as the key.
Then iterate through each subset of size k/2 with a positive sum (call the sum S) and check the hashtable for a subset whose sum is -S. If there is one then the combination of the two subsets of size k/2 is a subset of size k whose sum is zero.
So in the case of k=6 that they gave, you would find all subsets of size 3 and compute their sums (this will take O(n^3) time). Then checking the hashtable will take O(1) time for each subset, so the total time is O(n^3). In general this approach will take O(n^(k/2)) assuming k is small, and you can generalize it for odd values of k by taking subsets of size floor(k/2) and floor(k/2)+1.
#kasavbere -
Recently a friend had one of those harrowing all-day interviews for a C++ programming job with Google. His experience was similar to yours.
It inspired him to write this article - I think you might enjoy it:
The Pragmatic Defense