Tail Recursion in Scala

Since the time I started programming using Scala, I’ve vehemently tried to adopt the paradigms of functional programming as much as possible. Recently I had to work on a recursive piece of code and my architect suggested me to use something known as Tail Recursion instead of Traditional recursion and this blog talks all about it… So, lets get started !!

To understand (Tail) recursion you first have to understand recursion

In an ordinary recursive program, the method which is recursive will usually

  • Get the input and partially process it to form a [partial result + some unprocessed data]
  • It keeps the partial result [on its stack] then makes a call to itself and only passes the unprocessed data, and till the second call is not complete , it cannot return the final result.
  • The second call processes some more data [keeps the result] and makes a third call with the remainder of the unprocessed data & this process keeps repeating [& our Method Call Stacks keep on increasing]
  • With each invocation,  the unprocessed data keeps shrinking & we eventually reach a terminating condition when we dont have any unprocessed data left.
  • This last call completes and returns its result to the second last call which in turns does the final processing to return the output back to the third last call and this goes until we reach the first call which can now compute the final result.

Nobody likes StackOverFlow

The major problem with this approach is that each method call causes a new call stack to be created which has to be kept in memory till all the succeeding recursive calls finish. This happens because each recursive call other than the final one, is storing its partial result while waiting for the method it called to give its output.

Tail Recursion – There is no going back

Tail recursion is a type of recursive call where the computation of the current recursive call is done before the next call is made. The recursive method call is the last line of the method. Contrasting it with traditional recursion

  • The recursive method computes the partial result and finishes its processing
  • It then sends its result to the next call alongside the unprocessed data.
  • The next call does its bit of processing using the partial result and the sends its result to the next call
  • This goes on till we reach the last call which now has all the data it needs to compute the final result
  • Since the calls don’t need the stack [because all the data is sent over], Scala is smart enough to purge the stacks and not store them in memory thus allowing it to run in constant stack space

This single highhandedly eliminates StackOverflow exceptions as long as our code is tail recursive.

Tail Recursion == Iteration ?!!!

This epiphany hit me when I understood the advantages of tail recursion and from a computational resources standpoint, tail recursion is in many ways similar to iteration. Both use constant stack space to generate the result and do so in multiple steps. The major advantage of any recursive approach is in expressiveness. A lot of the algorithms employ Divide & Conquer approaches in their core and implementing them is way easier in recursion than using iteration

I used tail recursion in a real world problem which involved traversing a Json document which is akin to navigating a tree using the breadth first approach when using a tail recursive way whereas using a traditional recursion will resemble a depth first approach

Sounds great ! How do I implement it ??

Glad you asked !!

  • use @tailrec
    This forces Scala to do compile time check on the written code to guarantee that tail call optimization will be carried out
  • The recursive call should be the last line of your method [The compiler will force this because there is no coming back to this stack]
  • Usually a single value in the form of aggregations [Sum,Mix,Max etc] of results is sent over as argument to the next call, but if you want to send more than one, you can pass an accumulator a.k.a put values in a collection and pass that object as an argument
  • Lastly, Use a modern IDE like Intellij which shows a pictorial representation of Traditional recursion vs Tail recursion
Traditional Recursion
Tail Recursion

Converting Traditional to Tail Recursion

Lets have a look at this explanation with an example. The below code snippet is responsible for traversing a Json document with the intent of transforming its schema into another form. So we visit each node and add its data into a collection. If you imagine the Json Document as a Tree, then the values are all leaf nodes which need to be visited.

Traditional Recursive Code

//
def traverse(json: JsonNode) = {
    // For the input json node we got as parameter
	// if leaf , fetch and process it - End of recursion
	if(node.getChildNodes == null) process(node.value)
	else
	// for every child // if not leaf, call traverse again
	node
	.getChildNodes
	.foreach(child => traverse(child))
}
//

The above traversal follows Depth First Approach and even after the call at the highlighted line, the stack cannot be removed because the execution is not entirely complete and would eventually return back to this line to process the next child

Tail Recursive Code

//
@tailrec
def traverse(accumulator : Queue) = {
	if (queue.isEmpty){
		// End of recursion as we have processed all elements
	}
	else{
		val  currentNode = accumulator.dequeue
		// if leaf , fetch and process it - End of recursion
		if(currentNode.getChildNodes == null) process(node.value)
		else
		// add every child in the accumulator & call traverse again
		node
		.getChildNodes
		.foreach(child => accumulator.enqueue(child))
		
		// we call traverse once we are done processing all the elements in the curent call
		traverse(accumulator)
	}
}
//

The above traversal follows Breadth First Approach and all the processing for the currentNode is done [get & add all children] by the time the next call is made, so the Scala runtime can purge the stack.

tl;dr

By changing the way we think about the recursive calls, we can write code in tail recursive format which gives us the expressiveness of recursion alongside the memory performance of iteration!