Inside F#

Catamorphisms, part six

Posted by Brian on June 2, 2008

Oops!… I did it again.

I completely botched a key aspect of my previous blog entry.  Fortunately I have some alert readers who are keeping me honest.  Whereas last time I had to correct a blunder about run-time performance, this time I have to correct my implementation because I failed to make it properly tail-recursive.  It’s a learning opportunity, both for you and for me!

The problem

Last time I showed this code for KFoldTree and Change5to0bst:

let KFoldTree nodeF leafV tree =
let rec Loop t =
match t with
| Node(x,left,right) -> nodeF x (fun k -> k (Loop left)) (fun k -> k (Loop right))
| Leaf -> leafV t
Loop tree

// Change5to0bst : Tree<int> -> Tree<int>
let Change5to0bst tree =
KFoldTree (fun x kl kr t ->
let (Node(_,oldL,oldR)) = t
if x < 5 then
kr (fun newR ->
Node (x, oldL, newR))
elif x > 5 then
kl (fun newL ->
Node (x, newL, oldR))
else
Node(0,oldL,oldR)
) (fun t -> t) tree

and I looked very carefully at Change5to0bst to ensure that every call was a tail call.  It is.  The problem is, the "Loop" calls in KFoldTree are not tail calls!  For example:

(fun k -> k (Loop left))

Here, we will make a recursive call to "Loop", but when that call returns, there is still "more work to do" (we must pass that result to "k").  Thus, Loop is not a tail call here.  Oops!  The implication is that we must allocate a stack frame for the duration of the recursive call.  And so if we write…

// CreateZeroRightTree : int -> Tree<int>
let CreateZeroRightTree size =
let rec Loop t n =
if (n < size) then
Loop (Node(0,Leaf,t)) (n+1)
else
t
Loop Leaf 0
// make a big tree of 2 million nodes all going to the right
let bigTree = CreateZeroRightTree (2 * 1000 * 1000)
// call our supposedly-tail-recursive function on it
Change5to0bst bigTree

…sure enough – kaboom!  StackOverflowException.  Clearly I failed to test my code from my previous blog entry.

The fix

The fix involves explicitly passing the continuations throughout the computation.  The definition of KFold actually gets a little simpler, though the client code becomes slightly more complicated.  Here’s the new KFold:

let KFoldTree nodeF leafV tree =
let rec Loop t k =
match t with
| Node(x,left,right) -> nodeF x (Loop left) (Loop right) t k
| Leaf -> leafV t k
Loop tree (fun x -> x)

Relative to the previous (broken) version, "Loop" takes an extra continuation parameter, and passed it as an extra parameter to the client functions ("nodeF" and "leafV").  Note that the "Loop" calls got simpler, since for example

(fun k -> Loop left k)

can just be written as

(Loop left)

thanks to currying.

Here is how this affects the client:

let Change5to0bst tree =
KFoldTree (fun x kl kr t k ->
let (Node(_,oldL,oldR)) = t
if x < 5 then
kr (fun newR ->
k (Node(x, oldL, newR)))
elif x > 5 then
kl (fun newL ->
k (Node(x, newL, oldR)))
else
k (Node(0,oldL,oldR))
) (fun t k -> k t) tree

The two client lambdas now take an extra final parameter "k", and everywhere that the client used to "return a final value", now we are calling the continuation "k" on that final value.  Apart from that, the code is otherwise unchanged.

The lesson

Tail recursion is subtle – especially when dealing with mutually recursive functions/lambdas.  If you are going to try to be tail-recursive, test your code on large inputs to ensure you got it right!  I failed to test the new "Eval" function (see below), so as to leave that as a good exercise for you to try – create some large data and find out if I got the definitions of KFoldExpr and Eval right!

Other bits

Again, we can express XFold in terms of the new (corrected) KFold, and we can generalize this new KFold to other discriminated union types (like Expr).  For examples, see today’s source code.

Sorry that today’s blog entry is so short on prose, but I didn’t originally intend to spend time writing a blog entry today.  The error in my previous blog entry was sufficiently grievous, though, that I felt compelled to correct it immediately.

The source code

open System

// handy operator
let (===) = fun x y -> Object.ReferenceEquals(x,y)

type Tree<‘a> =
| Node of (*data*)‘a * (*left*)Tree<‘a> * (*right*)Tree<‘a>
| Leaf

//     4
//  2     6
// 1 3   5 7
let tree7 = Node(4, Node(2, Node(1, Leaf, Leaf), Node(3, Leaf, Leaf))
Node(6, Node(5, Leaf, Leaf), Node(7, Leaf, Leaf)))

let KFoldTree nodeF leafV tree =
let rec Loop t k =
match t with
| Node(x,left,right) -> nodeF x (Loop left) (Loop right) t k
| Leaf -> leafV t k
Loop tree (fun x -> x)

let Change5to0bst tree =
KFoldTree (fun x kl kr t k ->
let Node(_,oldL,oldR) = t
if x < 5 then
kr (fun newR ->
k (Node(x, oldL, newR)))
elif x > 5 then
kl (fun newL ->
k (Node(x, newL, oldR)))
else
k (Node(0,oldL,oldR))
) (fun t k -> k t) tree

// CreateZeroRightTree : int -> Tree<int>
let CreateZeroRightTree size =
let rec Loop t n =
if (n < size) then
Loop (Node(0,Leaf,t)) (n+1)
else
t
Loop Leaf 0
// make a big tree of 2 million nodes all going to the right
let bigTree = CreateZeroRightTree (2 * 1000 * 1000)
// call our tail-recursive function on it, to prove we get no StackOverflowException
Change5to0bst bigTree

// XFoldTree : (‘a -> ‘r -> ‘r -> Tree<‘a> -> ‘r) -> (Tree<‘a> -> ‘r) -> Tree<‘a> -> ‘r
let XFoldTree nodeF leafV tree =
KFoldTree (fun x l r t k -> l (fun lacc -> r (fun racc -> k (nodeF x lacc racc t))))
(fun t k -> k (leafV t)) tree

// Other useful Tree boilerplate from previous blogs
let XNode (x,l,r) (Node(xo,lo,ro) as orig) =
if xo = x && lo === l && ro === r then
orig
else
Node(x,l,r)
let XLeaf (Leaf as orig) =
orig
let FoldTree nodeF leafV tree =
XFoldTree (fun x l r _ -> nodeF x l r) (fun _ -> leafV) tree

// another example to suggest that the XFold written in terms of the KFold is also still tail-recursive
let XChange5to0 tree =
XFoldTree (fun x l r -> XNode((if x=5 then 0 else x), l, r)) XLeaf tree
XChange5to0 bigTree  // no StackOverflowException

///////////////////////////////////////////////////////////////////////////////////

// types capable of representing a small integer expression language
type Op =
| Plus
| Minus
type Expr =
| Literal of int
| BinaryOp of Expr * Op * Expr     // left, op, right
| IfThenElse of Expr * Expr * Expr // cond, then, else; 0=false in cond
| Print of Expr                    // prints, then returns that value

let exprs = [Literal(42)
BinaryOp(Literal(1), Plus, Literal(1))
IfThenElse(Literal(1), Print(Literal(42)), Print(Literal(0)))
]

let KFoldExpr litF binF ifF printF expr =
let rec Loop ex k =
match ex with
| Literal(x) -> litF x ex k
| BinaryOp(l,op,r) -> binF (Loop l) op (Loop r) ex k
| IfThenElse(c,t,e) -> ifF (Loop c) (Loop t) (Loop e) ex k
| Print(e) -> printF (Loop e) ex k
Loop expr (fun x -> x)

// Eval : Expr -> int
let Eval expr =
KFoldExpr (fun x _ k -> k x)
(fun kl op kr _ k -> match op with
| Plus -> kl (fun l -> kr (fun r -> k (l+r)))
| Minus -> kl (fun l -> kr (fun r -> k (l-r))))
(fun kc kt ke _ k -> kc (fun c -> if c <> 0 then
kt (fun t -> k t)
else
ke (fun e -> k e)))
(fun ke _ k -> ke (fun e -> printf "<%d>" e
k e))
expr

exprs |> List.iter (fun expr -> printfn "%d" (Eval expr))
// 42
// 2
// <42>42