1. Overview

In this tutorial, we’ll take a look into loops and recursion in the JVM. We’ll also look into common pitfalls to avoid as well as optimization techniques to help make our code concise as well as avoid running into errors.

2. Stacks and Stack Frames

To understand functional loops in Scala, it’s important to understand what stacks are.

2.1. Stacks

Stacks, also known as a call stack is defined by the official docs as a data structure that holds local variables and partial results as well as partially controlling method return and invocation.

A single JVM thread has a private JVM stack created at the same time as the thread and stores stack frames. The JVM performs two major operations directly on a java stack. It pushes and pops stack frames.

The default stack size varies between computer architecture as seen in this documentation. Although the stack size can be increased, it’s encouraged to design your code in a stack safe and efficient manner, rather than unnecessarily increasing stack size.

The stack contains information about what methods the thread has called to reach the current point of execution. The call stack changes as the thread executes its code.

The stack also contains all local variables for each method being executed (all methods on the call stack). A thread can only access its own thread stack.

Local variables created by a thread are not visible to any thread other than the thread that created it. Even if two threads are executing the exact same code, both threads will still create local variables of that code in each of their own thread stacks. Thus, each thread has its own version of each local variable.

2.2. Stack Frames

Stack frames (or frames) make up the Java Stack. A stack frame contains the state of a single Java method invocation. When a thread invokes a method or a function, the Java virtual machine pushes a new frame onto that thread’s Java stack. When the method completes, the JVM pops and discards the frame for that method.

This stack frame has three parts:

The size of a stack frame varies depending on the local variables and operand stack of that method. The unit of measurement being words.

When the JVM invokes a function or a method, it checks the class data to determine the number of words required by the method in the local variables and operand stack and then creates an appropriately sized stack frame for that method before pushing it unto the stack.

Each invocation of a method or function leads to the creation of a stack frame. The stack frame created contains all information about that method, and during execution, the code can only access the values in the current stack frame.

3. Loops

Scala, like every programming language, provides us with the ability to write loops. As an example, we’ll attempt to get the sum of a List of numbers both imperatively and functionally.

Here is an example of imperatively calculating the sum of a List of numbers:

def getSumOfList(list: List[Int]): Int = {
  var sum = 0
  for (i <- 0 until list.length) {
    sum += list(i)
  }
  sum
}
assert(getSumOfList((1 to 10).toList) == 55)

Although this appears to be correct, it’s not functional and could easily lead to problems such as race conditions if we decide to split the List into smaller Lists to be run by different threads and then accumulated. Scala discourages the use of var‘s unless when used in essential cases.

Another way we can implement a loop is via recursion.

3.1. Recursion

Recursion defines the process in which a function calls itself directly or indirectly. Most problems are solved using recursion because it breaks a problem into smaller tasks until it reaches a base condition or an end condition in which the recursion stops and the total result is then collated.

The use of recursion is a more functional approach to writing loops than using a for loop. Scala highly encourages the use of functional loops.

Recursive algorithms can sometimes create extremely deep call stacks and exhaust the stack space.

Here’s an example of calculating the same sum of a List using recursion:

def getSumOfListRecursive(list: List[Int]): Int = {
  list match {
    case Nil => 0
    case head :: tail =>
      head  +  getSumOfListRecursive(tail)
  }
}
assert(getSumOfListRecursive((1 to 10).toList) == 55)

We can see that there are no var‘s in sight. Our code is more efficient as we are not doing an indexed lookup on a List.

Instead of the use of var‘s , every time when getSumOfList calls itself, it pushes a new stack frame onto the calling stack with its own set of variables and then pops it when that particular function is completed.

This solution seems concise enough, but let’s try to sum up a very large List as shown in this example:

getSumOfListRecursive((1 to 10000).toList)

We run into this error:

Exception in thread "main" java.lang.StackOverflowError

This StackOverflowError implies that there were way more stack frames in our stack than it could handle.

Whenever a recursive function calls itself, information for the new instance of the function is pushed onto the stack as well as other function information. Because of this, each level of recursion requires a new stack frame.

As a result, the recursive function consumes more and more memory allocated to the stack. If the sum function calls itself a thousand times, a thousand stack frames are created.

Now does this mean that we can’t use recursion on large data structures? Luckily, Scala provides an efficient mechanism that enables us to use recursion on large data structures called tail recursion.

3.2. Tail Recursive Functions

A tail-recursive function describes a function whose last action is a direct call to itself. When recursive functions are written this way, the Scala compiler can optimize the resulting JVM bytecode such that the function requires only one stack frame, as opposed to one stack frame for each level of recursion.

Regular recursion creates lots of stack frames, and for algorithms that require deep levels of recursion, this creates a StackOverflowError (and crashes your program).

When the Scala compiler spots a tail-recursive function, it knows to optimize it by essentially turning it into a while loop. This means there are no more recursive calls and no more frames pushed onto the stack.

Here is our same example of calculating the sum of a List using tail recursion:

def getSumOfListTailRecursive(numbers: List[Int]): Int = {
  def innerFunction(list: List[Int], accumulator: Int): Int = {
    list match {
      case Nil => accumulator
      case head :: tail => innerFunction(tail, head + accumulator)
    }
  }
  innerFunction(numbers, 0) // give an initial accumulator
}
assert(getSumOfListTailRecursive((1 to 1000000).toList) == 1784293664)

With our tail-recursive function, we were able to perform recursion on a very large List. We could increase the size of the List and still be sure that we won’t run into a StackOverflowError.

If we compare this example with our previous non-tail-recursive function, we’ll see that in the tail-recursive function, the last action of the function was a call to itself, but in the former, it was:

head + getSumOfListRecursive(t)

The Scala compiler couldn’t optimize that as it wasn’t tail-recursive and that led to new stack frames for each level of recursion.

One way we could confirm if our function is tail-recursive is by adding this annotation to the top of our function:

@scala.annotation.tailrec

This way, our code won’t compile if the function isn’t tail-recursive.

3.3. Stack Frame Usage in Recursion

To prove the usage of a single stack frame in tail-recursive functions versus the multiple stack frames used in non-tail-recursive functions, we could easily print the difference between the number of stack frames at the point the function when the function began and when it reaches the last level of recursion, which in our case is when the List is empty or Nil in our case.

Here’s an example of printing the number of stack frames used in a non-tail-recursive function:

val length = Thread.currentThread().getStackTrace.length + 1
def getSumOfListRecursive(list: List[Int]): Int = {
  list match {
    case Nil =>
      println(Thread.currentThread().getStackTrace.length - length) // prints 100 for 100 items
      0
    case h :: t =>
      h + getSumOfListRecursive(t)
  }
}
assert(getSumOfListRecursive(1 to 100 toList) == 5050)

By checking the difference between the number of stack frames created before the function starts and at the last level of recursion, we see the direct proportionality between the number of stack frames created and the number of levels of recursion for non-tail-recursive functions.

Here’s an example of printing the number of stack frames used in a tail-recursive function:

val length = Thread.currentThread().getStackTrace.length + 1
def getSumOfListTailRecursive(numbers: List[Int]): Int = {
  def innerFunction(list: List[Int], accumulator: Int): Int = {
    list match {
      case Nil =>
        println(Thread.currentThread().getStackTrace.length - length) // prints 1 for 1000000 items
        accumulator
      case head :: tail =>
        innerFunction(tail, head + accumulator)
    }
  }
  innerFunction(numbers, 0) // give an initial accumulator
}
assert(getSumOfListTailRecursive((1 to 1000000).toList) == 1784293664)

In our non-tail-recursive function, We see the creation of only 1 stack frame despite the fact that we’re dealing with a million items.

4. Conclusion

In this article, we’ve seen how scala enables us to write loops both imperatively and functionally. We also looked into recursion as a functional approach to writing loops as well as its drawbacks and how the implementation of tail-recursive functions solves those drawbacks.

Code snippets and examples can be found over on GitHub.

Comments are closed on this article!