How to make tree mapping tail-recursive?
"Call stack" and "recursion" are merely popular design patterns that later got incorporated into most programming languages (and thus became mostly "invisible"). There is nothing that prevents you from reimplementing both with heap data structures. So, here is "the obvious" 1960's TAOCP retro-style solution:
trait Node { val name: String }
case class BranchNode(name: String, children: List[Node]) extends Node
case class LeafNode(name: String) extends Node
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
case class Frame(name: String, mapped: List[Node], todos: List[Node])
@annotation.tailrec
def step(stack: List[Frame]): Node = stack match {
// "return / pop a stack-frame"
case Frame(name, done, Nil) :: tail => {
val ret = BranchNode(name, done.reverse)
tail match {
case Nil => ret
case Frame(tn, td, tt) :: more => {
step(Frame(tn, ret :: td, tt) :: more)
}
}
}
case Frame(name, done, x :: xs) :: tail => x match {
// "recursion base"
case l @ LeafNode(_) => step(Frame(name, f(l) :: done, xs) :: tail)
// "recursive call"
case BranchNode(n, cs) => step(Frame(n, Nil, cs) :: Frame(name, done, xs) :: tail)
}
case Nil => throw new Error("shouldn't happen")
}
root match {
case l @ LeafNode(_) => f(l)
case b @ BranchNode(n, cs) => step(List(Frame(n, Nil, cs)))
}
}
The tail-recursive step
function takes a reified stack with "stack frames". A "stack frame" stores the name of the branch node that is currently being processed, a list of child nodes that have already been processed, and the list of the remaining nodes that still must be processed later. This roughly corresponds to an actual stack frame of your recursive mapLeaves
function.
With this data structure,
- returning from recursive calls corresponds to deconstructing a
Frame
object, and either returning the final result, or at least making thestack
one frame shorter. - recursive calls correspond to a step that prepends a
Frame
to thestack
- base case (invoking
f
on leaves) does not create or remove any frames
Once one understands how the usually invisible stack frames are represented explicitly, the translation is straightforward and mostly mechanical.
Example:
val example = BranchNode("x", List(
BranchNode("y", List(
LeafNode("a"),
LeafNode("b")
)),
BranchNode("z", List(
LeafNode("c"),
BranchNode("v", List(
LeafNode("d"),
LeafNode("e")
))
))
))
println(mapLeaves(example, { case LeafNode(n) => LeafNode(n.toUpperCase) }))
Output (indented):
BranchNode(x,List(
BranchNode(y,List(
LeafNode(A),
LeafNode(B)
)),
BranchNode(z, List(
LeafNode(C),
BranchNode(v,List(
LeafNode(D),
LeafNode(E)
))
))
))
It might be easier to implement it using a technique called trampoline.
If you use it, you'd be able to use two functions calling itself doing mutual recursion (with tailrec
, you are limited to one function). Similarly to tailrec
this recursion will be transformed to plain loop.
Trampolines are implemented in scala standard library in scala.util.control.TailCalls
.
import scala.util.control.TailCalls.{TailRec, done, tailcall}
def mapLeaves(root: Node, f: LeafNode => LeafNode): Node = {
//two inner functions doing mutual recursion
//iterates recursively over children of node
def iterate(nodes: List[Node]): TailRec[List[Node]] = {
nodes match {
case x :: xs => tailcall(deepMap(x)) //it calls with mutual recursion deepMap which maps over children of node
.flatMap(node => iterate(xs).map(node :: _)) //you can flat map over TailRec
case Nil => done(Nil)
}
}
//recursively visits all branches
def deepMap(node: Node): TailRec[Node] = {
node match {
case ln: LeafNode => done(f(ln))
case bn: BranchNode => tailcall(iterate(bn.children))
.map(BranchNode(bn.name, _)) //calls mutually iterate
}
}
deepMap(root).result //unwrap result to plain node
}
Instead of TailCalls
you could also use Eval
from Cats
or Trampoline
from scalaz
.
With that implementation function worked without problems:
def build(counter: Int): Node = {
if (counter > 0) {
BranchNode("branch", List(build(counter-1)))
} else {
LeafNode("leaf")
}
}
val root = build(4000)
mapLeaves(root, x => x.copy(name = x.name.reverse)) // no problems
When I ran that example with your implementation it caused java.lang.StackOverflowError
as expected.