I am trying to find the tail recursive function fold for a binary tree. Given the following definitions:
// From the book "Functional Programming in Scala", page 45 sealed trait Tree[+A] case class Leaf[A](value: A) extends Tree[A] case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]
The implementation of a non-tail recursive function is quite simple:
def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = t match { case Leaf(v) => map(v) case Branch(l, r) => red(fold(l)(map)(red), fold(r)(map)(red)) }
But now I'm struggling to find the tail recursive fold function so that the @annotation.tailrec can be used.
During my research, I found several examples where tail recursive functions on a tree can, for example, calculate the sum of all leaves using their own stack, which then basically equals List[Tree[Int]] . But, as I understand it, in this case it works only for additions, because it doesn’t matter, first you evaluate the left or right side of the operator. But for a generalized fold, this is very relevant. To show my intention, here are some examples of trees:
val leafs = Branch(Leaf(1), Leaf(2)) val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3)) val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3))) val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4))) val cmb = Branch(right, Branch(bal, Branch(leafs, left))) val trees = List(leafs, left, right, bal, cmb)
Based on these trees, I want to create a deep copy with the given fold method, for example:
val oldNewPairs = trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))
And then the proof that the equality condition is satisfied for all created copies:
val conditionHolds = oldNewPairs.forall(p => { if (p._1 == p._2) true else { println(s"Original:\n${p._1}\nNew:\n${p._2}") false } }) println("Condition holds: " + conditionHolds)
Can someone give me some pointers please?
You can find the code used in this question in ScalaFiddle: https://scalafiddle.io/sf/eSKJyp2/15