In my computer architecture class I just learned that running an algebraic expression involving multiplication through a multiplication circuit can be more costly than running it though an addition circuit, if the number of required multiplications are less than 3. Ex: 3x. If I'm doing this type of computation a few billion times, does it pay off to write it as: x + x + x or does the JIT optimizer optimize for this?
I wouldn't expect to be a huge difference on writing it one way or the other.
The compiler will probably take care of making all of those equivalent.
You can try each method and measure how long it takes, that could give you a good hint to answer your own question.
Here's some code that does the same calculations 10 million times using different approaches (x + x + x, 3*x, and a bit shift followed by a subtraction).
They seem to all take approx the same amount of time as measured by System.nanoTime.
Sample output for one run:
sum : 594599531
mult : 568783654
shift : 564081012
You can also take a look at this question that talks about how compiler's optimization can probably handle those and more complex cases: Is shifting bits faster than multiplying and dividing in Java? .NET?
Code:
import java.util.Random;
public class TestOptimization {
public static void main(String args[]) {
Random rn = new Random();
long l1 = 0, l2 = 0, l3 = 0;
long nano1 = System.nanoTime();
for (int i = 1; i < 10000000; i++) {
int num = rn.nextInt(100);
l1 += sum(num);
}
long nano2 = System.nanoTime();
for (int i = 1; i < 10000000; i++) {
int num = rn.nextInt(100);
l2 += mult(num);
}
long nano3 = System.nanoTime();
for (int i = 1; i < 10000000; i++) {
int num = rn.nextInt(100);
l3 += shift(num);
}
long nano4 = System.nanoTime();
System.out.println(l1);
System.out.println(l2);
System.out.println(l3);
System.out.println("sum : " + (nano2 - nano1));
System.out.println("mult : " + (nano3 - nano2));
System.out.println("shift : " + (nano4 - nano3));
}
private static long sum(long x) {
return x + x + x;
}
private static long mult(long x) {
return 3 * x;
}
private static long shift(long x) {
return (x << 2) - x;
}
}
Related
So I have a task to calculate Euler's number using multiple threads, using this formula: sum( ((3k)^2 + 1) / ((3k)!) ), for k = 0...infinity.
import java.math.BigDecimal;
import java.math.BigInteger;
import java.io.FileWriter;
import java.io.IOException;
import java.math.RoundingMode;
class ECalculator {
private BigDecimal sum;
private BigDecimal[] series;
private int length;
public ECalculator(int threadCount) {
this.length = threadCount;
this.sum = new BigDecimal(0);
this.series = new BigDecimal[threadCount];
for (int i = 0; i < this.length; i++) {
this.series[i] = BigDecimal.ZERO;
}
}
public synchronized void addToSum(BigDecimal element) {
this.sum = this.sum.add(element);
}
public void addToSeries(int id, BigDecimal element) {
if (id - 1 < length) {
this.series[id - 1] = this.series[id - 1].add(element);
}
}
public synchronized BigDecimal getSum() {
return this.sum;
}
public BigDecimal getSeriesSum() {
BigDecimal result = BigDecimal.ZERO;
for (int i = 0; i < this.length; i++) {
result = result.add(this.series[i]);
}
return result;
}
}
class ERunnable implements Runnable {
private final int id;
private final int threadCount;
private final int threadRemainder;
private final int elements;
private final boolean quietFlag;
private ECalculator eCalc;
public ERunnable(int threadCount, int threadRemainder, int id, int elements, boolean quietFlag, ECalculator eCalc) {
this.id = id;
this.threadCount = threadCount;
this.threadRemainder = threadRemainder;
this.elements = elements;
this.quietFlag = quietFlag;
this.eCalc = eCalc;
}
#Override
public void run() {
if (!quietFlag) {
System.out.println(String.format("Thread-%d started.", this.id));
}
long start = System.currentTimeMillis();
int k = this.threadRemainder;
int iteration = 0;
BigInteger currentFactorial = BigInteger.valueOf(intFactorial(3 * k));
while (iteration < this.elements) {
if (iteration != 0) {
for (int i = 3 * (k - threadCount) + 1; i <= 3 * k; i++) {
currentFactorial = currentFactorial.multiply(BigInteger.valueOf(i));
}
}
this.eCalc.addToSeries(this.id, new BigDecimal(Math.pow(3 * k, 2) + 1).divide(new BigDecimal(currentFactorial), 100, RoundingMode.HALF_UP));
iteration += 1;
k += this.threadCount;
}
long stop = System.currentTimeMillis();
if (!quietFlag) {
System.out.println(String.format("Thread-%d stopped.", this.id));
System.out.println(String.format("Thread %d execution time: %d milliseconds", this.id, stop - start));
}
}
public int intFactorial(int n) {
int result = 1;
for (int i = 1; i <= n; i++) {
result *= i;
}
return result;
}
}
public class TaskRunner {
public static final String DEFAULT_FILE_NAME = "result.txt";
public static void main(String[] args) throws InterruptedException {
int threadCount = 2;
int precision = 10000;
int elementsPerTask = precision / threadCount;
int remainingElements = precision % threadCount;
boolean quietFlag = false;
calculate(threadCount, elementsPerTask, remainingElements, quietFlag, DEFAULT_FILE_NAME);
}
public static void writeResult(String filename, String result) {
try {
FileWriter writer = new FileWriter(filename);
writer.write(result);
writer.close();
} catch (IOException e) {
System.out.println("An error occurred.");
e.printStackTrace();
}
}
public static void calculate(int threadCount, int elementsPerTask, int remainingElements, boolean quietFlag, String outputFile) throws InterruptedException {
long start = System.currentTimeMillis();
Thread[] threads = new Thread[threadCount];
ECalculator eCalc = new ECalculator(threadCount);
for (int i = 0; i < threadCount; i++) {
if (i == 0) {
threads[i] = new Thread(new ERunnable(threadCount, i, i + 1, elementsPerTask + remainingElements, quietFlag, eCalc));
} else {
threads[i] = new Thread(new ERunnable(threadCount, i, i + 1, elementsPerTask, quietFlag, eCalc));
}
threads[i].start();
}
for (int i = 0; i < threadCount; i++) {
threads[i].join();
}
String result = eCalc.getSeriesSum().toString();
if (!quietFlag) {
System.out.println("E = " + result);
}
writeResult(outputFile, result);
long stop = System.currentTimeMillis();
System.out.println("Calculated in: " + (stop - start) + " milliseconds" );
}
}
I stripped out the prints, etc. in the code that have no effect. My problem is that the more threads I use the slower it gets. Currently the fastest run I have is for 1 thread. I am sure the factorial calculation is causing some issues. I tried using a thread pool but still got the same times.
How can I make it so that running it with more threads, up until some point, will speed up the calculation process?
How would one go about calculating this big factorials?
The precision parameter that is passed is the amount of elements in the sum that are used. Can I set the BigDecimal scale to be somehow dependent on that precision so I don't hard code it?
EDIT
I updated the code block to be in 1 file only and runnable without external libs.
EDIT 2
I found out that the factorial code messes up with the time. If I let the threads ramp up to some high precision without calculating factorials the time goes down with increasing threads. Yet I cannot implement the factorial calculating in any way while keeping the time decreasing.
EDIT 3
Adjusting code to address answers.
private static BigDecimal partialCalculator(int start, int threadCount, int id) {
BigDecimal nBD = BigDecimal.valueOf(start);
BigDecimal result = nBD.multiply(nBD).multiply(BigDecimal.valueOf(9)).add(BigDecimal.valueOf(1));
for (int i = start; i > 0; i -= threadCount) {
BigDecimal iBD = BigDecimal.valueOf(i);
BigDecimal iBD1 = BigDecimal.valueOf(i - 1);
BigDecimal iBD3 = BigDecimal.valueOf(3).multiply(iBD);
BigDecimal prevNumerator = iBD1.multiply(iBD1).multiply(BigDecimal.valueOf(9)).add(BigDecimal.valueOf(1));
// 3 * i * (3 * i - 1) * (3 * i - 2);
BigDecimal divisor = iBD3.multiply(iBD3.subtract(BigDecimal.valueOf(1))).multiply(iBD3.subtract(BigDecimal.valueOf(2)));
result = result.divide(divisor, 10000, RoundingMode.HALF_EVEN)
.add(prevNumerator);
}
return result;
}
public static void main(String[] args) {
int threadCount = 3;
int precision = 6;
ExecutorService executorService = Executors.newFixedThreadPool(threadCount);
ArrayList<Future<BigDecimal> > futures = new ArrayList<Future<BigDecimal> >();
for (int i = 0; i < threadCount; i++) {
int start = precision - i;
System.out.println(start);
final int id = i + 1;
futures.add(executorService.submit(() -> partialCalculator(start, threadCount, id)));
}
BigDecimal result = BigDecimal.ZERO;
try {
for (int i = 0; i < threadCount; i++) {
result = result.add(futures.get(i).get());
}
} catch (Exception e) {
e.printStackTrace();
}
executorService.shutdown();
System.out.println(result);
}
Seems to be working properly for 1 thread but messes up the calculation for multiple.
After a review of the updated code, I've made the following observations:
First of all, the program runs for a fraction of a second. That means that this is a micro benchmark. Several key features in Java make micro benchmarks difficult to implement reliably. See How do I write a correct micro-benchmark in Java? For example, if the program doesn't run enough repetitions, the "just in time" compiler doesn't have time to kick in to compile it to native code, and you end up benchmarking the intepreter. It seems possible that in your case the JIT compiler takes longer to kick in when there are multiple threads,
As an example, to make your program do more work, I changed the BigDecimal precision from 100 to 10,000 and added a loop around the main method. The execution times were measured as follows:
1 thread:
Calculated in: 2803 milliseconds
Calculated in: 1116 milliseconds
Calculated in: 1040 milliseconds
Calculated in: 1066 milliseconds
Calculated in: 1036 milliseconds
2 threads:
Calculated in: 2354 milliseconds
Calculated in: 856 milliseconds
Calculated in: 624 milliseconds
Calculated in: 659 milliseconds
Calculated in: 664 milliseconds
4 threads:
Calculated in: 1961 milliseconds
Calculated in: 797 milliseconds
Calculated in: 623 milliseconds
Calculated in: 536 milliseconds
Calculated in: 497 milliseconds
The second observation is that there is a significant part of the workload that does not benefit from multiple threads: every thread is computing every factorial. This means the speed-up cannot be linear - as described by Amdahl's law.
So how can we get the result without computing factorials? One way is with Horner's method. As an example, consider the simpler series sum(1/k!) which also conveges to e but a little slower than yours.
Let's say you want to compute sum(1/k!) up to k = 100. With Horner's method you start from the end and extract common factors:
sum(1/k!, k=0..n) = 1/100! + 1/99! + 1/98! + ... + 1/1! + 1/0!
= ((... (((1/100 + 1)/99 + 1)/98 + ...)/2 + 1)/1 + 1
See how you start with 1, divide by 100 and add 1, divide by 99 and add 1, divide by 98 and add 1, and so on? That makes a very simple program:
private static BigDecimal serialHornerMethod() {
BigDecimal accumulator = BigDecimal.ONE;
for (int k = 10000; k > 0; k--) {
BigDecimal divisor = new BigDecimal(k);
accumulator = accumulator.divide(divisor, 10000, RoundingMode.HALF_EVEN)
.add(BigDecimal.ONE);
}
return accumulator;
}
Ok that's a serial method, how do you make it use parallel? Here's an example for two threads: First split the series into even and odd terms:
1/100! + 1/99! + 1/98! + 1/97! + ... + 1/1! + 1/0! =
(1/100! + 1/98! + ... + 1/0!) + (1/99! + 1/97! + ... + 1/1!)
Then apply Horner's method to both the even and odd terms:
1/100! + 1/98! + 1/96! + ... + 1/2! + 1/0! =
((((1/(100*99) + 1)/(98*97) + 1)/(96*95) + ...)/(2*1) + 1
and:
1/99! + 1/97! + 1/95! + ... + 1/3! + 1/1! =
((((1/(99*98) + 1)/(97*96) + 1)/(95*94) + ...)/(3*2) + 1
This is just as easy to implement as the serial method, and you get pretty close to linear speedup going from 1 to 2 threads:
private static BigDecimal partialHornerMethod(int start) {
BigDecimal accumulator = BigDecimal.ONE;
for (int i = start; i > 0; i -= 2) {
int f = i * (i + 1);
BigDecimal divisor = new BigDecimal(f);
accumulator = accumulator.divide(divisor, 10000, RoundingMode.HALF_EVEN)
.add(BigDecimal.ONE);
}
return accumulator;
}
// Usage:
ExecutorService executorService = Executors.newFixedThreadPool(2);
Future<BigDecimal> submit = executorService.submit(() -> partialHornerMethod(10000));
Future<BigDecimal> submit1 = executorService.submit(() -> partialHornerMethod(9999));
BigDecimal result = submit1.get().add(submit.get());
There is a lot of contention between the threads: they all compete to get a lock on the ECalculator object after every little bit of computation, because of this method:
public synchronized void addToSum(BigDecimal element) {
this.sum = this.sum.add(element);
}
In general having threads compete for frequent access to a common resource leads to bad performance, because you're asking for the operating system to intervene and tell the program which thread can continue. I haven't tested your code to confirm that this is the issue because it's not self-contained.
To fix this, have the threads accumulate their results separately, and merge results after the threads have finished. That is, create a sum variable in ERunnable, and then change the methods:
// ERunnable.run:
this.sum = this.sum.add(new BigDecimal(Math.pow(3 * k, 2) + 1).divide(new BigDecimal(factorial(3 * k)), 100, RoundingMode.HALF_UP));
// TaskRunner.calculate:
for (int i = 0; i < threadCount; i++) {
threads[i].join();
eCalc.addToSum(/* recover the sum computed by thread */);
}
By the way would be easier if you used the higher level java.util.concurrent API instead of creating thread objects yourself. You could wrap the computation in a Callable which can return a result.
Q2 How would one go about calculating this big factorials?
Usually you don't. Instead, you reformulate the problem so that it does not involve the direct computation of factorials. One technique is Horner's method.
Q3 The precision parameter that is passed is the amount of elements in the sum that are used. Can I set the BigDecimal scale to be somehow dependent on that precision so I don't hard code it?
Sure why not. You can work out the error bound from the number of elements (it's proportional to the last term in the series) and set the BigDecimal scale to that.
I'm quite new to coding, as you all can see from the clumsy code below. However, looking at this code you can see what I'm getting at. The code basically does what its supposed to, but I would like to write it as a loop to make it more efficient. Could someone maybe point me in the right direction? I have done some digging and thought about recursion, but I haven't been able to figure out how to apply it here.
public static void main(String[] args) {
double a = 10;
double b = 2;
double c = 3;
double avg = (a + b + c)/3;
double avg1 = (avg + b + c)/3;
double avg2 = (avg1 + b + c)/3;
double avg3 = (avg2 + b + c)/3;
System.out.println(avg+ "\n" + avg1+ "\n"+ avg2 + "\n"+ avg3);
}
Functionally, this would be equivalent to what you have done:
public static void main(String[] args) {
double a = 10;
double b = 2;
double c = 3;
double avg = (a + b + c)/3;
System.out.println(avg);
for (int i=0; i<3; i++) {
avg = (avg + b + c)/3;
System.out.println(avg);
}
}
But also you should know that shorter code does not always mean efficient code. The solution may be more concise, but I doubt there will be any change in performance.
If you mean shorter code with efficieny you can do it like this.
public static void main(String[] args) {
double a = 10;
double b = 2;
double c = 3;
for (int i = 0; i < 4; i++) {
a = (a + b + c) / 3;
System.out.println(a);
}
}
I have no idea what this calculation represents (some sort of specialised weighted average?) but rather than use repetition and loops, you can reach the exact same calculation by using a bit of algebra and refactoring the terms:
public static double directlyCalculateWeightedAverage(double a, double b,
double c) {
return a / 81 + 40 * b / 81 + 40 * c / 81;
}
This reformulation is reached because the factor a appears just once in the mix and is then divided by 34 which is 81. Then each of b and c appear at various levels of division, so that b sums to this:
b/81 + b/27 + b/9 + b/3
== b/81 + 3b/81 + 9b/81 + 27b/81
== 40b/81
and c is treated exactly the same.
Which gives the direct calculation
a/81 + 40b/81 + 40c/81
Assuming your formula does not change, I'd recommend using this direct approach rather than resorting to repeated calculations and loops.
Your problem can be solved by 2 approaches: iterative (with a loop) or recursive (with a recursive function).
Iterative approach : for loop
The for loop allow you to repeat a group of instructions au given number of times.
In your case, you could write the following :
double a = 10, b = 2, c = 3;
double avg = a;
for (int i = 0; i < 4; i++) {
avg = (avg + b + c) / 3;
System.out.println(avg);
}
This will print the 4 first results of your calculation.
In my example, I overwrite the variable avg to only keep the last result, which might not be what you want. To keep the result of each loop iteration, you may store the result in an array.
Recursive approach
In Java, there is no such thing as a standalone function. In order to use recursion, you have to use static methods :
private static double recursiveAvg(double avg, int count) {
// Always check that the condition for your recursion is valid !
if (count == 0) {
return avg;
}
// Apply your formula
avg = (avg + 2 + 3) / 3;
// Call the same function with the new avg value, and decrease the iteration count.
return recursiveAvg(avg, count - 1);
}
public static void main(String[] args) {
// Start from a = 10, and repeat the operation 4 times.
double avg = recursiveAvg(10, 4);
System.out.println(avg);
}
Always check for a condition that will end the recursion. In our example, it's the number of times the operation should be performed.
Note that most programmers prefer the iterative approach : easier to write and read, and less error prone.
I had a problem where i had to calculate sum of large powers of numbers in an array and return the result.For example arr=[10,12,34,56] then output should be
10^1+12^2+34^3+56^4.Here the output could be very large so we were asked to take a mod of 10^10+11 on the output and then return it.I did it easily in python but in java initially i used BigInteger and got tle for half the test cases so i thought of using Long and then calculating power using modular exponential but then i got the wrong output to be precise all in negative as it obviously exceeded the limit.
Here is my code using Long and Modular exponential.
static long power(long x, long y, long p)
{
long res = 1; // Initialize result
x = x % p; // Update x if it is more than or
// equal to p
while (y > 0)
{
// If y is odd, multiply x with result
if ((y & 1)==1)
res = (res*x) % p;
// y must be even now
y = y>>1; // y = y/2
x = (x*x) % p;
}
return res;
}
static long solve(int[] a) {
// Write your code here
Long[] arr = new Long[a.length];
for (int i = 0; i < a.length; i++) {
arr[i] = setBits(new Long(a[i]));
}
Long Mod = new Long("10000000011");
Long c = new Long(0);
for (int i = 0; i < arr.length; i++) {
c += power(arr[i], new Long(i + 1),Mod) % Mod;
}
return c % Mod;
}
static long setBits(Long a) {
Long count = new Long(0);
while (a > 0) {
a &= (a - 1);
count++;
}
return count;
}
Then i also tried Binary Exponentiation but nothing worked for me.How do i achieve this without using big integer and as easily as i got it in python
You have added an extra zero the value of mod, it should be 1000000011.
Hope this will solve it
I read in couple of blogs that in Java modulo/reminder operator is slower than bitwise-AND. So, I wrote the following program to test.
public class ModuloTest {
public static void main(String[] args) {
final int size = 1024;
int index = 0;
long start = System.nanoTime();
for(int i = 0; i < Integer.MAX_VALUE; i++) {
getNextIndex(size, i);
}
long end = System.nanoTime();
System.out.println("Time taken by Modulo (%) operator --> " + (end - start) + "ns.");
start = System.nanoTime();
final int shiftFactor = size - 1;
for(int i = 0; i < Integer.MAX_VALUE; i++) {
getNextIndexBitwise(shiftFactor, i);
}
end = System.nanoTime();
System.out.println("Time taken by bitwise AND --> " + (end - start) + "ns.");
}
private static int getNextIndex(int size, int nextInt) {
return nextInt % size;
}
private static int getNextIndexBitwise(int size, int nextInt) {
return nextInt & size;
}
}
But in my runtime environment (MacBook Pro 2.9GHz i7, 8GB RAM, JDK 1.7.0_51) I am seeing otherwise. The bitwise-AND is significantly slower, in fact twice as slow than the remainder operator.
I would appreciate it if someone can help me understand if this is intended behavior or I am doing something wrong?
Thanks,
Niranjan
Your code reports bitwise-and being much faster on each Mac I've tried it on, both with Java 6 and Java 7. I suspect the first portion of the test on your machine happened to coincide with other activity on the system. You should try running the test multiple times to verify you aren't seeing distortions based on that. (I would have left this as a 'comment' rather than an 'answer', but apparently you need 50 reputation to do that -- quite silly, if you ask me.)
For starters, logical conjunction trick only works with Nature Number dividends and power of 2 divisors. So, if you need negative dividends, floats, or non-powers of 2, sick with the default % operator.
My tests (with JIT warmup and 1M random iterations), on an i7 with a ton of cores and bus load of ram show about 20% better performance from the bitwise operation. This can very per run, depending how the process scheduler runs the code.
using Scala 2.11.8 on JDK 1.8.91
4Ghz i7-4790K, 8 core AMD, 32GB PC3 19200 ram, SSD
This example in particular will always give you a wrong result. Moreover, I believe that any program which is calculating the modulo by a power of 2 will be faster than bitwise AND.
REASON: When you use N % X where X is kth power of 2, only last k bits are considered for modulo, whereas in case of the bitwise AND operator the runtime actually has to visit each bit of the number under question.
Also, I would like to point out the Hot Spot JVM's optimizes repetitive calculations of similar nature(one of the examples can be branch prediction etc). In your case, the method which is using the modulo just returns the last 10 bits of the number because 1024 is the 10th power of 2.
Try using some prime number value for size and check the same result.
Disclaimer: Micro benchmarking is not considered good.
Is this method missing something?
public static void oddVSmod(){
float tests = 100000000;
oddbit(tests);
modbit(tests);
}
public static void oddbit(float tests){
for(int i=0; i<tests; i++)
if((i&1)==1) {System.out.print(" "+i);}
System.out.println();
}
public static void modbit(float tests){
for(int i=0; i<tests; i++)
if((i%2)==1) {System.out.print(" "+i);}
System.out.println();
}
With that, i used netbeans built-in profiler (advanced-mode) to run this. I set var tests up to 10X10^8, and every time, it showed that modulo is faster than bitwise.
Thank you all for valuable inputs.
#pamphlet: Thank you very much for the concerns, but negative comments are fine with me. I confess that I did not do proper testing as suggested by AndyG. AndyG could have used a softer tone, but its okay, sometimes negatives help seeing the positive. :)
That said, I changed my code (as shown below) in a way that I can run that test multiple times.
public class ModuloTest {
public static final int SIZE = 1024;
public int usingModuloOperator(final int operand1, final int operand2) {
return operand1 % operand2;
}
public int usingBitwiseAnd(final int operand1, final int operand2) {
return operand1 & operand2;
}
public void doCalculationUsingModulo(final int size) {
for(int i = 0; i < Integer.MAX_VALUE; i++) {
usingModuloOperator(1, size);
}
}
public void doCalculationUsingBitwise(final int size) {
for(int i = 0; i < Integer.MAX_VALUE; i++) {
usingBitwiseAnd(i, size);
}
}
public static void main(String[] args) {
final ModuloTest moduloTest = new ModuloTest();
final int invocationCount = 100;
// testModuloOperator(moduloTest, invocationCount);
testBitwiseOperator(moduloTest, invocationCount);
}
private static void testModuloOperator(final ModuloTest moduloTest, final int invocationCount) {
for(int i = 0; i < invocationCount; i++) {
final long startTime = System.nanoTime();
moduloTest.doCalculationUsingModulo(SIZE);
final long timeTaken = System.nanoTime() - startTime;
System.out.println("Using modulo operator // Time taken for invocation counter " + i + " is " + timeTaken + "ns");
}
}
private static void testBitwiseOperator(final ModuloTest moduloTest, final int invocationCount) {
for(int i = 0; i < invocationCount; i++) {
final long startTime = System.nanoTime();
moduloTest.doCalculationUsingBitwise(SIZE);
final long timeTaken = System.nanoTime() - startTime;
System.out.println("Using bitwise operator // Time taken for invocation counter " + i + " is " + timeTaken + "ns");
}
}
}
I called testModuloOperator() and testBitwiseOperator() in mutual exclusive way. The result was consistent with the idea that bitwise is faster than modulo operator. I ran each of the calculation 100 times and recorded the execution times. Then removed first five and last five recordings and used rest to calculate the avg. time. And, below are my test results.
Using modulo operator, the avg. time for 90 runs: 8388.89ns.
Using bitwise-AND operator, the avg. time for 90 runs: 722.22ns.
Please suggest if my approach is correct or not.
Thanks again.
Niranjan
private static double [] sigtab = new double[1001]; // values of f(x) for x values
static {
for(int i=0; i<1001; i++) {
double ifloat = i;
ifloat /= 100;
sigtab[i] = 1.0/(1.0 + Math.exp(-ifloat));
}
}
public static double fast_sigmoid (double x) {
if (x <= -10)
return 0.0;
else if (x >= 10)
return 1.0;
else {
double normx = Math.abs(x*100);
int i = (int)normx;
double lookup = sigtab[i] + (sigtab[i+1] - sigtab[i])*(normx - Math.floor(normx));
if (x > 0)
return lookup;
else // (x < 0)
return (1 - lookup);
}
}
Anyone know why this "fast sigmoid" actually runs slower than the exact version using Math.exp?
You should profile your code, but I'll bet it's the call to Math.floor taking around half your CPU cycles (it is slow because it calls the native method StrictMath.floor(double), incurring the JNI overhead.)
It is possible to compute (less-accurate) versions of sigmoid functions faster than the (exact) hardware implementations. Here's an example for tanh, which should be easy to transform to your function (is it expit(-x)?)
Two tricks that are used here are often useful in LUT-based approximations:
Simulate rounding by adding a large constant (forcing the FPU will truncate it, having too few bits to represent the sum)
Make your table size a power of 2 (means one less multiply per call)
public static float fastTanH(float x) {
if (x<0) return -fastTanH(-x);
if (x>8) return 1f;
float xp = TANH_FRAC_BIAS + x;
short ind = (short) Float.floatToRawIntBits(xp);
float tanha = TANH_TAB[ind];
float b = xp - TANH_FRAC_BIAS;
x -= b;
return tanha + x * (1f - tanha*tanha);
}
private static final int TANH_FRAC_EXP = 6; // LUT precision == 2 ** -6 == 1/64
private static final int TANH_LUT_SIZE = (1 << TANH_FRAC_EXP) * 8 + 1;
private static final float TANH_FRAC_BIAS =
Float.intBitsToFloat((0x96 - TANH_FRAC_EXP) << 23);
private static float[] TANH_TAB = new float[TANH_LUT_SIZE];
static {
for (int i = 0; i < TANH_LUT_SIZE; ++ i) {
TANH_TAB[i] = (float) Math.tanh(i / 64.0);
}
}
Do you mean looking up in an array of double elements and performing some calculus should be faster than calculating it on the spot?
Altough the CPU only has basic operations, it can handle an exponentiation pretty easily. I'd say in less than 5 basic operations.
What you are doing here is somehow complex and requires actually having to go fetch some elements in the memory. 64bits*1001 surely fits in your cache but cache access time certainly does not match registry access time.
This case does not surprise me in the least.