Scala recursion vs loop: performance and runtime considerations
Scala static methods for factorial(n)
(coded with scala 2.12.x, java-8):
object Factorial {
/*
* For large N, it throws a stack overflow
*/
def recursive(n:BigInt): BigInt = {
if(n < 0) {
throw new ArithmeticException
} else if(n <= 1) {
1
} else {
n * recursive(n - 1)
}
}
/*
* A tail recursive method is compiled to avoid stack overflow
*/
@scala.annotation.tailrec
def recursiveTail(n:BigInt, acc:BigInt = 1): BigInt = {
if(n < 0) {
throw new ArithmeticException
} else if(n <= 1) {
acc
} else {
recursiveTail(n - 1, n * acc)
}
}
/*
* A while loop
*/
def loop(n:BigInt): BigInt = {
if(n < 0) {
throw new ArithmeticException
} else if(n <= 1) {
1
} else {
var acc = 1
var idx = 1
while(idx <= n) {
acc = idx * acc
idx += 1
}
acc
}
}
}
Specs:
class FactorialSpecs extends SpecHelper {
private val smallInt = 10
private val largeInt = 10000
describe("Factorial.recursive") {
it("return 1 for 0") {
assert(Factorial.recursive(0) == 1)
}
it("return 1 for 1") {
assert(Factorial.recursive(1) == 1)
}
it("return 2 for 2") {
assert(Factorial.recursive(2) == 2)
}
it("returns a result, for small inputs") {
assert(Factorial.recursive(smallInt) == 3628800)
}
it("throws StackOverflow for large inputs") {
intercept[java.lang.StackOverflowError] {
Factorial.recursive(Int.MaxValue)
}
}
}
describe("Factorial.recursiveTail") {
it("return 1 for 0") {
assert(Factorial.recursiveTail(0) == 1)
}
it("return 1 for 1") {
assert(Factorial.recursiveTail(1) == 1)
}
it("return 2 for 2") {
assert(Factorial.recursiveTail(2) == 2)
}
it("returns a result, for small inputs") {
assert(Factorial.recursiveTail(smallInt) == 3628800)
}
it("returns a result, for large inputs") {
assert(Factorial.recursiveTail(largeInt).isInstanceOf[BigInt])
}
}
describe("Factorial.loop") {
it("return 1 for 0") {
assert(Factorial.loop(0) == 1)
}
it("return 1 for 1") {
assert(Factorial.loop(1) == 1)
}
it("return 2 for 2") {
assert(Factorial.loop(2) == 2)
}
it("returns a result, for small inputs") {
assert(Factorial.loop(smallInt) == 3628800)
}
it("returns a result, for large inputs") {
assert(Factorial.loop(largeInt).isInstanceOf[BigInt])
}
}
}
Benchmarks:
import org.scalameter.api._
class BenchmarkFactorials extends Bench.OfflineReport {
val gen: Gen[Int] = Gen.range("N")(1, 1000, 100) // scalastyle:ignore
performance of "Factorial" in {
measure method "loop" in {
using(gen) in {
n => Factorial.loop(n)
}
}
measure method "recursive" in {
using(gen) in {
n => Factorial.recursive(n)
}
}
measure method "recursiveTail" in {
using(gen) in {
n => Factorial.recursiveTail(n)
}
}
}
}
Benchmark results (loop is much faster):
[info] Test group: Factorial.loop
[info] - Factorial.loop.Test-9 measurements:
[info] - at N -> 1: passed
[info] (mean = 0.01 ms, ci = <0.00 ms, 0.02 ms>, significance = 1.0E-10)
[info] - at N -> 101: passed
[info] (mean = 0.01 ms, ci = <0.01 ms, 0.01 ms>, significance = 1.0E-10)
[info] - at N -> 201: passed
[info] (mean = 0.02 ms, ci = <0.02 ms, 0.02 ms>, significance = 1.0E-10)
[info] - at N -> 301: passed
[info] (mean = 0.03 ms, ci = <0.02 ms, 0.03 ms>, significance = 1.0E-10)
[info] - at N -> 401: passed
[info] (mean = 0.03 ms, ci = <0.03 ms, 0.04 ms>, significance = 1.0E-10)
[info] - at N -> 501: passed
[info] (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info] - at N -> 601: passed
[info] (mean = 0.04 ms, ci = <0.04 ms, 0.05 ms>, significance = 1.0E-10)
[info] - at N -> 701: passed
[info] (mean = 0.05 ms, ci = <0.05 ms, 0.05 ms>, significance = 1.0E-10)
[info] - at N -> 801: passed
[info] (mean = 0.06 ms, ci = <0.05 ms, 0.06 ms>, significance = 1.0E-10)
[info] - at N -> 901: passed
[info] (mean = 0.06 ms, ci = <0.05 ms, 0.07 ms>, significance = 1.0E-10)
[info] Test group: Factorial.recursive
[info] - Factorial.recursive.Test-10 measurements:
[info] - at N -> 1: passed
[info] (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info] - at N -> 101: passed
[info] (mean = 0.05 ms, ci = <0.01 ms, 0.09 ms>, significance = 1.0E-10)
[info] - at N -> 201: passed
[info] (mean = 0.03 ms, ci = <0.02 ms, 0.05 ms>, significance = 1.0E-10)
[info] - at N -> 301: passed
[info] (mean = 0.07 ms, ci = <0.00 ms, 0.13 ms>, significance = 1.0E-10)
[info] - at N -> 401: passed
[info] (mean = 0.09 ms, ci = <0.01 ms, 0.18 ms>, significance = 1.0E-10)
[info] - at N -> 501: passed
[info] (mean = 0.10 ms, ci = <0.03 ms, 0.17 ms>, significance = 1.0E-10)
[info] - at N -> 601: passed
[info] (mean = 0.11 ms, ci = <0.08 ms, 0.15 ms>, significance = 1.0E-10)
[info] - at N -> 701: passed
[info] (mean = 0.13 ms, ci = <0.11 ms, 0.14 ms>, significance = 1.0E-10)
[info] - at N -> 801: passed
[info] (mean = 0.16 ms, ci = <0.13 ms, 0.19 ms>, significance = 1.0E-10)
[info] - at N -> 901: passed
[info] (mean = 0.21 ms, ci = <0.15 ms, 0.27 ms>, significance = 1.0E-10)
[info] Test group: Factorial.recursiveTail
[info] - Factorial.recursiveTail.Test-11 measurements:
[info] - at N -> 1: passed
[info] (mean = 0.00 ms, ci = <0.00 ms, 0.01 ms>, significance = 1.0E-10)
[info] - at N -> 101: passed
[info] (mean = 0.04 ms, ci = <0.03 ms, 0.05 ms>, significance = 1.0E-10)
[info] - at N -> 201: passed
[info] (mean = 0.12 ms, ci = <0.05 ms, 0.20 ms>, significance = 1.0E-10)
[info] - at N -> 301: passed
[info] (mean = 0.16 ms, ci = <-0.03 ms, 0.34 ms>, significance = 1.0E-10)
[info] - at N -> 401: passed
[info] (mean = 0.12 ms, ci = <0.09 ms, 0.16 ms>, significance = 1.0E-10)
[info] - at N -> 501: passed
[info] (mean = 0.17 ms, ci = <0.15 ms, 0.19 ms>, significance = 1.0E-10)
[info] - at N -> 601: passed
[info] (mean = 0.23 ms, ci = <0.19 ms, 0.26 ms>, significance = 1.0E-10)
[info] - at N -> 701: passed
[info] (mean = 0.25 ms, ci = <0.18 ms, 0.32 ms>, significance = 1.0E-10)
[info] - at N -> 801: passed
[info] (mean = 0.28 ms, ci = <0.21 ms, 0.36 ms>, significance = 1.0E-10)
[info] - at N -> 901: passed
[info] (mean = 0.32 ms, ci = <0.17 ms, 0.46 ms>, significance = 1.0E-10)
I know everyone already answered the question, but I thought that I might add this one optimization: If you convert the pattern matching to simple if-statements, it can speed up the tail recursion.
final object Factorial {
type Out = BigInt
def calculateByRecursion(n: Int): Out = {
require(n>0, "n must be positive")
n match {
case _ if n == 1 => return 1
case _ => return n * calculateByRecursion(n-1)
}
}
def calculateByForLoop(n: Int): Out = {
require(n>0, "n must be positive")
var accumulator: Out = 1
for (i <- 1 to n)
accumulator = i * accumulator
accumulator
}
def calculateByWhileLoop(n: Int): Out = {
require(n>0, "n must be positive")
var acc: Out = 1
var i = 1
while (i <= n) {
acc = i * acc
i += 1
}
acc
}
def calculateByTailRecursion(n: Int): Out = {
require(n>0, "n must be positive")
@annotation.tailrec
def fac(n: Int, acc: Out): Out = if (n==1) acc else fac(n-1, n*acc)
fac(n, 1)
}
def calculateByTailRecursionUpward(n: Int): Out = {
require(n>0, "n must be positive")
@annotation.tailrec
def fac(i: Int, acc: Out): Out = if (i == n) n*acc else fac(i+1, i*acc)
fac(1, 1)
}
def attempt(f: ()=>Unit): Boolean = {
try {
f()
true
} catch {
case _: Throwable =>
println(" <<<<< Failed...")
false
}
}
def comparePerformance(n: Int) {
def showOutput[A](msg: String, data: (Long, A), showOutput:Boolean = true) =
showOutput match {
case true =>
val res = data._2.toString
val pref = res.substring(0,5)
val midd = res.substring((res.length-5)/ 2, (res.length-5)/ 2 + 10)
val suff = res.substring(res.length-5)
printf("%s returned %s in %d ms\n", msg, s"$pref...$midd...$suff" , data._1)
case false =>
printf("%s in %d ms\n", msg, data._1)
}
def measure[A](f:()=>A): (Long, A) = {
val start = System.currentTimeMillis
val o = f()
(System.currentTimeMillis - start, o)
}
attempt(() => showOutput ("By for loop", measure(()=>calculateByForLoop(n))))
attempt(() => showOutput ("By while loop", measure(()=>calculateByWhileLoop(n))))
attempt(() => showOutput ("By non-tail recursion", measure(()=>calculateByRecursion(n))))
attempt(() => showOutput ("By tail recursion", measure(()=>calculateByTailRecursion(n))))
attempt(() => showOutput ("By tail recursion upward", measure(()=>calculateByTailRecursionUpward(n))))
}
}
My results:
scala> Factorial.comparePerformance(20000)
By for loop returned 18192...5708616582...00000 in 179 ms
By while loop returned 18192...5708616582...00000 in 159 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 18192...5708616582...00000 in 169 ms
By tail recursion upward returned 18192...5708616582...00000 in 174 ms
By for loop returned 18192...5708616582...00000 in 212 ms
By while loop returned 18192...5708616582...00000 in 156 ms
By non-tail recursion returned 18192...5708616582...00000 in 155 ms
By tail recursion returned 18192...5708616582...00000 in 166 ms
By tail recursion upward returned 18192...5708616582...00000 in 137 ms
scala> Factorial.comparePerformance(200000)
By for loop returned 14202...0169293868...00000 in 17467 ms
By while loop returned 14202...0169293868...00000 in 17303 ms
By non-tail recursion <<<<< Failed...
By tail recursion returned 14202...0169293868...00000 in 18477 ms
By tail recursion upward returned 14202...0169293868...00000 in 17188 ms
For loops are not actually quite loops; they're for comprehensions on a range. If you actually want a loop, you need to use while
. (Actually, I think the BigInt
multiplication here is heavyweight enough so it shouldn't matter. But you'll notice if you're multiplying Int
s.)
Also, you have confused yourself by using BigInt
. The bigger your BigInt
is, the slower your multiplication. So your non-tail loop counts up while your tail recursion loop counds down which means that the latter has more big numbers to multiply.
If you fix these two issues you will find that sanity is restored: loops and tail recursion are the same speed, with both regular recursion and for
slower. (Regular recursion may not be slower if the JVM optimization makes it equivalent)
(Also, the stack overflow fix is probably because the JVM starts inlining and may either make the call tail-recursive itself, or unrolls the loop far enough so that you don't overflow any longer.)
Finally, you're getting poor results with for and while because you're multiplying on the right rather than the left with the small number. It turns out that the Java's BigInt multiplies faster with the smaller number on the left.