# Trampolines
Originally by Sriram Sankaranarayanan <srirams@colorado>

Modified by Ravi Mangal <ravi.mangal@colostate>

Last Modified: Apr 9, 2025.

---


Trampolines go hand in hand with Continuation Passing Style (CPS) of writing programs. Let us quickly review the basic facts about CPS.
- We add an extra continuation argument to every function call in the program.
- We transform the program so that all function calls happen at the tail position.
- Finally, we __hope__ that the compiler/interpreter in all its goodness will optimize the tail call away.

## Fibonacci

We already saw the CPS version of Fibonacci

In [2]:
def fibonacci_k[T] (n: Int, k: Int => T): T = {
    if (n < 2)
        k(1)
    else 
        fibonacci_k(n-1, v1 => fibonacci_k(n-2, v2 => {k(v1 + v2)} ))
}

defined [32mfunction[39m [36mfibonacci_k[39m

In [3]:
fibonacci_k(10, print)

89

In [4]:
fibonacci_k(15, print)

987

In [5]:
fibonacci_k(20, print)

java.lang.StackOverflowError: null

What just happened? 
- Well, even though the function is tail recursive, Scala did not recognize it as tail recursive.

Scala runs on the Java Virtual Machines which implements a lot of powerful optimizations to make our code run faster. However, tail call optimization (TCO), which is a very basic optimization implemented in numerous languages is not possible in the JVM. Scala implements a weak version of this optimization that can fail to correctly recognize tail calls.

In [6]:
import scala.annotation.tailrec

@tailrec
def factorialTCO_k[T] (n: Int, k: Int => T): T = {
    if (n < 1)
        k(1)
    else 
        factorialTCO_k(n-1, v1 => {k(n*v1)})
}

[32mimport [39m[36mscala.annotation.tailrec[39m
defined [32mfunction[39m [36mfactorialTCO_k[39m

In [6]:
import scala.annotation.tailrec

@tailrec
def fibonacciTCO_k[T] (n: Int, k: Int => T): T = {
    if (n < 2)
        k(1)
    else 
        fibonacciTCO_k(n-1, v1 => fibonacciTCO_k(n-2, v2 => {k(v1 + v2)} ))
}

cmd7.sc:8: could not optimize @tailrec annotated method fibonacciTCO_k: it contains a recursive call not in tail position
        fibonacciTCO_k(n-1, v1 => fibonacciTCO_k(n-2, v2 => {k(v1 + v2)} ))
                                                ^
Compilation Failed

As you can see, it was clever enough to optimize the tail call in factorial but not clever enough to optimize that in fibonacci. This is because Scala is limited in its ability to recognize tail calls. A more sophisticated analysis could understand that ` v1 => fibonacciTCO_k(n-2, v2 => {k(v1 + v2)} )` is not a tail call but a closure.
But we have to work with the language we have and not the one we wish we had ;-)

## Trampolines

Trampolines are a trick that support continuation passing style in languages that lack the ability to do TCO. You can view it as a manual approach to tail call optimization. 

We make two simple changes: 
- The CPS program rather than calling the tail call simply returns the tail call as a closure.
- The trampoline is simply a while loop that will call the tail call while making sure that the stack never grows.


Therefore, rather than make the tail call, the CPS transform returns an object of type CPSResult.

A `CPSResult` can be one of two things:
- A _CALL_ object that encapsulates a function denoted by `Call( f: () => CPSResult[T] ) `. Call takes in a field `f` such that if we call `f` it returns an object of type `CPSResult[T]`
- A _DONE_ object that encapsulates a value of type `T` (the overall return value of our original function).


In [7]:
sealed trait CPSResult[T] // type of the message that comes back
case class Call[T](f: () => CPSResult[T]) extends CPSResult[T]
case class Done[T](v: T) extends CPSResult[T]

defined [32mtrait[39m [36mCPSResult[39m
defined [32mclass[39m [36mCall[39m
defined [32mclass[39m [36mDone[39m

The main idea is as follows. Suppose the original CPS function is of the form 

~~~
def originalCPSFunction( x: ..., k: ... => T) : T = {
   if ( ... base case condition ... )
       return k ( .. base args ..)
   else 
       ...
       return tailCall( new_x, new_k )
     }
~~~

the trampolined version is of the form

~~~
def trampolinedCPSFunction(x: .., k : ... => CPSResult[T]): CPSResult[T] = {
   if ( ... base case condition ... ) 
        return Call( () => k( .. base args .. ))
   else 
        ..... 
        return Call( () => tailCall(new_x, new_k_trampolined))
}
~~~

Notice that instead of calling `k` or calling `tailCall`, we return a `Call` object that
encapsulates a closure `() => <whatever we were calling originally> `. The idea of making a unit
closure is that it delays computation so that Scala does not evaluate `k(..base args..)` or 
`tailCall( ...) ` which would totally defeat the purpose.



In [27]:
def factorial_k[T](n: Int, k: Int => T): T = {
    if (n <= 0)
        k(1) 
    else 
       factorial_k(n-1, v => {k(n * v)}) 
}

// The trampolined version with continuations
def tramp_factorial_k[T](n: Int, k: Int => CPSResult[T]): CPSResult[T] = {
    println("DEBUG: I am in tramp_factorial_k")
    if (n <= 0)
        Call( () => { k(1)} ) // Message "call k(1)"
    else 
       // Used to be factorial_k(n-1, v => {k(n *v)})
       Call( () => tramp_factorial_k(n-1, v => { // Used to be k(n*v) 
                                        Call( () =>  {  k(n * v) } ) 
                                        }) )
    // call "this complicated continuation"
}

defined [32mfunction[39m [36mfactorial_k[39m
defined [32mfunction[39m [36mtramp_factorial_k[39m

Notice two main things: 
- The trampolined function when called always returns a new function encapsulated inside a Call object.
- We replace every function call f(...args ..) by Call( () => f( .. args .. ) ) 
- A trampolined version should never call another function. It always returns an object of type `CPSResult[T]`
- The type of the continuation used to be `k: Int => T` is now `k: Int => CPSResult[T]`.

Now we write the trampoline for factorial.

In [28]:
def factorial(n: Int):Int = {
    // It is important that the  continuation passed to the very first call
    // return Done(value) to indicate that the computation is done when it is called.
    // Identity function that encapsulates its result in Done
    
    // x => Done(x) 
    def terminal_continuation (x: Int): CPSResult[Int] = {  Done(x) }
    
    // Now instead of recursion, we will use a while loop
    var call_res = tramp_factorial_k(n, terminal_continuation)
    var done = false
    while (!done ){
        println("DEBUG: I am in trampoline!")
        call_res match {
            // Here is where we will call f
            case Call(f) => { call_res = f() }
            case Done(v) => {done = true}
        }
    }
    print("DEBUG: Trampoline is done.")
    call_res match {
        case Call(f) => { throw new AssertionError("This should never happen, since the while loop must have kept iterating until we saw Done")
                        }
        case Done(v: Int) => {return v}
    }
    
}

defined [32mfunction[39m [36mfactorial[39m

In [29]:
factorial(6)

DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in tramp_factorial_k
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: I am in trampoline!
DEBUG: Trampoline is done.

[36mres29[39m: [32mInt[39m = [32m720[39m

Notice how the code jumps between calling the trampoline and the tramp_factorial_k functions.

- Initial call to trampoline.
- Trampoline calls factorial function
- Factorial function returns a new function to call.
- Trampoline calls the function that factorial returns
- Call returns a function to the trampoline.
- Trampoline calls this function
- ....
- The returned function is the terminal
- Trampoline calls terminal function.

Note how the stack has at most two calls in it.
- The call to trampoline.
- The call to whatever trampoline has called.

Since whatever is called by trampoline just returns a closure and trampoline calls that closure.

How does this work for Fibonacci?

In [11]:
//was: def fibonacci_k[T](n: Int, k: Int => T): T
def tramp_fibonacci_k[T](n: Int, k: Int => CPSResult[T]): CPSResult[T]  = {
    
    if (n <= 2)
        {
            //was: return k(1)
            // since fibonacci should not call k, it returns a Call object to trampoline, which will call it.
            return Call( () => { k(1) } )
        }
    else 
        // was: fibonacci_k(n-1, v1 => fibonacci_k(n-2, v2 => k(v1+v2)))
        // make it into a call object
        // Do not forget to modify the continuation as well. 
        // Wherever you see a function being called, mechanically replace it by Call( () => fun-being-called)
        return Call( () => { tramp_fibonacci_k(n-1, 
                                         v1 => Call( () => { 
                                                        tramp_fibonacci_k(n-2, v2 => {
                                                               Call( () => { k(v1 + v2)} )
                                                        } ) } )
                                        )
                           }
                   )
}

defined [32mfunction[39m [36mtramp_fibonacci_k[39m

We build the trampoline in the usual way.

In [12]:
def fibonacci(n: Int): Int = {
    var done = false
    def k (t: Int) = Done(t)
    var res: CPSResult[Int] = tramp_fibonacci_k(n, k)
    while (!done){
        res match {
            case Call(f) => { res = f() } 
            case Done(v) => {done = true}
        }
    }
    res match {
            case Call(f) => throw new IllegalArgumentException("This should never happen -- since while loop above can only exit when done is true")
            case Done(v: Int) => {return v}
        }
}

defined [32mfunction[39m [36mfibonacci[39m

In [13]:
fibonacci(11)

[36mres13[39m: [32mInt[39m = [32m89[39m

In [14]:
fibonacci(15)

[36mres14[39m: [32mInt[39m = [32m610[39m

In [15]:
fibonacci(20)

[36mres15[39m: [32mInt[39m = [32m6765[39m

In [16]:
fibonacci(25)

[36mres16[39m: [32mInt[39m = [32m75025[39m

In [17]:
fibonacci(40)

[36mres17[39m: [32mInt[39m = [32m102334155[39m

Let us try a simple example over trees.

In [18]:
sealed trait NumTree
case class Node(j: Int, child1: NumTree, child2: NumTree) extends NumTree
case object Leaf extends NumTree

defined [32mtrait[39m [36mNumTree[39m
defined [32mclass[39m [36mNode[39m
defined [32mobject[39m [36mLeaf[39m

In [19]:
def heightOfTree(t: NumTree): Int = t match {
    case Leaf => 0
    case Node(j, c1, c2) => 1 + math.max(heightOfTree(c1), heightOfTree(c2))
}

def heightOfTree_k[T](t: NumTree, k: Int => T): T = t match {
    case Leaf => k(0)
    case Node(j, c1, c2) => heightOfTree_k(c1, v1 => { heightOfTree_k(c2, v2 => { k(1+ math.max(v1, v2)) } )})
}

def tramp_heightOfTree_k[T](t: NumTree, k: Int => CPSResult[T]) : CPSResult[T] = t match {
     case Leaf => Call(() => k(0) )
    case Node(j, c1, c2) => { Call(() => heightOfTree_k(c1, v1 => 
                                 { Call( () =>  heightOfTree_k(c2, v2 => { Call( () => k(1 + math.max(v1, v2)) ) }))})) }
}

defined [32mfunction[39m [36mheightOfTree[39m
defined [32mfunction[39m [36mheightOfTree_k[39m
defined [32mfunction[39m [36mtramp_heightOfTree_k[39m

In [20]:
def heightOfTreeWithTrampoline(t: NumTree): Int = {
    def terminal_cont(j: Int) = Done(j)
    var res = tramp_heightOfTree_k(t, terminal_cont)
    var done = false
    while (!done){
        res match {
            case Done(j) => {done = true}
            case Call(f) => {res = f() }
        }
    }
    res match {
            case Call(f) => throw new IllegalArgumentException("what the ..")
            case Done(v: Int) => {return v}
    }
}

defined [32mfunction[39m [36mheightOfTreeWithTrampoline[39m

In [21]:
val t1 = Node(18, Leaf, Leaf)
val t2 = Node(10,  Node(15, t1, Leaf), Node(28, Leaf, Leaf))
val t3 = Node(20, t1, t2)
val t4 = Node(19, t3, t3)

[36mt1[39m: [32mNode[39m = [33mNode[39m(j = [32m18[39m, child1 = Leaf, child2 = Leaf)
[36mt2[39m: [32mNode[39m = [33mNode[39m(
  j = [32m10[39m,
  child1 = [33mNode[39m(
    j = [32m15[39m,
    child1 = [33mNode[39m(j = [32m18[39m, child1 = Leaf, child2 = Leaf),
    child2 = Leaf
  ),
  child2 = [33mNode[39m(j = [32m28[39m, child1 = Leaf, child2 = Leaf)
)
[36mt3[39m: [32mNode[39m = [33mNode[39m(
  j = [32m20[39m,
  child1 = [33mNode[39m(j = [32m18[39m, child1 = Leaf, child2 = Leaf),
  child2 = [33mNode[39m(
    j = [32m10[39m,
    child1 = [33mNode[39m(
      j = [32m15[39m,
      child1 = [33mNode[39m(j = [32m18[39m, child1 = Leaf, child2 = Leaf),
      child2 = Leaf
    ),
    child2 = [33mNode[39m(j = [32m28[39m, child1 = Leaf, child2 = Leaf)
  )
)
[36mt4[39m: [32mNode[39m = [33mNode[39m(
  j = [32m19[39m,
  child1 = [33mNode[39m(
    j = [32m20[39m,
    child1 = [33mNode[39m(j = [32m18[39m, child1 = Leaf, child

In [22]:
heightOfTree(t4)

[36mres22[39m: [32mInt[39m = [32m5[39m

In [23]:
heightOfTreeWithTrampoline(t4)

[36mres23[39m: [32mInt[39m = [32m5[39m

In [24]:
heightOfTree_k(t4, x => x)

[36mres24[39m: [32mInt[39m = [32m5[39m