Dynamic programming approach to TSP in Java - java

I'm a beginner, and I'm trying to write a working travelling salesman problem using dynamic programming approach.
This is the code for my compute function:
public static int compute(int[] unvisitedSet, int dest) {
if (unvisitedSet.length == 1)
return distMtx[dest][unvisitedSet[0]];
int[] newSet = new int[unvisitedSet.length-1];
int distMin = Integer.MAX_VALUE;
for (int i = 0; i < unvisitedSet.length; i++) {
for (int j = 0; j < newSet.length; j++) {
if (j < i) newSet[j] = unvisitedSet[j];
else newSet[j] = unvisitedSet[j+1];
int distCur;
if (distMtx[dest][unvisitedSet[i]] != -1) {
distCur = compute(newSet, unvisitedSet[i]) + distMtx[unvisitedSet[i]][dest];
if (distMin > distCur)
distMin = distCur;
else {
System.out.println("No path between " + dest + " and " + unvisitedSet[i]);
return distMin;
The code is not giving me the correct answers, and I'm trying to figure out where the error is occurring. I think my error occurs when I add:
distCur = compute(newSet, unvisitedSet[i]) + distMtx[unvisitedSet[i]][dest];
So I've been messing around with that part, moving the addition to the very end right before I return distMin and so on... But I couldn't figure it out.
Although I'm sure it can be inferred from the code, I will state the following facts to clarify.
distMtx stores all the intercity distances, and distances are symmetric, meaning if distance from city A to city B is 3, then the distance from city B to city A is also 3. Also, if two cities don't have any direct paths, the distance value is -1.
Any help would be very much appreciated!
The main function reads the intercity distances from a text file. Because I'm assuming the number of cities will always be less than 100, global int variable distMtx is [100][100].
Once the matrix is filled with the necessary information, an array of all the cities are created. The names of the cities are basically numbers. So if I have 4 cities, set[4] = {0, 1, 2, 3}.
In the main function, after distMtx and set is created, first call to compute() is called:
int optRoute = compute(set, 0);
Sample input:
-1 3 2 7
3 -1 10 1
2 10 -1 4
7 1 4 -1
Expected output:

Here's a working iterative solution to the TSP with dynamic programming. What would make your life easier is to store the current state as a bitmask instead of in an array. This has the advantage that the state representation is compact and can be cached easily.
I made a video detailing the solution to this problem on Youtube, please enjoy! Code was taken from my github repo
* An implementation of the traveling salesman problem in Java using dynamic
* programming to improve the time complexity from O(n!) to O(n^2 * 2^n).
* Time Complexity: O(n^2 * 2^n)
* Space Complexity: O(n * 2^n)
import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
public class TspDynamicProgrammingIterative {
private final int N, start;
private final double[][] distance;
private List<Integer> tour = new ArrayList<>();
private double minTourCost = Double.POSITIVE_INFINITY;
private boolean ranSolver = false;
public TspDynamicProgrammingIterative(double[][] distance) {
this(0, distance);
public TspDynamicProgrammingIterative(int start, double[][] distance) {
N = distance.length;
if (N <= 2) throw new IllegalStateException("N <= 2 not yet supported.");
if (N != distance[0].length) throw new IllegalStateException("Matrix must be square (n x n)");
if (start < 0 || start >= N) throw new IllegalArgumentException("Invalid start node.");
this.start = start;
this.distance = distance;
// Returns the optimal tour for the traveling salesman problem.
public List<Integer> getTour() {
if (!ranSolver) solve();
return tour;
// Returns the minimal tour cost.
public double getTourCost() {
if (!ranSolver) solve();
return minTourCost;
// Solves the traveling salesman problem and caches solution.
public void solve() {
if (ranSolver) return;
final int END_STATE = (1 << N) - 1;
Double[][] memo = new Double[N][1 << N];
// Add all outgoing edges from the starting node to memo table.
for (int end = 0; end < N; end++) {
if (end == start) continue;
memo[end][(1 << start) | (1 << end)] = distance[start][end];
for (int r = 3; r <= N; r++) {
for (int subset : combinations(r, N)) {
if (notIn(start, subset)) continue;
for (int next = 0; next < N; next++) {
if (next == start || notIn(next, subset)) continue;
int subsetWithoutNext = subset ^ (1 << next);
double minDist = Double.POSITIVE_INFINITY;
for (int end = 0; end < N; end++) {
if (end == start || end == next || notIn(end, subset)) continue;
double newDistance = memo[end][subsetWithoutNext] + distance[end][next];
if (newDistance < minDist) {
minDist = newDistance;
memo[next][subset] = minDist;
// Connect tour back to starting node and minimize cost.
for (int i = 0; i < N; i++) {
if (i == start) continue;
double tourCost = memo[i][END_STATE] + distance[i][start];
if (tourCost < minTourCost) {
minTourCost = tourCost;
int lastIndex = start;
int state = END_STATE;
// Reconstruct TSP path from memo table.
for (int i = 1; i < N; i++) {
int index = -1;
for (int j = 0; j < N; j++) {
if (j == start || notIn(j, state)) continue;
if (index == -1) index = j;
double prevDist = memo[index][state] + distance[index][lastIndex];
double newDist = memo[j][state] + distance[j][lastIndex];
if (newDist < prevDist) {
index = j;
state = state ^ (1 << index);
lastIndex = index;
ranSolver = true;
private static boolean notIn(int elem, int subset) {
return ((1 << elem) & subset) == 0;
// This method generates all bit sets of size n where r bits
// are set to one. The result is returned as a list of integer masks.
public static List<Integer> combinations(int r, int n) {
List<Integer> subsets = new ArrayList<>();
combinations(0, 0, r, n, subsets);
return subsets;
// To find all the combinations of size r we need to recurse until we have
// selected r elements (aka r = 0), otherwise if r != 0 then we still need to select
// an element which is found after the position of our last selected element
private static void combinations(int set, int at, int r, int n, List<Integer> subsets) {
// Return early if there are more elements left to select than what is available.
int elementsLeftToPick = n - at;
if (elementsLeftToPick < r) return;
// We selected 'r' elements so we found a valid subset!
if (r == 0) {
} else {
for (int i = at; i < n; i++) {
// Try including this element
set |= 1 << i;
combinations(set, i + 1, r - 1, n, subsets);
// Backtrack and try the instance where we did not include this element
set &= ~(1 << i);
public static void main(String[] args) {
// Create adjacency matrix
int n = 6;
double[][] distanceMatrix = new double[n][n];
for (double[] row : distanceMatrix) java.util.Arrays.fill(row, 10000);
distanceMatrix[5][0] = 10;
distanceMatrix[1][5] = 12;
distanceMatrix[4][1] = 2;
distanceMatrix[2][4] = 4;
distanceMatrix[3][2] = 6;
distanceMatrix[0][3] = 8;
int startNode = 0;
TspDynamicProgrammingIterative solver = new TspDynamicProgrammingIterative(startNode, distanceMatrix);
// Prints: [0, 3, 2, 4, 1, 5, 0]
System.out.println("Tour: " + solver.getTour());
// Print: 42.0
System.out.println("Tour cost: " + solver.getTourCost());

I know this is pretty old question but it might help somebody in the future.
Here is very well written paper on TSP with dynamic programming approach

I think you have to make some changes in your program.
Here there is an implementation


Find consecutive range in a given list with limit

Find the maximum consecutive elements matching the given condition.
I have a list of numbers called A, another list called B and a limit called Limit.
The task is find the maximum k consecutive elements in A such that they satisfy below condition.
Max(B[i],B[i+1],...B[i+k]) + Sum(A[i], A[i+1], ..., A[i+k]) * k ≤ Limit
A = [2,1,3,4,5]
B = [3,6,1,3,4]
Limit = 25
Take 2 consecutive elements:
Highest sum occurs with elements in A = 4,5. The corresponding max in B is Max(3,4) = 4.
So value = 4 + (4+5) * 2 = 22. Here 22 ≤ 25, so 2 consecutive is possible
Take 3 consecutive elements:
Taking sum for 1st 3 elements of A = 2,1,3. The corresponding max in B is Max(3,6,1) = 6.
So value = 6 + (2+1+3) * 3 = 24. Here 24 ≤ 25, so 3 consecutive is possible
Take 4 consecutive elements:
Taking sum for 1st 4 elements of A = 2,1,3,4. The corresponding max in B is Max(3,6,1,3) = 6.
So value = 6 + (2+1+3+4) * 4 = 46. Here 46 > 25, so 4 consecutive is not possible
So correct answer to this input is 3.
n (Size of A) is up to 10⁵, A elements up to 10¹⁴, B elements up to 10⁹, Limit up to 10¹⁴.
Here is my code:
public int getMax(List<Integer> A, List<Integer> B, long limit) {
int result = 0;
int n = A.size();
for(int len=1; len<=n; len++) {
for(int i=0; i<=n-len; i++) {
int j=i+len-1;
int max = B.get(i);
long total = 0;
for(int k=i; k<=j; k++) {
total += A.get(k);
max = Math.max(max, B.get(k));
total = max + total * len;
if(total < limit) {
result = len;
return result;
This code works for smaller range of inputs.
But fails with a time out for larger inputs. How can I reduce time complexity of this code?
Updated code based on dratenik answer, but the sample test case mentioned in my post itself is failing. The program is returning 4 instead of 3.
public int getMax(List<Integer> A, List<Integer> B, long limit) {
int from = 0, to = 0, max = -1;
int n = A.size();
for (; from < n;) {
int total = 0;
int m = B.get(from); // updated here
for (int i = from; i < to; i++) {
total += A.get(i); // updated here
m = Math.max(m, B.get(i)); // updated here
total = m + total * (to - from); // updated here
if (total <= limit && to - from + 1 > max) {
max = to - from + 1;
if (total < limit && to < n) { // below target, extend window
} else { // otherwise contract window
if (from > to) {
to = from;
return max;
Since all the elements of A and B are positive, you can solve this with the usual two-pointer approach to finding a maximum length subarray:
Initialize two pointers s and e to the start of the arrays, and then advance e as far as possible without violating the limit. This finds the longest valid subarray that starts at s.
While e isn't at the end of the arrays, advance s by one position, and then again advance e as far as possible without violating the limit. This finds the longest valid subarray that starts at every position. This leads to an O(n) algorithm, because e can advance monotonically.
Your answer is the longest valid sequence you see.
In order to determine in O(1) whether or not a particular range from s to e is valid, you need to track the cumulative sum of A elements and the current maximum of B elements.
The sum is easy -- just add elements that e passes and subtract elements that s passes.
To track the current maximum of elements in B, you can use the standard sliding-window-maximum algorithm described here: Sliding window maximum in O(n) time. It works just fine with expanding and contracting windows, maintaining O(1) amortized cost per operation.
Here's an O(n) solution in Java. Note that I multiplied the sum of A elements by the length of the sequence, because it's what you seem to intend, even though the formula you wrote multiplies by length-1:
public static int getMax(List<Integer> A, List<Integer> B, long limit) {
final int size = A.size();
// a Queue containing indexes of elements that may become max in the window
// they must be monotonically decreasing
final int maxQ[] = new int[size];
int maxQstart = 0, maxQend = 0;
// current valid window start and end
int s=0, e = 0;
int bestLen = 0;
long windowSum = 0;
while (s < size && e < size) {
// calculate longer window max
long nextMax = maxQstart < maxQend ? B.get(maxQ[maxQstart]) : 0;
nextMax = Math.max(nextMax, B.get(e));
long sumPart = (windowSum + A.get(e)) * (e+1-s);
if (nextMax + sumPart <= limit) {
// extending the window is valid
int lastB = B.get(e);
while (maxQstart < maxQend && B.get(maxQ[maxQend-1]) <= lastB) {
maxQ[maxQend++] = e;
windowSum += A.get(e);
if (e-s > bestLen) {
bestLen = e-s;
} else if (e > s) {
// extending the window is invalid.
// move up the start instead
windowSum -= A.get(s);
while(maxQstart < maxQend && maxQ[maxQstart] < s) {
} else {
// we need to move the start up, but the window is empty, so move them both
return bestLen;
Sliding window approach? Slightly pseudocodey version:
int from=0, to=0, max = -1;
for(;from<n;) {
total = (target expression on elements between from-to inclusive)
if (total<=target && to-from+1 > max) {max = to-from+1;}
if (total<target && to<n) { // below target, extend window
} else { // otherwise contract window
if (from>to) {to=from;}
return max;
The sum could be updated incrementally, but I don't know how to sensibly update the max(B[i],B[i+1],...B[i+k]) part when contracting the window, so let's recompute the whole thing at each step.
I tried to use meanigful names to make the code readable. Don't hesitate to ask where it is not clear:
public int getMax(List<Integer> a, List<Integer> b, long limit) {
int max = -1;
int numberOfElements = 2;
boolean found;
found = false;
for ( int index = 0; index <= a.size() - numberOfElements; index++) {
int totalA = 0;
int maxB = b.get(index);
for (int i = index; i < index + numberOfElements; i++) {
totalA += a.get(i);
maxB = Math.max(maxB,b.get(i)); // updated here
int total = maxB + totalA * numberOfElements;
if (total <= limit && numberOfElements >= max) {
max = numberOfElements;
found = true;
} while(found && numberOfElements <= a.size());
return max;
(more test cases can be helpful for further debugging)
I think the main obstacle there is how to efficiently track maximum over sliding window.
Easy optimization in this respect without diving into dynamic programming is to make use of MaxHeap.
In java it is implemented as PriorityQueue.
Please consider following code.
private int findMaxRange(List<Long> listA, List<Long> listB, long limit) {
int maxRange = 0;
while (maxRange < listA.size() && isRangePossible(listA, listB, limit, maxRange+1)) {
return maxRange;
private boolean isRangePossible(List<Long> listA, List<Long> listB, long limit, int rangeSize) {
//calculate initial values of max and sum
PriorityQueue<Long> maxHeap = new PriorityQueue<>(rangeSize, Comparator.reverseOrder());
Long max = maxHeap.peek();
Long sum = listA.stream().limit(rangeSize).mapToLong(i->i).sum();
//iterate with sliding window
for (int i = 0; i < listA.size() - rangeSize; i++) {
if (isConditionMet(max, sum, rangeSize, limit)) {
return true;
sum = sum + listA.get(i+rangeSize) - listA.get(i);
max = maxHeap.peek();
return isConditionMet(max, sum, rangeSize, limit);
private boolean isConditionMet(Long max, Long sum, int rangeSize, long limit) {
return max + sum * rangeSize < limit;
Also please pay attention to value ranges. Such big values can easily overflow long and may require specialized types like BigInteger. You should also consider how much memory is used by auxiliary datatypes.
The problem here seems to be use of three nested for loops for calculating max and sum for every window.
We can avoid this unnecessary iterations by using calculations of previous iteration in the new iteration with the help of Dynamic programming.
In my solution, I made 2 2d Arrays, one to store max values for each windows and other to store sums for each windows, storing values of previous iterations will greatly reduce the time complexity.
here is the Java code:
import java.util.*;
public class MyClass {
public static void main(String[] args) {
System.out.println("Hello, World!");
List A = Arrays.asList(2,1,3,4,5);
List B = Arrays.asList(3,6,1,3,4);
System.out.println(MyClass.getMax(A, B, 25L));
public static int getMax(List<Integer> A, List<Integer> B, long limit) {
int n = A.size();
int[][] dp1 = new int[n + 1][n + 1];
int[][] dp2 = new int[n + 1][n + 1];
for(int i = 1; i <= n; i++) {
for(int j = i; j <= n; j++) {
dp1[i][j] = Math.max(dp1[i - 1][j- 1], B.get(j - 1));
dp2[i][j] = dp2[i - 1][j- 1] + A.get(j - 1);
for(int i = 0; i <= n; i++) {
for(int j = 0; j <= n; j++) {
System.out.print("{" + dp1[i][j] + ", " + dp2[i][j] + "}, ");
int kMax = 0;
for(int i = 0; i <= n; i++) {
for(int j = i; j <= n; j++) {
if(dp1[i][j] + dp2[i][j] * i <= limit) {
kMax = i;
System.out.println("Max K: " + kMax);
return 0;
if you are dependent only on algorithms and not making any app or game, it's not necessary that you have to use java, try using python or c++ (or even c, c#), python is used mostly for algorithms,
or if you need java only then add breakpoints or make the program to print all work it does (ask it to print j, i, k, result variables in console) then you can easily debug.

how to determine if a number is a smart number in java?

I have this question I am trying to solve. I have tried coding for the past 4 hours.
An integer is defined to be a Smart number if it is an element in the infinite sequence
1, 2, 4, 7, 11, 16 …
Note that 2-1=1, 4-2=2, 7-4=3, 11-7=4, 16-11=5 so for k>1, the kth element of the sequence is equal to the k-1th element + k-1. For example, for k=6, 16 is the kth element and is equal to 11 (the k-1th element) + 5 ( k-1).
Write function named isSmart that returns 1 if its argument is a Smart number, otherwise it returns 0. So isSmart(11) returns 1, isSmart(22) returns 1 and isSmart(8) returns 0
I have tried the following code to
import java.util.Arrays;
public class IsSmart {
public static void main(String[] args) {
// TODO Auto-generated method stub
int x = isSmart(11);
public static int isSmart(int n) {
int[] y = new int[n];
int j = 0;
for (int i = 1; i <= n; i++) {
y[j] = i;
for (int i = 0; i <= y.length; i++) {
int diff = 0;
y[j] = y[i+1] - y[i] ;
y[i] = diff;
for (int i = 0; i < y.length; i++) {
if(n == y[i])
return 1;
return 0;
When I test it with 11 it is giving me 0 but it shouldn't. Any idea how to correct my mistakes?
It can be done in a simpler way as follows
import java.util.Arrays;
public class IsSmart {
public static void main(String[] args) {
int x = isSmart(11);
System.out.println("Ans: "+x);
public static int isSmart(int n) {
//------------ CHECK THIS LOGIC ------------//
int[] y = new int[n];
int diff = 1;
for (int i = 1; i < n; i++) {
y[0] =1;
y[i] = diff + y[i-1];
//------------ CHECK THIS LOGIC ------------//
for (int i = 0; i < y.length; i++) {
if(n == y[i])
return 1;
return 0;
One of the problems is the way that your populating your array.
The array can be populated as such
for(int i = 0; i < n; i++) {
y[i] = (i == 0) ? 1 : y[i - 1] + i;
The overall application of the function isSmart can be simplified to:
public static int isSmart(int n) {
int[] array = new int[n];
for(int i = 0; i < n; i++) {
array[i] = (i == 0) ? 1 : array[i - 1] + i;
for (int i = 0; i < array.length; i++) {
if (array[i] == n) return 1;
return 0;
Note that you don't need to build an array:
public static int isSmart(int n) {
int smart = 1;
for (int i = 1; smart < n; i++) {
smart = smart + i;
return smart == n ? 1 : 0;
Here is a naive way to think of it to get you started - you need to fill out the while() loop. The important thing to notice is that:
The next value of the sequence will be the number of items in the sequence + the last item in the sequence.
import java.util.ArrayList;
public class Test {
public static void main(String[] args) {
public static int isSmart(int n) {
ArrayList<Integer> sequence = new ArrayList<Integer>();
// Start with 1 in the ArrayList
// You need to keep track of the index, as well as
// the next value you're going to add to your list
int index = 1; // or number of elements in the sequence
int nextVal = 1;
while (nextVal < n) {
// Three things need to happen in here:
// 1) set nextVal equal to the sum of the current index + the value at the *previous* index
// 2) add nextVal to the ArrayList
// 3) incriment index by 1
// Now you can check to see if your ArrayList contains n (is Smart)
if (sequence.contains(n)) { return 1; }
return 0;
First think of a mathematical solution.
Smart numbers form a sequence:
a0 = 1
an+1 = n + an
This gives a function for smart numbers:
f(x) = ax² + bx + c
f(x + 1) = f(x) + x = ...
So the problem is to find for a given y a matching x.
You can do this by a binary search.
int isSmart(int n) {
int xlow = 1;
int xhigh = n; // Exclusive. For n == 0 return 1.
while (xlow < xhigh) {
int x = (xlow + xhigh)/2;
int y = f(x);
if (y == n) {
return 1;
if (y < n) {
xlow = x + 1;
} else {
xhigh = x;
return 0;
Yet smarter would be to use the solution for x and look whether it is an integer:
ax² + bx + c' = 0 where c' = c - n
x = ...
I was playing around with this and I noticed something. The smart numbers are
1 2 4 7 11 16 22 29 ...
If you subtract one you get
0 1 3 6 10 15 21 28 ...
0 1 2 3 4 5 6 7 ...
The above sequence happens to be the sum of the first n numbers starting with 0 which is n*(n+1)/2. So add 1 to that and you get a smart number.
Since n and n+1 are next door to each other you can derive them by reversing the process.
Take 29, subtract 1 = 28, * 2 = 56. The sqrt(56) rounded up is 8. So the 8th smart number (counting from 0) is 29.
Using that information you can detect a smart number without a loop by simply reversing the process.
public static int isSmart(int v) {
int vv = (v-1)*2;
int sq = (int)Math.sqrt(vv);
int chk = (sq*(sq+1))/2 + 1;
return (chk == v) ? 1 : 0;
Using a version which supports longs have verified this against the iterative process from 1 to 10,000,000,000.

Smallest sum of triplet products where the middle element is removed using Dynamic Programming

I have given a sequence of N numbers (4 ≤ N ≤ 150). One index i (0 < i < N) is picked and multiplied with the left and the right number, in other words with i-1 and i+1. Then the i-th number is removed. This is done until the sequence has only two numbers left over. The goal is to find the smallest sum of these products which obviously depends on the order in which the indices are picked.
E.g. for the sequence 44, 45, 5, 39, 15, 22, 10 the smallest sum would be 17775
using the indices in the following order: 1->3->4->5->2 which is the sum:
44*45*5 + 5*39*15 + 5*15*22 + 5*22*10 + 44*5*10 = 9900 + 2925 + 1650 + 1100 + 2200 = 17775
I have found a solution using a recursive function:
public static int smallestSum(List<Integer> values) {
if (values.size() == 3)
return values.get(0) * values.get(1) * values.get(2);
else {
int ret = Integer.MAX_VALUE;
for (int i = 1; i < values.size() - 1; i++) {
List<Integer> copy = new ArrayList<Integer>(values);
int val = smallestSum(copy) + values.get(i - 1) * values.get(i) * values.get(i + 1);
if (val < ret) ret = val;
return ret;
However, this solution is only feasible for small N but not for a bigger amount of numbers. What I am looking for is a way to do this using an iterative Dynamic Programming approach.
The optimal substructure needed for a DP is that, given the identity of the last element removed, the elimination strategy for the elements to the left is independent of the elimination strategy for the elements to the right. Here's a new recursive function (smallestSumA, together with the version from the question and a test harness comparing the two) incorporating this observation:
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class Foo {
public static void main(String[] args) {
Random r = new Random();
for (int i = 0; i < 10000; i++) {
List<Integer> values = new ArrayList<Integer>();
for (int n = 3 + r.nextInt(8); n > 0; n--) {
int a = smallestSumA(values, 0, values.size() - 1);
int q = smallestSumQ(values);
if (q != a) {
public static int smallestSumA(List<Integer> values, int first, int last) {
if (first + 2 > last)
return 0;
int ret = Integer.MAX_VALUE;
for (int i = first + 1; i <= last - 1; i++) {
int val = (smallestSumA(values, first, i)
+ values.get(first) * values.get(i) * values.get(last) + smallestSumA(values, i, last));
if (val < ret)
ret = val;
return ret;
public static int smallestSumQ(List<Integer> values) {
if (values.size() == 3)
return values.get(0) * values.get(1) * values.get(2);
else {
int ret = Integer.MAX_VALUE;
for (int i = 1; i < values.size() - 1; i++) {
List<Integer> copy = new ArrayList<Integer>(values);
int val = smallestSumQ(copy) + values.get(i - 1) * values.get(i) * values.get(i + 1);
if (val < ret)
ret = val;
return ret;
Invoke as smallestSum(values, 0, values.size() - 1).
To get the DP, observe that there are only N choose 2 different settings for first and last, and memoize. The running time is O(N^3).
If anyone is interested in a DP solution, based on David Eisenstat's recursive solution, here is an iterative one using DP (for many big numbers it's useful to replace int's with long's):
public static int smallestSum(List<Integer> values) {
int[][] table = new int[values.size()][values.size()];
for (int i = 2; i < values.size(); i++) {
for (int j = 0; j + i < values.size(); j++) {
int ret = Integer.MAX_VALUE;
for (int k = j + 1; k <= j + i - 1; k++) {
int val = table[j][k] + values.get(j) * values.get(k) * values.get(j + i) + table[k][j + i];
if (val < ret) ret = val;
table[j][j + i] = ret;
return table[0][values.size() - 1];

NumberOfDiscIntersections overflow in codility test

In the codility test NumberOfDiscIntersections I am getting perf 100% and correctness 87% with the one test failing being
arithmetic overflow tests
got -1 expected 2
I can't see what is causing that given that I am using long which is 64-bit. And even if I can get it to 100% perf 100% correctness I am wondering if there is a better way to do this that is not as verbose in Java.
edit: figured out a much better way to do with with two arrays rather than a pair class
// you can also use imports, for example:
import java.util.*;
// you can use System.out.println for debugging purposes, e.g.
// System.out.println("this is a debug message");
class Solution {
public int solution(int[] A) {
int j = 0;
Pair[] arr = new Pair[A.length * 2];
for (int i = 0; i < A.length; i++) {
Pair s = new Pair(i - A[i], true);
arr[j] = s;
Pair e = new Pair(i + A[i], false);
arr[j] = e;
Arrays.sort(arr, new Pair(0, true));
long numIntersect = 0;
long currentCount = 0;
for (Pair p: arr) {
if (p.start) {
numIntersect += currentCount;
if (numIntersect > 10000000) {
return -1;
} else {
return (int) numIntersect;
static private class Pair implements Comparator<Pair> {
private long x;
private boolean start;
public Pair(long x, boolean start) {
this.x = x;
this.start = start;
public int compare(Pair p1, Pair p2) {
if (p1.x < p2.x) {
return -1;
} else if (p1.x > p2.x) {
return 1;
} else {
if (p1.start && p2.start == false) {
return -1;
} else if (p1.start == false && p2.start) {
return 1;
} else {
return 0;
Look at this line:
Pair s = new Pair(i + A[i], true);
This is equivalent with Pair s = new Pair((long)(i + A[i]) , true);
As i is integer, and A[i] is also integer, so this can cause overflow, as value in A[i] can go up to Integer.MAX_VALUE, and the cast to long happened after add operation is completed.
To fix:
Pair s = new Pair((long)i + (long)A[i], true);
Note: I have submitted with my fixed and got 100%
My todays solution. O(N) time complexity. Simple assumption that number of availble pairs in next point of the table is difference between total open circle to that moment (circle) and circles that have been processed before. Maybe it's to simple :)
public int solution04(int[] A) {
final int N = A.length;
final int M = N + 2;
int[] left = new int[M]; // values of nb of "left" edges of the circles in that point
int[] sleft = new int[M]; // prefix sum of left[]
int il, ir; // index of the "left" and of the "right" edge of the circle
for (int i = 0; i < N; i++) { // counting left edges
il = tl(i, A);
sleft[0] = left[0];
for (int i = 1; i < M; i++) {// counting prefix sums for future use
int o, pairs, total_p = 0, total_used=0;
for (int i = 0; i < N; i++) { // counting pairs
ir = tr(i, A, M);
o = sleft[ir]; // nb of open till right edge
pairs = o -1 - total_used;
total_p += pairs;
if(total_p > 10000000){
total_p = -1;
return total_p;
int tl(int i, int[] A){
int tl = i - A[i]; // index of "begin" of the circle
if (tl < 0) {
tl = 0;
} else {
tl = i - A[i] + 1;
return tl;
int tr(int i, int[] A, int M){
int tr; // index of "end" of the circle
if (Integer.MAX_VALUE - i < A[i] || i + A[i] >= M - 1) {
tr = M - 1;
} else {
tr = i + A[i] + 1;
return tr;
My take on this, O(n):
public int solution(int[] A) {
int[] startPoints = new int[A.length];
int[] endPoints = new int[A.length];
int tempPoint;
int currOpenCircles = 0;
long pairs = 0;
//sum of starting and end points - how many circles open and close at each index?
for(int i = 0; i < A.length; i++){
tempPoint = i - A[i];
startPoints[tempPoint < 0 ? 0 : tempPoint]++;
tempPoint = i + A[i];
if(A[i] < A.length && tempPoint < A.length) //first prevents int overflow, second chooses correct point
//find all pairs of new circles (combinations), then make pairs with exiting circles (multiplication)
for(int i = 0; i < A.length; i++){
if(startPoints[i] >= 2)
pairs += (startPoints[i] * (startPoints[i] - 1)) / 2;
pairs += currOpenCircles * startPoints[i];
currOpenCircles += startPoints[i];
currOpenCircles -= endPoints[i];
if(pairs > 10000000)
return -1;
return (int) pairs;
The explanation to Helsing's solution part:
if(startPoints[i] >= 2) pairs += (startPoints[i] * (startPoints[i] - 1)) / 2;
is based on mathematical combinations formula:
Cn,m = n! / ((n-m)!.m!
for pairs, m=2 then:
Cn,2 = n! / ((n-2)!.2
Equal to:
Cn,2 = n.(n-1).(n-2)! / ((n-2)!.2
By simplification:
Cn,2 = n.(n-1) / 2
Not a very good performance, but using streams.
List<Long> list = IntStream.range(0, A.length).mapToObj(i -> Arrays.asList((long) i - A[i], (long) i + A[i]))
.sorted((o1, o2) -> {
int f = o1.get(0).compareTo(o2.get(0));
return f == 0 ? o1.get(1).compareTo(o2.get(1)) : f;
(acc, val) -> {
if (acc.isEmpty()) {
} else {
Iterator it = acc.iterator();
while (it.hasNext()) {
long el = (long) it.next();
if (val.get(0) <= el) {
long count = acc.get(0);
acc.set(0, ++count);
} else {
return (int) (list.isEmpty() ? 0 : list.get(0) > 10000000 ? -1 : list.get(0));
This one in Python passed all "Correctness tests" and failed all "Performance tests" due to O(n²), so I got 50% score. But it is very simple to understand. I just used the right radius (maximum) and checked if it was bigger or equal to the left radius (minimum) of the next circles. I also avoided to use sort and did not check twice the same circle. Later I will try to improve performance, but the problem here for me was the algorithm. I tried to find a very easy solution to help explain the concept. Maybe this will help someone.
def solution(A):
n = len(A)
cnt = 0
for j in range(1,n):
for i in range(n-j):
return -1
return cnt

Implementing a binary insertion sort using binary search in Java

I'm having trouble combining these two algorithms together. I've been asked to modify Binary Search to return the index that an element should be inserted into an array. I've been then asked to implement a Binary Insertion Sort that uses my Binary Search to sort an array of randomly generated ints.
My Binary Search works the way it's supposed to, returning the correct index whenever I test it alone. I wrote out Binary Insertion Sort to get a feel for how it works, and got that to work as well. As soon as I combine the two together, it breaks. I know I'm implementing them incorrectly together, but I'm not sure where my problem lays.
Here's what I've got:
public class Assignment3
public static void main(String[] args)
int[] binary = { 1, 7, 4, 9, 10, 2, 6, 12, 3, 8, 5 };
static int ModifiedBinarySearch(int[] theArray, int theElement)
int leftIndex = 0;
int rightIndex = theArray.length - 1;
int middleIndex = 0;
while(leftIndex <= rightIndex)
middleIndex = (leftIndex + rightIndex) / 2;
if (theElement == theArray[middleIndex])
return middleIndex;
else if (theElement < theArray[middleIndex])
rightIndex = middleIndex - 1;
leftIndex = middleIndex + 1;
return middleIndex - 1;
static void ModifiedBinaryInsertionSort(int[] theArray)
int i = 0;
int[] returnArray = new int[theArray.length + 1];
for(i = 0; i < theArray.length; i++)
returnArray[ModifiedBinarySearch(theArray, theArray[i])] = theArray[i];
for(i = 0; i < theArray.length; i++)
System.out.print(returnArray[i] + " ");
The return value I get for this when I run it is 1 0 0 0 0 2 0 0 3 5 12. Any suggestions?
UPDATE: updated ModifiedBinaryInsertionSort
static void ModifiedBinaryInsertionSort(int[] theArray)
int index = 0;
int element = 0;
int[] returnArray = new int[theArray.length];
for (int i = 1; i < theArray.lenght - 1; i++)
element = theArray[i];
index = ModifiedBinarySearch(theArray, 0, i, element);
returnArray[i] = element;
while (index >= 0 && theArray[index] > element)
theArray[index + 1] = theArray[index];
index = index - 1;
returnArray[index + 1] = element;
Here is my method to sort an array of integers using binary search.
It modifies the array that is passed as argument.
public static void binaryInsertionSort(int[] a) {
if (a.length < 2)
for (int i = 1; i < a.length; i++) {
int lowIndex = 0;
int highIndex = i;
int b = a[i];
//while loop for binary search
while(lowIndex < highIndex) {
int middle = lowIndex + (highIndex - lowIndex)/2; //avoid int overflow
if (b >= a[middle]) {
lowIndex = middle+1;
else {
highIndex = middle;
//replace elements of array
System.arraycopy(a, lowIndex, a, lowIndex+1, i-lowIndex);
a[lowIndex] = b;
How an insertion sort works is, it creates a new empty array B and, for each element in the unsorted array A, it binary searches into the section of B that has been built so far (From left to right), shifts all elements to the right of the location in B it choose one right and inserts the element in. So you are building up an at-all-times sorted array in B until it is the full size of B and contains everything in A.
Two things:
One, the binary search should be able to take an int startOfArray and an int endOfArray, and it will only binary search between those two points. This allows you to make it consider only the part of array B that is actually the sorted array.
Two, before inserting, you must move all elements one to the right before inserting into the gap you've made.
I realize this is old, but the answer to the question is that, perhaps a little unintuitively, "Middleindex - 1" will not be your insertion index in all cases.
If you run through a few cases on paper the problem should become apparent.
I have an extension method that solves this problem. To apply it to your situation, you would iterate through the existing list, inserting into an empty starting list.
public static void BinaryInsert<TItem, TKey>(this IList<TItem> list, TItem item, Func<TItem, TKey> sortfFunc)
where TKey : IComparable
if (list == null)
throw new ArgumentNullException("list");
int min = 0;
int max = list.Count - 1;
int index = 0;
TKey insertKey = sortfFunc(item);
while (min <= max)
index = (max + min) >> 1;
TItem value = list[index];
TKey compKey = sortfFunc(value);
int result = compKey.CompareTo(insertKey);
if (result == 0)
if (result > 0)
max = index - 1;
min = index + 1;
if (index <= 0)
index = 0;
else if (index >= list.Count)
index = list.Count;
if (sortfFunc(list[index]).CompareTo(insertKey) < 0)
list.Insert(index, item);
Dude, I think you have some serious problem with your code. Unfortunately, you are missing the fruit (logic) of this algorithm. Your divine goal here is to get the index first, insertion is a cake walk, but index needs some sweat. Please don't see this algorithm unless you gave your best and desperate for it. Never give up, you already know the logic, your goal is to find it in you. Please let me know for any mistakes, discrepancies etc. Happy coding!!
public class Insertion {
private int[] a;
int n;
int c;
public Insertion()
a = new int[10];
int find(int key)
int lowerbound = 0;
int upperbound = n-1;
c = (lowerbound + upperbound)/2;
return 0;
return c++;
return c;
if(a[c]>key && a[c-1]<key)
return c;
else if (a[c]<key && a[c+1]>key)
return c++;
upperbound = c-1;
lowerbound = c+1;
void insert(int key)
for(int k=n;k>c;k--)
void display()
for(int i=0;i<10;i++)
public static void main(String[] args)
Insertion i=new Insertion();

