Kiran
Kiran Author of uvCharts, Marvelous and Claft

Demystifying Tail Call Optimization

Tail call optimization is becoming increasingly popular due to the interest in functional programming and its inclusion in newer versions of popular languages. So what exactly is tail call? and how does it help in optimizing the flow of a program?

In computer science, a tail call is a subroutine call performed as the final action of a procedure. If a tail call might lead to the same subroutine being called again later in the call chain, the subroutine is said to be tail-recursive, which is a special case of recursion. Tail recursion is particularly useful, and often easy to handle in implementations.

  • Wikipedia

Lets keep aside the standard definition of tail call for a while and lets look at how traditional recursive functions work with an example.

int factorial(n) {
  if (n < 0) return -1;
  if (n <= 1) return 1;
  return n * factorial(n-1);
}

int main() {
  int n = 5;
  printf("%d\n", factorial(5));
  return 0;
}

That’s a pretty simple example of recursion to calculate factorial of 5 in C language. Let us see how the call stack would look like when the last function call of factorial is about to return and then initiate the series of returns.

printf("%d\n", factorial(5))
  -> return 5 * factorial(4) :: factorial(5) returns to main function
    -> return 4 * factorial(3) :: factorial(4) returns to factorial(5) instance in the call stack
      -> return 3 * factorial(2) :: factorial(3) returns to factorial(4) instance in the call stack
        -> return 2 * factorial(1)
          -> return 1

The function call stack at this point is like

- main :: n = 5
- factorial(5) :: [n = 5, return_address: main]
- factorial(4) :: [n = 4, return_address: factorial(5)]
- factorial(3) :: [n = 3, return_address: factorial(4)]
- factorial(2) :: [n = 2, return_address: factorial(3)]
- factorial(1) :: [n = 1, return_address: factorial(2)]

With every call during the recursion the call stack keeps increasing and the local variables of a function live on until the function itself returns a value to the callee function (the function which has invoked the call).

The call stack has to be maintained because each factorial call for n > 1 is returning n multiplied by the return value of the next factorial call so the function state is important to be remembered.

Now, think, what if the return value had nothing to do with the state of the function call. It would make total sense if none of the state is maintained in the call stack and the next call of the function directly returns it to the callee of the current function call, right?. This sort of optimization is called Tail Call Optimization and is pretty much defacto in most functional languages. You can now re-read the definition above from wikipedia, things will begin to start making sense.

Let us rewrite the factorial example given above to be tail call compliant, remember the optimization depends on the compiler/runtime/interpreter and will vary based on the programming language.

int _factorial(n, acc) {
  if (n < 0) return -1;
  if (n <= 1) return acc;
  return _factorial(n-1, n * acc);
}

int factorial(n) {
  return _factorial(n, 1);
}

int main() {
  int n = 5;
  printf("%d\n", factorial(5));
  return 0;
}

Here we are delegating the factorial call to another function called _factorial which takes in 2 arguments instead of 1. The 2nd argument is more like the current state of the factorial when the function was invoked and thus the callee function can directly return the result to its callee function.

Let us look at how the factorial call would look like in this case in a tail call optimization compliant runtime:

printf("%d\n", factorial(5))
  -> return _factorial(5, 1) :: return to main
    -> return _factorial(4, 5) :: return to factorial
    -> return _factorial(3, 20) :: return to factorial
    -> return _factorial(2, 60) :: return to factorial
    -> return _factorial(1, 120) :: return to factorial

The call stack at this point is similar to:

- main :: n = 5
- factorial(5) :: [n = 5, return_address: main]
- _factorial(1, 120) :: [n = 1, return_address: factorial(5)]

If you look at it now, the _factorial calls keep replacing each other and each of them just returns a value to its root callee, factorial which in turn returns the value directly to main function.

Before I call this the end of the blog post, let us look at few examples of tail recursive functions and few non-tail recursive functions. The examples are from Elixir but you’ll be able to follow:

defmodule TailRecursive do
  def gcd(a, 0), do: a
  def gcd(0, a), do: a
  def gcd(a, b), do: gcd(b, rem(a, b))

  def len([], length), do: length
  def len([h | t], length), do: len(t, 1+length)
end

defmodule NonTailRecursive do
  def fib(n) when n < 0, do: raise "InvalidFibonacciInvocation"
  def fib(n) when n <= 1, do: n
  def fib(n), do: fib(n-1) + fib(n-2)

  def sum([]), do: 0
  def sum([h | t]), do: h + sum(t)
end

Just notice how the part after the do: part changes in the tail recursive and non tail recursive examples.

Most functional languages like Haskell, Scala, Erlang, Elixir already have tail call optimization and some of the popular languages are more likely to have it in the near future. JavaScript will get tail call optimization as part of the ES6 specification.

Tail Call Optimization is merely a complex term for a rather simple concept which surrounds the function call stack during execution of a program.

comments powered by Disqus