If you recall a while back when I was demonstrating some Functional Data Structures, I mentioned the fact that some of the functions were not tail recursive, and that this is something that we would probably want to do something about. Which raises the question: How exactly do we go about making a function tail-recursive? I am going to attempt to address that question here.

One of the first problems with creating a tail recursive function is figuring out whether a function is tail recursive in the first place. Sadly this isn’t something that is always obvious. There has been some discussion about generating a compiler warning if a function is not tail recursive, which sounds like a dandy idea since the compiler knows enough to know how to optimize tail recursive functions for us. But we don’t have that yet, so we’re going to have to try and figure it out on our own. So here are some things to look for:

  1. When is the recursive call made? And more importantly, is there anything that happens after it? Even something simple like adding a number to the result of the function can cause a function to not be tail-recursive
  2. Are there multiple recursive calls? This sort of thing happens when processing tree-like data structures quite a bit. If you need to apply a function recursively to 2 sub-sets of elements and then combine them, chances are they are not tail-recursive
  3. Is there any exception handling in the body of the function? This includes use and using declarations. Since there are multiple possible return paths then the compiler can’t optimize things to make the call recursive.

Now that we have a chance of identifying non-tail-recursive functions lets take a look at how to make a function tail-recursive. There may be cases where its not possible for various reasons to make a function tail-recursive, but it is worthwhile trying to make sure recursive functions are tail-recursive because a StackOverflowException cannot be caught, and will cause the process to exit, regardless (yes, I know this from previous experience Smile)

Accumulators

One of the primary ways of making a function tail-recursive is to provide some kind of accumulator as one of the function parameters so that you an build the final result and pass it on using the accumulator, so when the recursion is complete you return the accumulator. A very simple example of this would be creating a sum function on a list of ints. A simple non-tail-recursive version of this function might look like:

let rec sum (items:int list) =
    match items with
    | [] –> 0
    | i::rest –> i + (sum rest)

And making it tail-recursive by using an accumulator would look like this:

let rec sum (items:int list) acc =
    match items with
    | [] –> acc
    | i::rest –> sum rest (i+acc)

 

Continuations

If you can’t use an accumulator as part of your function another (reasonably) simple approach is to use a continuation function. Basically the approach here is to take the work you would be doing after the recursive call and put it into a function that gets passed along and executed when the recursive call is complete. For an example where we’re going to use the insert function from my Functional Data Structures post. Here is the original function:

let rec insert value tree =
    match (value,tree) with
    | (_,Empty) –> Tree(Empty,value,Empty)
    | (v,Tree(a,y,b)) as s –>
        if v < y then
            Tree(insert v a,y,b)
        elif v > y then
            Tree(a,y,insert v b)
        else snd s

This is slightly more tricky since we need to build a new tree with the result, and the position of the result will also vary. So lets add a continuation function to this and see what changes:

let rec insert value tree cont =
    match (value,tree) with
    | (_,Empty) –> Tree(Empty,value,Empty) |> cont
    | (v,Tree(a,y,b)) as s –>
        if v < y then
            insert v a (fun t –> Tree(t,y,b)) |> cont
        elif v > y then
            insert v b (fun t –> Tree(a,y,t)) |> cont
        else snd s |> cont

For the initial call of this function you’ll want to pass in the built-in id function, which just returns whatever you pass to it. As you can see the function is a little more involved, but still reasonably easy to follow. The key is to make sure you apply the continuation function to the result of the function call, otherwise things will fall apart pretty quickly

These two techniques are the primary means of converting a non-tail-recursive function to a tail-recursive function. There is also a more generalized technique known as a “trampoline” which can also be used to eliminate the accumulation of stack frames (among other things). I’ll leave that as a topic for another day, though.

Another thing worth pointing out, is that the built-in fold functions available in the F# standard library are already tail-recursive. So if you make use of fold you don’t have to worry about how to make your function tail recursive. Yet another reason to make fold your go-to function.

Let’s say that you’ve been working hard on this really awesome data structure. Its fast, its space efficient, its immutable, its everything anyone could dream of in a data structure. But you only have time to implement one function for processing the data in your new miracle structure, so what would it be?

Ok, not a terribly realistic scenario, but bare with me here, there is a point to this. The answer to this question, of course, is that you would implement fold. Why you might ask? Because if you have a fold implementation then it is possible to implement just about any other function you want in terms of fold. Don’t believe me? Well, I’ll show you, and in showing you I’ll also demonstrate how finding the right abstraction in a functional language can reduce the size and complexity of your codebase in amazing ways.

Now, to get started, let’s take a look at what exactly the fold function is:

val fold:folder('State -> 'T -> 'State) -> state:'State -> list:'T list -> 'State

In simple terms it iterates over the items in the structure, and applies a function to each element which in some way processes the element and returns some kind of accumulator. Ok, maybe that didn’t come through quite as simply as I would have hoped. So how about start with a pretty straight-forward example: sum.

 
let sum (list:int list)= List.fold (fun s i -> s + i) 0 list 

Here we are folding over a list of integers, but in theory the data structure could be just about anything. Each item in the list gets added to the previous total. The first item is added with the value passed in to the fold, so for items [1;2;3] we start by adding 1 to 0, then 2 to 1, then 3 to 3, the result is 6. We could even get kinda crazy with the point-free style and use the fact that the + operator is a function which takes two arguments, and returns a third…which happens to exactly match our folding function.

let sum (list:int list) = List.fold (+) 0 list

So that’s pretty cool right? Now it seem like you could also very easily create a Max function for your structure by using the built in max operator, or a Min function using the min operator.

let max (list:int list) = List.fold (max) (Int32.MinValue) list 
let min (list:int list) = List.fold (min) (Int32.MaxValue) list

But I did say that you could create any other processing function right? So how about something a little trickier, like Map? It may not be quite as obvious, but the implementation is actually equally simplistic. First let’s take a quick look at the signature of the map function to refresh our memories:

val map: mapping ('T –> 'U) –> list:'T list –> 'U list

So how do we implement that in terms of fold? Again, we’ll use List because its simple enough to see what goes on internally:

let map (mapping:'a -> 'b) (list:'a list) = List.fold (fun l i –> mapping i::l) [] list

Pretty cool right? Use the Cons operator (::) and a mapping function with an initial value of an empty list. So that’s pretty fun, how about another classic like filter? Also, pretty similar

let filter (pred:'a -> bool) (list:'a list) = List.fold (fun l i –> if pred I then i::l else l) [] list

Now we’re on a roll, so how about the choose function (like map, only returns an Option and any None value gets left out)? No problem.

let choose (chooser:'a –> 'b option) (list:'a list) = List.fold (fun l i –> match chooser i with | Some i –> i::l | _ –> l) [] list

Ok, so now how about toMap?

let toMap (list:'a*'b list) = List.fold (fun m (k,v) –> Map.add k v) Map.empty list

And collect (collapsing a list of lists into a single list)?

list collect (list:'a list list) = List.fold (fun l li –> List.fold (fun l' i' –> i'::l') l li) [] list

In this case we’re nesting a fold inside a fold, but it still works. And now, just for fun, list exists, tryFind, findIndex, etc

let exists (pred:'a -> bool) (list:'a list) = List.fold (fun b i -> pred i || b) false list
let tryFind (pred:'a -> bool) (list:'a list) = List.fold (fun o i -> if pred i then Some i else o) None list
let findIndex (pred:'a -> bool) (list:'a list) = List.fold (fun (idx,o) i -> if pred i then (idx + i,Some idx) else (idx + 1,o)) (-1,None) list |> snd |> Option.get
let forall (pred:'a -> bool) (list:'a list) = List.fold (fun b i -> pred i && b) true list
let iter (f:'a -> unit) (list:'a list) = List.fold (fun _ i -> f i) () list
let length (list:'a list) = List.fold (fun c _ -> c + 1) 0 list
let partition (pred:'a -> bool) (list:'a list) = List.fold (fun (t,f) i -> if pred i then i::t,f else t,i::f) ([],[]) list

Its worth pointing out that some of these aren’t the most efficient implementations. For example, exists, tryFind and findIndex ideally would have some short-circuit behavior so that when the item is found the list isn’t traversed any more. And then there are things like rev, sort, etc which could be built in terms of fold, I guess, but the simpler and more efficient implementations would be done using simpler recursive processing. I can’t help but find the simplicity of the fold abstraction very appealing, it makes me ever so slightly giddy (strange, I know).