Binary tree fractal fold tail in Scala

I am trying to find the tail recursive function fold for a binary tree. Given the following definitions:

// From the book "Functional Programming in Scala", page 45 sealed trait Tree[+A] case class Leaf[A](value: A) extends Tree[A] case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A] 

The implementation of a non-tail recursive function is quite simple:

 def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = t match { case Leaf(v) => map(v) case Branch(l, r) => red(fold(l)(map)(red), fold(r)(map)(red)) } 

But now I'm struggling to find the tail recursive fold function so that the @annotation.tailrec can be used.

During my research, I found several examples where tail recursive functions on a tree can, for example, calculate the sum of all leaves using their own stack, which then basically equals List[Tree[Int]] . But, as I understand it, in this case it works only for additions, because it doesn’t matter, first you evaluate the left or right side of the operator. But for a generalized fold, this is very relevant. To show my intention, here are some examples of trees:

 val leafs = Branch(Leaf(1), Leaf(2)) val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3)) val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3))) val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4))) val cmb = Branch(right, Branch(bal, Branch(leafs, left))) val trees = List(leafs, left, right, bal, cmb) 

Based on these trees, I want to create a deep copy with the given fold method, for example:

 val oldNewPairs = trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _)))) 

And then the proof that the equality condition is satisfied for all created copies:

 val conditionHolds = oldNewPairs.forall(p => { if (p._1 == p._2) true else { println(s"Original:\n${p._1}\nNew:\n${p._2}") false } }) println("Condition holds: " + conditionHolds) 

Can someone give me some pointers please?

You can find the code used in this question in ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15

+7
scala tail-recursion fold binary-tree tree
source share
1 answer

You can achieve a tail recursive solution if you stop using the function call stack and start using the stack controlled by your code and battery:

 def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = { case object BranchStub extends Tree[Nothing] @tailrec def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] = if(toVisit.isEmpty) acc else { toVisit.head match { case Leaf(v) => val leafRes = map(v) foldImp( toVisit.tail, acc :+ leafRes ) case Branch(l, r) => foldImp(l :: r :: BranchStub :: toVisit.tail, acc) case BranchStub => foldImp(toVisit.tail, acc.dropRight(2) ++ Vector(acc.takeRight(2).reduce(red))) } } foldImp(t::Nil, Vector.empty).head } 

The idea is to accumulate values ​​from left to right, track the parental rights ratio by introducing a node stub and reduce the result using your red function, using the last two battery cells every time a node stub is in the study.

This solution can be optimized, but it is already an implementation of a tail recursive function.

EDIT:

It can be simplified a little by changing the data structure of the drive to a list considered as a stack:

 def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = { case object BranchStub extends Tree[Nothing] @tailrec def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] = if(toVisit.isEmpty) acc else { toVisit.head match { case Leaf(v) => foldImp( toVisit.tail, map(v)::acc ) case Branch(l, r) => foldImp(r :: l :: BranchStub :: toVisit.tail, acc) case BranchStub => foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2)) } } foldImp(t::Nil, Nil).head } 
+5
source share

All Articles