Understanding Miller Rabin implementation - java

I'm learning about Miller Rabin, and I'm looking at the following implementation of the algorithm from https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Primality_Testing#Java
I feel like I have an okay understanding of the algorithm, but the implementation is not very easy to follow mainly because of the lack of documentation. It would be very helpful if someone could walk through the code and explain what we're doing at each step, and why. Referencing the algorithm would be very helpful.
Algorithm:
Input: n > 3, an odd integer to be tested for primality;
Input: k, a parameter that determines the accuracy of the test
Output: composite if n is composite, otherwise probably prime
write n − 1 as 2s·d with d odd by factoring powers of 2 from n − 1
LOOP: repeat k times:
pick a randomly in the range [2, n − 2]
x ← a^d mod n
if x = 1 or x = n − 1 then do next LOOP
for r = 1 .. s − 1
x ← x^2 mod n
if x = 1 then return composite
if x = n − 1 then do next LOOP
return composite
return probably prime
Implentation:
import java.math.BigInteger;
import java.util.Random;
public class MillerRabin {
private static final BigInteger ZERO = BigInteger.ZERO;
private static final BigInteger ONE = BigInteger.ONE;
private static final BigInteger TWO = new BigInteger("2");
private static final BigInteger THREE = new BigInteger("3");
public static boolean isProbablePrime(BigInteger n, int k) {
if (n.compareTo(THREE) < 0)
return true;
int s = 0;
BigInteger d = n.subtract(ONE); // n-1
while (d.mod(TWO).equals(ZERO)) { //?
s++; //?
d = d.divide(TWO); //?
}
for (int i = 0; i < k; i++) { //LOOP: repeat k times
BigInteger a = uniformRandom(TWO, n.subtract(ONE)); //?
BigInteger x = a.modPow(d, n); //x = a^d mod n
if (x.equals(ONE) || x.equals(n.subtract(ONE))) // if x=1 or x = n-1, then do next LOOP
continue;
int r = 1;
for (; r < s; r++) { // for r = 1..s-1
x = x.modPow(TWO, n); //x = x ^ 2 mod n
if (x.equals(ONE)) //if x = 1, return false (composite
return false;
if (x.equals(n.subtract(ONE))) //if x= n-1, look at the next a
break;
}
if (r == s) // None of the steps made x equal n-1.
return false; //we've exhausted all of our a values, probably composite
}
return true; //probably prime
}
//this method is just to generate a random int
private static BigInteger uniformRandom(BigInteger bottom, BigInteger top) {
Random rnd = new Random();
BigInteger res;
do {
res = new BigInteger(top.bitLength(), rnd);
} while (res.compareTo(bottom) < 0 || res.compareTo(top) > 0);
return res;
}

This part of the code
while (d.mod(TWO).equals(ZERO)) { //?
s++; //?
d = d.divide(TWO); //?
}
corresponds to
write n − 1 as 2^s·d with d odd by factoring powers of 2 from n − 1
As long as d is even, it is divided by 2 and s is incremented. After the loop d must be odd and s holds the number of factors 2 in n-1.
And this part
BigInteger a = uniformRandom(TWO, n.subtract(ONE)); //?
implements
pick a randomly in the range [2, n − 2]

Related

Time Complexity Improvement for Project Euler Solution

This is my solution to the first question in Project Euler.
Could someone please help reduce the time complexity of this working code?
Problem:
If we list all the natural numbers below 10 that are multiples of 3 or 5, we get 3, 5, 6 and 9. The sum of these multiples is 23.
Find the sum of all the multiples of 3 or 5 below 1000.
public class Sum {
private static final int n = 1000;
public static void main(String[] args) {
for (int i = 1, sum = 0; i <= n; i++) {
if ((i % 3 == 0) || (i % 5 == 0)) {
System.out.println(sum += i);
}
}
}
}
I am using the formula provided by phatfingers.
You can have numbers that are divisible by 15 (3 * 5), so you kinda have to subtract that amount.
Because the formula without k works for all natural numbers up to n, you multiply the thing by k. But this expands your scale by a factor of k, so I divided n by k (automatically being rounded down).
public class Sum {
private static final int n = 15;
public static void main(String[] args) {
int result = compute(3, n) + compute(5, n) - compute(15, n);
System.out.println(result);
}
private static int compute(int k, int n) {
n = n / k;
return k * n * (n + 1) / 2;
}
}
Note:
You can declare your variables inside of the for-loop
You check i < n (n exclusive) instead of i <= n (n inclusive, which I think you want here). If you don't want this, change n = n / k to n = (n - 1) / k.

Turn recursive function into iterative version

i need to turn the following recursive code into an iterative version and my mind is fried. i feel like im missing something obvious. any help is appreciated.
public static int computeRecursive(int n){
if(n <= 1){
return n;
}
else{
return 2 * computeRecursive(n-2) + computeRecursive(n-1);
}
}
This is similar to an iterative fibonacci series, where you hold the initial two values of your function f() in two variables a and b. Then compute the result for the current N off of those previous two results:
public static int f(int n) {
if ( n <= 1 ) {
return n;
}
int result = 0;
int a = 0, // f(0) = 0
b = 1; // f(1) = 1
// start iteration at n=2 because we already have f(0) and f(1)
for(int i = 2; i <= n; i++) {
// f(n) = 2 * f(n-2) + f(n-1)
result = 2 * a + b;
// f(n-2) = f(n-1)
a = b;
// f(n-1) = f(n)
b = result;
}
return result;
}
In my opinion both recursive and iterative solutions are weak if you can just apply your maths skills and work out the formula.
In this case we have: f(n) = (2 ** n -(-1) ** n)/3. Bellow is how you work it out.
f(0) = 0
f(1) = 1
f(n) = f(n-1) + 2 * f(n-2)
So the polynomial for this recurrence is:
r ** 2 = r + 2
If you sole that you will get the values of r as r1 =−1 and r2 =2
So the solution to the recurrence is on the form:
f(n) = c1 ∗ r1 ** n + c2 ∗ r2 ** n
To work out the values for c1 and c2 constants just use the initial condition f(0) = 0 and f(1) = 1 and you will get
c1 = -1/3 and c2 = 1/3
So the final formula for your iteration is
f(n) = (-1 * (-1) ** n + 2 ** n)/3 = (2 ** n -(-1) ** n)/3.
Once you know the formula implementing it in java or any other language is easy.
public static int f(int n) {
return n <= 1 ? n: (Math.pow(2,n) - Math.pow(-1, n)) / 3;
}
Hope it helped
You can may be try the below code. It is similar to fibonacci series.
public static int computeRecursive(int n){
int a[]=new int[n];
a[0]=1; a[1]=1;
for(int i=2;i<n;i++){
a[i]=2*a[i-2]+a[i-1];
}
return a[n-1];
}

nth Binary palindrome with efficient time complexity

Given an integer N, i am trying to find the nth binary palindrome.I have written the following code but it is not efficient.is there a more efficient way in terms of time complexity.
I was trying it out as a problem online and i was supposed to output in 1 sec or less but for every input it takes 2 seconds.
public static Boolean Palind(String n){
String reverse = "";
int length = n.length();
for(int i = length - 1; i >=0;i--){
reverse = reverse + n.charAt(i);
}
if(n.equals(reverse)){
return true;
}
else{
return false;
}
}
public static int Magical(int n){
ArrayList<Integer> res = new ArrayList<Integer>();
for(int i = 1; i < Math.pow(2, n);i++){
if(Palind(Integer.toBinaryString(i))){
res.add(i);
}
}
return res.get(n-1);
}
The relevant OEIS entry (A006995) has a lot of nice tips if you read through it. For example, a(2^n-1)=2^(2n-2)-1 lets you skip right to the (2n - 1)th palindrome really quickly.
It also gives several implementations. For example, the Smalltalk implementation works like this (note that the input value, n, starts with 1 for the first palindrome, 0):
public static final int nthBooleanPalindrome(int n) {
if (n == 1) return 0;
if (n == 2) return 1;
int m = 31 - Integer.numberOfLeadingZeros(n);
int c = 1 << (m - 1);
int b;
if (n >= 3*c) {
int a = n - 3*c;
int d = 2*c*c;
b = d + 1;
int k2 = 1;
for (int i = 1; i < m; i++) {
k2 <<= 1;
b += a*k2/c%2*(k2 + d/k2);
}
}
else {
int a = n - 2*c;
int d = c*c;
b = d + 1 + (n%2*c);
int k2 = 1;
for (int i = 1; i < m - 1; i++) {
k2 <<= 1;
b += a*k2/c%2*(k2 + d/k2);
}
}
return b;
}
Try something like this maybe?
public static void main(String[] args) {
for (int i = 1; i < 65535; i++) {
System.out.println(
i + ": " + getBinaryPalindrom(i) + " = " + Integer.toBinaryString(getBinaryPalindrom(i)));
}
}
public static int getBinaryPalindrom(int N) {
if (N < 4) {
switch (N) {
case 1:
return 0;
case 2:
return 1;
case 3:
return 3;
}
throw new IndexOutOfBoundsException("You need to supply N >= 1");
}
// second highest to keep the right length (highest is always 1)
final int bitAfterHighest = (N >>> (Integer.SIZE - Integer.numberOfLeadingZeros(N) - 2)) & 1;
// now remove the second highest bit to get the left half of our palindrom
final int leftHalf = (((N >>> (Integer.SIZE - Integer.numberOfLeadingZeros(N) - 1)) & 1) << (Integer.SIZE -
Integer.numberOfLeadingZeros(N) - 2)) | ((N << (Integer.numberOfLeadingZeros(N) + 2)) >>> (Integer.numberOfLeadingZeros(N) + 2));
// right half is just the left reversed
final int rightHalf = Integer.reverse(leftHalf);
if (Integer.numberOfLeadingZeros(leftHalf) < Integer.SIZE / 2) {
throw new IndexOutOfBoundsException("To big to fit N=" + N + " into an int");
}
if (bitAfterHighest == 0) {
// First uneven-length palindromes
return (leftHalf << (Integer.SIZE - Integer.numberOfLeadingZeros(leftHalf)) - 1) | (rightHalf
>>> Integer.numberOfTrailingZeros(rightHalf));
} else {
// Then even-length palindromes
return (leftHalf << (Integer.SIZE - Integer.numberOfLeadingZeros(leftHalf))) | (rightHalf
>>> Integer.numberOfTrailingZeros(rightHalf));
}
}
The idea is that each number will become a palindrome once it reverse is added. To have the halves correctly aligned the halves just need to be shifted in place.
The problem why this has gotten a bit complex is that all uneven-length palindromes of a given leftHalf length come before all even-length palindromes of a given leftHalf length. Feel free to provide a better solution.
As int has 32 bit in Java there is a limit on N.
int-Version on ideone.com
And a BigInteger-version to support big values. It is not as fast as the int-version as the byte[]-arrays which store the value of the BigInteger create some overhead.
public static void main(String[] args) {
for (BigInteger i = BigInteger.valueOf(12345678); i.compareTo(BigInteger.valueOf(12345778)) < 0; i = i
.add(BigInteger
.ONE)) {
final BigInteger curr = getBinaryPalindrom(i);
System.out.println(i + ": " + curr + " = " + curr.toString(2));
}
}
public static BigInteger getBinaryPalindrom(BigInteger n) {
if (n.compareTo(BigInteger.ZERO) <= 0) {
throw new IndexOutOfBoundsException("You need to supply N >= 1");
} else if (n.equals(BigInteger.valueOf(1))) {
return BigInteger.valueOf(0);
} else if (n.equals(BigInteger.valueOf(2))) {
return BigInteger.valueOf(1);
} else if (n.equals(BigInteger.valueOf(3))) {
return BigInteger.valueOf(3);
}
final int bitLength = n.bitLength() - 1;
// second highest to keep the right length (highest is always 1)
final boolean bitAfterHighest = n.testBit(bitLength - 1);
// now remove the second highest bit to get the left half of our palindrom
final BigInteger leftHalf = n.clearBit(bitLength).setBit(bitLength - 1);
// right half is just the left reversed
final BigInteger rightHalf;
{
byte[] inArray = leftHalf.toByteArray();
byte[] outArray = new byte[inArray.length];
final int shiftOffset = Integer.SIZE - Byte.SIZE;
for (int i = 0; i < inArray.length; i++) {
outArray[inArray.length - 1 - i] = (byte) (Integer.reverse(inArray[i]) >>> shiftOffset);
}
rightHalf = new BigInteger(1, outArray).shiftRight(outArray.length * Byte.SIZE - bitLength);
}
if (!bitAfterHighest) {
// First uneven-length palindromes
return leftHalf.shiftLeft(bitLength - 1).or(rightHalf);
} else {
// Then even-length palindromes
return leftHalf.shiftLeft(bitLength).or(rightHalf);
}
}
I have the same idea with #Kiran Kumar: you should not count number one by one to find if it is a binary palindrome which is too slow, but rather find the internal pattern that number has.
List the number in binary string one by one, you can find the pattern:
0
1
11
101
1001
1111
...
1......1
And the following is some math problem:
We have 2^round_up((L-2)/2) palindrome of number with length L in binary format.
Sum up every shorter length number, we get following len to sum mapping:
for (int i = 1; i < mapping.length; i++) {
mapping[i] = (long) (mapping[i - 1] + Math.pow(2, Math.ceil((i - 1) * 1.0 / 2)));
}
If we find N range in [count(L), count(L+1)), we can concat it with remaining number:
public static long magical(long n) {
if (n == 0 || n == 1) {
return n;
}
long N = n - 2;
return Long.parseLong(concat(N), 2);
}
private static String concat(long N) {
int midLen = Arrays.binarySearch(indexRange, N);
if (midLen < 0) {
midLen = -midLen - 1;
}
long remaining = N - indexRange[midLen];
String mid = mirror(remaining, midLen);
return '1' + mid + '1';
}
private static String mirror(long n, int midLen) {
int halfLen = (int) Math.ceil(midLen * 1.0 / 2);
// produce fixed length binary string
final String half = Long.toBinaryString(n | (1 << halfLen)).substring(1);
if (midLen % 2 == 0) {
return half + new StringBuilder(half).reverse().toString();
} else {
return half + new StringBuilder(half).reverse().toString().substring(1);
}
}
Full code with test for produce large possible long can be found in my git repo.
Idea to optimize,
Let's look at the palindrome sequence 0, 1, 11, 101, 111, 1001 etc...
All numbers must begin and end with 1, So the middle bits only changes and midle substring should be palindrome for full string to become palindrome,
So let's take a 2 digit binary number - one palindrome is possible.
The binary of the decimal 3 is a palindrome. 11
For a 3 digit binary number 2 palindromes are possible, 2*(no of 1 digit palindrome)
The binary of the decimal 5 is a palindrome. 101
The binary of the decimal 7 is a palindrome. 111
For 5 digit binary number 4 palindromes are possible 2*(no of 3 digit palindrome)
10001,10101, 11011, 11111
and so on,
So it will be 2 + 20 + 21 + 22 +...... +2i-N ,
we solve for i and find out the palindrome number.
So by analysing this sequence we get an equation like 2(i/2)+1 -1 = N
where N is the No of palindrome,
and i is the number of bits in the nth palindrome string,
using this we can find the length of the String, from this we can find the string early.
This might be complex, but helps in solving higher values of N quickly....

How to find the floor and remainder of a fraction in Java?

Given integers 'a' and 'b', I would like a method that returns the floor and remainder of a / b such that:
a / b = floor + remainder / |b| (where |b| is the absolute value of b), and
0 <= remainder < |b|.
For example, 5/3 = 1 + 2/3.
Here's an attempt that works only for positive a and b:
public static long[] floorAndRemainder(long a, long b) {
long floor = a / b;
long remainder = a % b;
return new long[] { floor, remainder };
}
I need a function that works for all positive and negative numerators and denominators. For example,
-5/3 = -2 + 1/3
5/-3 = -2 + 1/3
-5/-3 = 1 + 2/3
Implementation 1: Floating Point
Uses floating point math to simplify the logic. Be warned that this will produce incorrect results for large numbers due to loss of precision.
public static long[] floorAndRemainder(long a, long b) {
long floor = (long) Math.floor(a / (double) b);
long remainder = Math.abs(a - floor * b);
return new long[] { floor, remainder };
}
Implementation 2: Find, then Correct
Finds the floor and remainder using integer division and modulus operators, then corrects for negative fractions. This shows that the remainder is relatively difficult to correct without using the floor.
public static long[] floorAndRemainder(long a, long b) {
long floor = a / b;
long remainder = a % b;
boolean isNegative = a < 0 ^ b < 0;
boolean hasRemainder = remainder != 0;
// Correct the floor.
if (isNegative && hasRemainder) {
floor--;
}
// Correct the remainder.
if (hasRemainder) {
if (isNegative) {
if (a < 0) { // then remainder < 0 and b > 0
remainder += b;
} else { // then remainder > 0 and b < 0
remainder = -remainder - b;
}
} else {
if (remainder < 0) {
remainder = -remainder;
}
}
}
return new long[] { floor, remainder };
}
Implementation 3: The Best Option
Finds the floor the same way as Implementation 2, then uses the floor to find the remainder like Implementation 1.
public static long[] floorAndRemainder(long a, long b) {
long floor = a / b;
if ((a < 0 ^ b < 0) && a % b != 0) {
floor--;
}
long remainder = Math.abs(a - floor * b);
return new long[] { floor, remainder };
}

Class that checks for primes not working properly

I made a class called Primes for everything related to prime numbers. It contains a method called isPrime, which uses another method called sieveOfAtkin that creates a boolean array called sieve that has index values from 0 to 1000000. The user passes an integer n to the isPrime method. If sieve[n]=true, then the integer n is a prime number. Otherwise isPrime returns false. My problem is that when I tested this method using numbers that I know are prime, it always returns false. Take this line of code for example that tests whether 13 is a prime number:
public class Test {
public static void main(String[] args) {
Primes pr=new Primes(); // Creates Primes object
System.out.println(pr.isPrime(13));
}
}
The output is false, even though we know that 13 is a prime number. Here is my code for the entire Primes class https://github.com/javtastic/project_euler/blob/master/Primes.java
It uses a sieve of atkin, which is supposed to be the most efficient method of testing for primes. More info on that can be found here: http://en.wikipedia.org/wiki/Sieve_of_Atkin
I'm not entirely sure what I am doing wrong. I have been trying for hours to figure out what is causing this error and I still get the same results (everything is false). Perhaps I should find a different way of checking primes?
Use this:
public static boolean isPrime(int number) {
int sqrt = (int) Math.sqrt(number) + 1;
for (int i = 2; i < sqrt; i++) {
if (number % i == 0) {
return false;
}
}
return true;
}
The code works like a charm for me, no changes done.
code
import java.util.Arrays;
public class Test {
// This class uses a Atkin sieve to determine all prime numbers less than int limit
// See http://en.wikipedia.org/wiki/Sieve_of_Atkin
private static final int limit = 1000000;
private final static boolean[] sieve = new boolean[limit + 1];
private static int limitSqrt = (int) Math.sqrt(limit);
public static void sieveOfAtkin() {
// there may be more efficient data structure
// arrangements than this (there are!) but
// this is the algorithm in Wikipedia
// initialize results array
Arrays.fill(sieve, false);
// the sieve works only for integers > 3, so
// set these trivially to their proper values
sieve[0] = false;
sieve[1] = false;
sieve[2] = true;
sieve[3] = true;
// loop through all possible integer values for x and y
// up to the square root of the max prime for the sieve
// we don't need any larger values for x or y since the
// max value for x or y will be the square root of n
// in the quadratics
// the theorem showed that the quadratics will produce all
// primes that also satisfy their wheel factorizations, so
// we can produce the value of n from the quadratic first
// and then filter n through the wheel quadratic
// there may be more efficient ways to do this, but this
// is the design in the Wikipedia article
// loop through all integers for x and y for calculating
// the quadratics
for (int x = 1 ; x <= limitSqrt ; x++) {
for (int y = 1 ; y <= limitSqrt ; y++) {
// first quadratic using m = 12 and r in R1 = {r : 1, 5}
int n = (4 * x * x) + (y * y);
if (n <= limit && (n % 12 == 1 || n % 12 == 5)) {
sieve[n] = !sieve[n];
}
// second quadratic using m = 12 and r in R2 = {r : 7}
n = (3 * x * x) + (y * y);
if (n <= limit && (n % 12 == 7)) {
sieve[n] = !sieve[n];
}
// third quadratic using m = 12 and r in R3 = {r : 11}
n = (3 * x * x) - (y * y);
if (x > y && n <= limit && (n % 12 == 11)) {
sieve[n] = !sieve[n];
}
// note that R1 union R2 union R3 is the set R
// R = {r : 1, 5, 7, 11}
// which is all values 0 < r < 12 where r is
// a relative prime of 12
// Thus all primes become candidates
}
}
// remove all perfect squares since the quadratic
// wheel factorization filter removes only some of them
for (int n = 5 ; n <= limitSqrt ; n++) {
if (sieve[n]) {
int x = n * n;
for (int i = x ; i <= limit ; i += x) {
sieve[i] = false;
}
}
}
}
// isPrime method checks to see if a number is prime using sieveOfAtkin above
// Works since sieve[0] represents the integer 0, sieve[1]=1, etc
public static boolean isPrime(int n) {
sieveOfAtkin();
return sieve[n];
}
public static void main(String[] args) {
System.out.println(isPrime(13));
}
}
Input - 13, Output - True
Input -23, Output - True
Input - 33. output - False

Categories

Resources