# Continuation Passing Style

Originally by Sriram Sankaranarayanan <srirams@colorado>

Modified by Ravi Mangal <ravi.mangal@colostate>

Last Modified: Apr 2, 2025.

---

Thus far, we have built interpreters for various features in Lettuce. However, all of our interpreters depended
on recursive calls to the eval function. The use of recursion was very convenient for us to translate 
the semantics directly into a Scala program. However, this is not ideal since we all know about recursions
and stacks. Thus, large programs can cause the stack to overflow. 

In this lecture, we will revisit the theme of eliminating non-tail recursion. We have already done this using
an accumulator. However, accumulators are limited in their scope. We will now present a general scheme
that works without accumulators.

## Recap: Recursion, Tail Recursion and Eliminating the Non-Tail Recursion

We will take a few minutes to quickly recap recursion, tail recursion and the problem of eliminating
non-tail recursion.

- Recursion causes the activation records to grow on the stack, potentially causing stack overflow.
- Tail recursion is a benign case when the result of any recursive calls are returned without any further processing.
- Tail recursive calls can be implemented such that the activation records need not grow.

In [1]:
/*-- Examples: As an exercise, classify these calls as tail recursive or not --*/

def rec_fun1(x: Int): Int = {
    if (x <= 0) {
        x
    } else {
        rec_fun1(x - 10)
    }
}

def rec_fun2(y: Int = 0, x: Int): Int = {
    if (x <= 10) {
        y
    } else {
        rec_fun2(y + 1, x - 10)
    }
}


def rec_fun3( x: Int): Int = {
    if (x <= 10) {
        x - 10
    } else {
        1 + rec_fun3( x - 10)
    }
}


def rec_fun4( x: Int): Int = {
    if (x <= 10) {
        x - 5
    } else {
        rec_fun1 ( rec_fun4( x - 10) )
    }
}

def foo(x: Int): Int = { x - 15}

def rec_fun5(x: Int): Int = {
    if (x <= 0){
        foo(x)
    } else {
        rec_fun5(foo(x))
    }
}


def rec_fun6(x: Int): Int = {
    if (x <= 0){
        foo(x)
    } else {
        foo(rec_fun6(x-5))
    }
}

defined [32mfunction[39m [36mrec_fun1[39m
defined [32mfunction[39m [36mrec_fun2[39m
defined [32mfunction[39m [36mrec_fun3[39m
defined [32mfunction[39m [36mrec_fun4[39m
defined [32mfunction[39m [36mfoo[39m
defined [32mfunction[39m [36mrec_fun5[39m
defined [32mfunction[39m [36mrec_fun6[39m

We have seen examples of how to use accumulators to remove non-tail recursion. This can work in some but not all cases.


In [5]:
def factorial(n: Int): Int = {
    if (n <= 0){
        1
    } else {
        n * factorial(n-1)
    }
}

def factorial_tail(acc: Int = 1, n: Int): Int = {
    if (n <= 0) {
        acc
    } else {
        factorial_tail(acc * n , n-1)
    }
}

def fibonacci(n: Int): Int = {
    if (n < 2){
        1
    } else {
        fibonacci(n-1) + fibonacci(n-2)
    }    
}

def fibonacci_tail(n: Int, acc1: Int = 1, acc2: Int = 1): Int = {
    if (n <= 0) {
        acc1
    } else if (n == 1) {
        return acc2
    } else {
        fibonacci_tail( n-1, acc2, acc1 + acc2)
    }
}





defined [32mfunction[39m [36mfactorial[39m
defined [32mfunction[39m [36mfactorial_tail[39m
defined [32mfunction[39m [36mfibonacci[39m
defined [32mfunction[39m [36mfibonacci_tail[39m

In [8]:
fibonacci(15)
fibonacci_tail(15)

[36mres7_0[39m: [32mInt[39m = [32m987[39m
[36mres7_1[39m: [32mInt[39m = [32m987[39m

Notice that the accumulator method requires us to change the logic of the method call in a non-trivial manner.
It is hard to argue that `fibonacci` and `fibonacci_tail` are the same algorithm. In fact they are not. The
change from `factorial` to `factorial_tail` also needs some insights into the nature of the `factorial` function.

Therefore, if our goal is to mechanically remove non-tail recursion and convert it into tail recursion, this does not help us.

## Continuations and Continuation Passing Style

Continuation passing style (CPS) is a "style" of programming wherein every function will have an extra argument
called the `continuation`. A continuation is a function that is passed in and specifies what the caller
wishes to do with the result that has been computed.

Take for instance, a function `func` that takes in an integer and returns an integer.
~~~
def func(x: Int): Int = {
     // .. do some work to compute result .. 
     return result
}
~~~

In the CPS, this function is now written as

~~~
def func-k(x: Int, k: Int => Int) : Int = {
    //  .. do some work to compute result ..
    k(result) // Pass the result onto the continuation.
}
~~~

Note that `func-k` takes in an extra argument `k` called continuation. It
is the function through which the caller specifies what they want done with
the result of the call. Rather than return the result and make the caller operate
on it, the caller simply bundles up the results and passes it all in.

Let us look at a concrete example. First take a look at these three functions defined below.






In [15]:
def addUp(x: Int, y: Int, z: Int): Int = {
    x + y + z
}

def multiply(x: Int, y: Int): Int = {
    x * y
}

def madd(x: Int, y: Int, z:Int): Int = {
    val v1 = multiply(x, y)
    val v2 = addUp(v1, y, z)
    return v2;
}

defined [32mfunction[39m [36maddUp[39m
defined [32mfunction[39m [36mmultiply[39m
defined [32mfunction[39m [36mmadd[39m

In [20]:
println(madd(1,2,3))

7


Let us now create the CPS version of these functions.

In [37]:
def addUp_k(x: Int, y: Int, z:Int, k: Int => Int): Int = {
    k(x + y + z)
}

def multiply_k(x: Int, y: Int, k: Int => Int): Int = {
    k ( x * y)
}

def madd_k(x: Int, y: Int, z: Int, k: Int => Int): Int ={
    // Create a new continuation.
    // This continuation k1 is a closure that will be passed to multiply.
    // It will be called by addUp_k but must do the work that was originally done by madd.
    def k1(v1: Int): Int = addUp_k(v1, y, z, k) // Add up v1, y, z and ask addUp_k to run k on the result.
    multiply_k(x, y, k1)
}


defined [32mfunction[39m [36maddUp_k[39m
defined [32mfunction[39m [36mmultiply_k[39m
defined [32mfunction[39m [36mmadd_k[39m

In [39]:
madd_k(1, 2, 3, x => x) // the function x => x is the identity function that just returns the argument.

[36mres38[39m: [32mInt[39m = [32m7[39m

A few things to notice: 
- First, the translation for `addUp` and multiply to `addUp_k` and `multiply_k` is straightforward. These functions get a new argument `k` for the continuation. They simply compute what they did originally and instead of returing the result, they simply call k on it.
- However, the tricky function is the madd_k function. What did the `madd` function do? 
    - Called multiply on x, y
    - Took the result and called the addUp function
- Thus, we can write down what `madd_k` should do.
    - call `multiply_k` on x, y and pass a continuation `k1` to `multiply_k`. What must this continuation do? 
    - The continuation `k1` should do the work `madd` would have done after the call to multiply returned.
       1. Call addUp_k
       2. Pass the result on to k.

In [40]:
def f1(x: Int): Int =  {
    if (x <= 0){
        1
    } else {
        3 + f1(x - 10)
    }
}

def f1_k(x: Int, k: Int => Int): Int = {
    if (x <= 0){
        k(1)
    } else {
        def k1(v: Int): Int = k(3 + v) // tell what to do with the result of the call
        f1_k(x - 10, k1)
    }
}


defined [32mfunction[39m [36mf1[39m
defined [32mfunction[39m [36mf1_k[39m

In [42]:
println(f1(25))
println(f1_k(25, x => x))

10
10


In [1]:
import scala.annotation.tailrec
def factorial(n: Int): Int = {
    if (n <= 0) {
        1
    } else {
        n * factorial(n-1)
    }
}

@tailrec
def factorial_k(n: Int, k: Int => Int): Int = {
    if (n <= 0) {
        k(1)
    } else {
        def k1(v: Int): Int = {
            k(v * n)
        }
        factorial_k(n-1, k1)
    }
}

[32mimport [39m[36mscala.annotation.tailrec[39m
defined [32mfunction[39m [36mfactorial[39m
defined [32mfunction[39m [36mfactorial_k[39m

In [2]:
println(factorial(6))
println(factorial_k(6, x => x))

720
720


Note that the CPS style has important properties that you should check:
1. Every function `f_k` now has an extra argument called the continuation. 
2. The input to the continuation is the result of the function f.
3. The original call has a "terminal continuation" typically identity `x => x`
4. Each path in the code within a function ends in a function call and there cannot be any intermediate function call.
5. __All calls are tail calls__.

In theory therefore, these tail calls must be optimized away by the compiler/interpreter. However, in practice it may not be always as simple. For instance, Scala has serious limitations in how it handles tail calls.

In [47]:
def fibonacci(n: Int): Int = {
    if (n < 2){
        1
    } else {
        fibonacci(n-1) + fibonacci(n-2)
    }    
}


def fibonacci_k(n: Int, k: Int => Int): Int = {
    if (n < 2) {
        k(1)
    } else {
        // k1 is the continuation that will be passed to 
        // fibonacci(n-1). It instructs that call on what to do 
        // afterwards.
        def k1(v: Int): Int = {
            // k2 is the continuation that will be passed to 
            // fibonacci(n-2). It instructs that call on what to do 
            // afterwards.
            def k2(v2: Int): Int = {
                k (v + v2) // Just add up the two fibonacci results and pass it to k.
            }
            fibonacci_k(n-2, k2) // Call fibonacci on n-2 and execute function k2 on the result
        }
        // call fibonacci on n-1 and execute k1 on the result.
        fibonacci_k(n-1, k1)
    }
}

defined [32mfunction[39m [36mfibonacci[39m
defined [32mfunction[39m [36mfibonacci_k[39m

In [48]:
println(fibonacci(12))
println(fibonacci_k(12, x => x))

233
233


In [51]:
println(fibonacci_k(18, x => x))

: 

In [36]:
import scala.annotation.tailrec

    
@tailrec
def fibonacci_k(n: Int, k: Int => Int): Int = {
    if (n < 2) {
        k(1)
    } else {
        fibonacci_k(n-1, {
            v1 => fibonacci_k(n-2, {
                v2 => k(v1+v2)
            })
        })
    }
}


cmd36.sc:10: could not optimize @tailrec annotated method fibonacci_k: it contains a recursive call not in tail position
            v1 => fibonacci_k(n-2, {
                             ^

: 

## Recipe for transforming to CPS.

- Add the continuation parameter to all the functions. What should the type of the continuation parameter be?
- Now transform each function to be in the CPS. 
 
### Case 1

If the function we are converting has no other function calls, then there is not much to do. Just rememeber to call the continuation parameter k on the return value.


In [52]:
def simple_fun(x: Int): Int = {
    val y = x * x
    val z = y + y - 5 * x
    if (z <= 0)
        1
    else 
        z
}


def simple_fun_k(x: Int, k: Int => Int): Int = {
    val y = x * x
    val z = y + y - 5 * x
    if (z <= 0)
        k(1) // remember to call k on the return value 
    else 
        k(z) // reember to call k on the return value
}


defined [32mfunction[39m [36msimple_fun[39m
defined [32mfunction[39m [36msimple_fun_k[39m

### Case 2

If the function we are converting has just one function call in each branch and it is a tail call, then again, we simply convert the tail calls into their CPS version and remember to pass the continuation parameter to them.

In [54]:
def tail_call_fun(x: Int): Int = {
    if ( x >= 0) {
        simple_fun(x + 1)
    } else {
        val y = x * x - 2
        simple_fun(y)
    }
}

def tail_call_fun_k(x: Int, k: Int => Int): Int = {
    if ( x >= 0) {
        simple_fun_k(x + 1, k) // Convert to CPS version and remember to pass my own continuation in
    } else {
        val y = x * x - 2
        simple_fun_k(y, k) // Convert to CPS version and remember to pass my own continuation in
    }
}

defined [32mfunction[39m [36mtail_call_fun[39m
defined [32mfunction[39m [36mtail_call_fun_k[39m

### Case 3 

If the routine we are converting has a function call but the function call is not the last operation in some branch, then CPS transformation is a little more involved. Consider the example below:

~~~
def fun(...): R = {
  if (condA)
    block A 
  else if (condB) 
    block B
  else {
      
      block C
      
      val v = fun2(args2)
      ...
         code block D
         this block can involve v and other local variables.
      ...
      return x
  
  }

}
~~~

In this situation, we first create a closure for a new continuation and then pass this new continuation to the function being called.

~~~
def fun_k(..., k: R => R): R = {
  if (condA)
    CPS TRANSFORMED block A 
  else if (condB) 
    CPS TRANSFORMED block B
  else {
      CPS TRANSFORMED block C
      // CREATE the new continuation
      def k1(v:T): R =  { 
        ...
         CPS TRANSFORMED code block D
         this block can involve v and other local variables.
        ...
        return k(x)
      }
      fun2_k(args2, k1)
  }
}
~~~

Let us look at an example

In [59]:
def fancy_function(x: Int, y: Int): Int = {
    if (x == 0)
        return 0
    else if (x > 0) {
        val s1 = 25
        val y1 = x * y + x - y
        s1 + y1
    } else {
        val y1 = tail_call_fun(x)
        y1 + y - 2 * x
    }
    
}

defined [32mfunction[39m [36mfancy_function[39m

In [58]:
def fancy_function_k(x: Int, y: Int, k: Int => Int): Int = {
    if (x == 0)
        return k(0)
    else if (x > 0) {
        val s1 = 25
        val y1 = x * y + x - y
        k(s1 + y1)
    } else {
        // Transform code after call 
        //  y1 + y - 2 * x
        def k1(y1: Int): Int = {
            k(y1 + y - 2 * x)
        }
        tail_call_fun_k(x,  k1)
    }
    
}

defined [32mfunction[39m [36mfancy_function_k[39m

In [60]:
def even_more_fancy(x: Int): Int = {
    val v1 = fancy_function(x, x - 2)
    val v2 = fancy_function(x-2, x)
    val v3 = tail_call_fun(v1)
    val v4 = v1 + v2 + v3
    fancy_function(v4, v3)
}

defined [32mfunction[39m [36meven_more_fancy[39m

In [61]:
def even_more_fancy_k(x: Int, k: Int => Int): Int = {
    def k1 (v1: Int) : Int = {
        /* CPS TRANSFORM OF 
        val v2 = fancy_function(x-2, x)
        val v3 = tail_call_fun(v1)
        val v4 = v1 + v2 + v3
        fancy_function(v4, v3)
        */
        def k2(v2: Int): Int = {
            /* CPS TRANSFORM 
             val v3 = tail_call_fun(v1)
             val v4 = v1 + v2 + v3
            fancy_function(v4, v3)
            */
            def k3(v3: Int): Int = {
                /*
                 val v4 = v1 + v2 + v3
                fancy_function(v4, v3)
                */
                val v4 = v1 + v2 + v3
                k(fancy_function(v4, v3))
            }
            tail_call_fun_k(v1, k3)
            
        }
        fancy_function_k(x-2, x, k2)
    }
    
    fancy_function_k(x, x-2, k1)  
}

defined [32mfunction[39m [36meven_more_fancy_k[39m

In [64]:
even_more_fancy(15)

[36mres63[39m: [32mInt[39m = [32m1124682442[39m

In [65]:
even_more_fancy_k(15, x=> x)

[36mres64[39m: [32mInt[39m = [32m1124682442[39m

## Polymorphic Continuations

Thus far, we have lived in the happy and lucky world where all functions had integer arguments and returned integers. Reality knocks (you down) and thus, we have to contend with functions having a lot of possible return types. However, unfortunately, this means that the CPS transformation will have to change the return type of each function we are transforming. Also, this type is not known in advance. We will motivate the need to perform polymorphic continuations. 

Consider the following example.

In [83]:
def utilityFunction(x: Int): Int = x + 2

def call1(x: String): String = (utilityFunction(x.toInt)).toString

def call2(x: Int): Float = utilityFunction(x).toFloat

def mainFunction(x: Int):String = {
    val v1 = call1(x.toString)
    val v2 = call2(x)
    v1 + v2.toString
}

defined [32mfunction[39m [36mutilityFunction[39m
defined [32mfunction[39m [36mcall1[39m
defined [32mfunction[39m [36mcall2[39m

We try to do a CPS transform as we are used to, but sadly we find that it is not working

In [83]:
def utilityFunction_k(x: Int, k: Int => Int): Int = k(x + 2)

def call1_k(x: String, k: String=> String): String = {
    utilityFunction_k(x .toInt, { v => k(v.toString)})
}

def call2_k(x: Int, k: Float => Float): Float = {
    utilityFunction_k(x, {f => k(f.toFloat)})
}


cmd83.sc:4: type mismatch;
 found   : String
 required: Int
    utilityFunction_k( x.toInt, { v => k(v.toString)})
                                        ^cmd83.sc:8: type mismatch;
 found   : Float
 required: Int
    utilityFunction_k(x, {f => k(f.toFloat)})
                                ^

: 

The reason is that `utilityFunction_k` is being called from two different call sites. Unfortunately, the continuations at these sites return two different return types. Therefore, we have to allow `utilityFunction_k` to be more general.

In [87]:
def utilityFunction_k[T1](x: Int, k: Int => T1): T1 = k(x + 2)

def call1_k[T2](x: String, k: String=> T2): T2 = {
    utilityFunction_k[T2]( x.toInt, { v => k(v.toString)})
}

def call2_k[T3](x: Int, k: Float => T3): T3 = {
    utilityFunction_k[T3](x, {f => k(f.toFloat)})
}

def mainFunction_k(x: Int, k:String => String):String = {
    call1_k[String](x.toString, v1 => {
      call2_k[String](x, v2 => {
          k(v1 + v2.toString)
      })  
    })
}

mainFunction_k(25, x => x)


defined [32mfunction[39m [36mutilityFunction_k[39m
defined [32mfunction[39m [36mcall1_k[39m
defined [32mfunction[39m [36mcall2_k[39m
defined [32mfunction[39m [36mmainFunction_k[39m
[36mres86_4[39m: [32mString[39m = [32m"2727.0"[39m

In general, it is always a good idea to build the CPS transformation assuming that the continuation can have any return type.

## CPS Interpreter For Lettuce

We are now ready (with trepidation) to write a CPS interpreter for the expression eval function. Let us go back to the very simple interpreter for Lettuce with Let bindings.


$$\begin{array}{rcll}
\mathbf{Program} & \rightarrow & TopLevel(\mathbf{Expr}) \\[5pt]
\mathbf{Expr} & \rightarrow & Const(\mathbf{Number}) \\
 & | & Ident(\mathbf{Identifier}) \\
 & | & Plus(\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Minus(\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Mult (\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Geq (\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Eq (\mathbf{Expr}, \mathbf{Expr}) \\
 & | & IfThenElse(\mathbf{Expr}, \mathbf{Expr}, \mathbf{Expr}) & \text{if (expr) then expr else expr} \\
 & | & Let( \mathbf{Identifier}, \mathbf{Expr}, \mathbf{Expr}) & \text{let identifier = expr in expr} \\
 & | & FunDef( \mathbf{Identifier}, \mathbf{Expr}) & \text{function (identifier) expr } \\ 
 & | & FunCall(\mathbf{Expr}, \mathbf{Expr}) & \text{function call - identifier(expr)} \\[5pt]
\end{array}$$

In [66]:
sealed trait Program
sealed trait Expr

case class TopLevel(e: Expr) extends Program

case class Const(v: Double) extends Expr // Expr -> Const(v)
case class Ident(s: String) extends Expr // Expr -> Ident(s)

// Arithmetic Expressions
case class Plus(e1: Expr, e2: Expr) extends Expr // Expr -> Plus(Expr, Expr)
case class Minus(e1: Expr, e2: Expr) extends Expr // Expr -> Minus(Expr, Expr)
case class Mult(e1: Expr, e2: Expr) extends Expr // Expr -> Mult (Expr, Expr)

// Boolean Expressions
case class Geq(e1: Expr, e2:Expr) extends Expr
case class Eq(e1: Expr, e2: Expr) extends Expr


//If then else
case class IfThenElse(e: Expr, eIf: Expr, eElse: Expr) extends Expr

//Let bindings
case class Let(s: String, defExpr: Expr, bodyExpr: Expr) extends Expr

//Function definition
case class FunDef(param: String, bodyExpr: Expr) extends Expr

// Function call
case class FunCall(funCalled: Expr, argExpr: Expr) extends Expr

defined [32mtrait[39m [36mProgram[39m
defined [32mtrait[39m [36mExpr[39m
defined [32mclass[39m [36mTopLevel[39m
defined [32mclass[39m [36mConst[39m
defined [32mclass[39m [36mIdent[39m
defined [32mclass[39m [36mPlus[39m
defined [32mclass[39m [36mMinus[39m
defined [32mclass[39m [36mMult[39m
defined [32mclass[39m [36mGeq[39m
defined [32mclass[39m [36mEq[39m
defined [32mclass[39m [36mIfThenElse[39m
defined [32mclass[39m [36mLet[39m
defined [32mclass[39m [36mFunDef[39m
defined [32mclass[39m [36mFunCall[39m

In [67]:
/* 1. Define the values */
sealed trait Value 
case class NumValue(d: Double) extends Value
case class BoolValue(b: Boolean) extends Value
/* -- Let us add Closure to the set of values --*/
case class Closure(x: String, e: Expr, pi: Map[String, Value]) extends Value
case object ErrorValue extends Value


/*2. Operators on values */

def valueToNumber(v: Value): Double = v match {
    case NumValue(d) => d
    case _ => throw new IllegalArgumentException(s"Error: Asking me to convert Value: $v to a number")
}

def valueToBoolean(v: Value): Boolean = v match {
    case BoolValue(b) => b
    case _ => throw new IllegalArgumentException(s"Error: Asking me to convert Value: $v to a boolean")
}

def valueToClosure(v: Value): Closure = v match {
    case Closure(x, e, pi) => Closure(x, e, pi)
    case _ =>  throw new IllegalArgumentException(s"Error: Asking me to convert Value: $v to a closure")
}


defined [32mtrait[39m [36mValue[39m
defined [32mclass[39m [36mNumValue[39m
defined [32mclass[39m [36mBoolValue[39m
defined [32mclass[39m [36mClosure[39m
defined [32mobject[39m [36mErrorValue[39m
defined [32mfunction[39m [36mvalueToNumber[39m
defined [32mfunction[39m [36mvalueToBoolean[39m
defined [32mfunction[39m [36mvalueToClosure[39m

In [68]:
def evalExpr(e: Expr, env: Map[String, Value]): Value =  {
    
    /* Method to deal with binary arithmetic operations */
    
    def applyArith2 (e1: Expr, e2: Expr) (fun: (Double , Double) => Double) = {
        val v1 = valueToNumber(evalExpr(e1, env))
        val v2 = valueToNumber(evalExpr(e2, env))
        val v3 = fun(v1, v2)
        NumValue(v3)
    }  /* -- We have deliberately curried the method --*/
    
    /* Helper method to deal with unary arithmetic */
    def applyArith1(e: Expr) (fun: Double => Double) = {
        val v = valueToNumber(evalExpr(e, env))
        val v1 = fun(v)
        NumValue(v1)
    }
    
    /* Helper method to deal with comparison operators */
    def applyComp(e1: Expr, e2: Expr) (fun: (Double, Double) => Boolean) = {
        val v1 = valueToNumber(evalExpr(e1, env))
        val v2 = valueToNumber(evalExpr(e2, env))
        val v3 = fun(v1, v2)
        BoolValue(v3)
    }
    
   
    e match {
        case Const(f) => NumValue(f)
        
        case Ident(x) => {
            if (env contains x) 
                env(x)
            else 
                throw new IllegalArgumentException(s"Undefined identifier $x")
        }
    
    
        case Plus(e1, e2) => applyArith2 (e1, e2) ( _ + _ )
    
        case Minus(e1, e2) => applyArith2(e1, e2) ( _ - _ )
    
        case Mult(e1, e2) =>  applyArith2(e1, e2) (_ * _)
    
    
    
        case Geq(e1, e2) => applyComp(e1, e2)(_ >= _)
    
        case Eq(e1, e2) => applyComp(e1, e2)(_ == _)
    
        
    
        case IfThenElse(e1, e2, e3) => {
            val v = evalExpr(e1, env)
            v match {
                case BoolValue(true) => evalExpr(e2, env)
                case BoolValue(false) => evalExpr(e3, env)
                case _ => throw new IllegalArgumentException(s"If-then-else condition expr: ${e1} is non-boolean -- evaluates to ${v}")
            }
        }
    
        case Let(x, e1, e2) => {
            val v1 = evalExpr(e1, env)  // eval e1
            val env2 = env + (x -> v1) // create a new extended env
            evalExpr(e2, env2) // eval e2 under that.
        }
    
        case FunDef(x, e) => {
            Closure(x, e, env) // Return a closure with the current enviroment.
        }
        
        case FunCall(e1, e2) => {
            val v1 = evalExpr(e1, env)
            val v2 = evalExpr(e2, env)
            v1 match {
                case Closure(x, closure_ex, closed_env) => {
                    // First extend closed_env by binding x to v2
                    val new_env = closed_env + ( x -> v2)
                    // Evaluate the body of the closure under the extended environment.
                    evalExpr(closure_ex, new_env)
                }
                case _ => throw new IllegalArgumentException(s"Function call error: expression $e1 does not evaluate to a closure")
            }
        }
    }
}

def evalProgram(p: Program) = {
    val m: Map[String, Value] = Map[String,Value]()
    p match { 
        case TopLevel(e) => evalExpr(e, m)
    }
}

defined [32mfunction[39m [36mevalExpr[39m
defined [32mfunction[39m [36mevalProgram[39m

Now we are ready to write a CPS style interpreter. Be patient, this may not work out in the first try.

In [69]:
/*2. Operators on values */

def valueToNumberCPS[T](v: Value, k: Double => T): T = v match {
    case NumValue(d) => k(d)
    case _ => throw new IllegalArgumentException(s"Error: Asking me to convert Value: $v to a number")
}

def valueToBooleanCPS[T](v: Value, k: Boolean => T): T = v match {
    case BoolValue(b) => k(b)
    case _ => throw new IllegalArgumentException(s"Error: Asking me to convert Value: $v to a boolean")
}

def valueToClosureCPS[T](v: Value, k: Closure => T): T = v match {
    case Closure(x, e, pi) => k(Closure(x, e, pi))
    case _ =>  throw new IllegalArgumentException(s"Error: Asking me to convert Value: $v to a closure")
}

defined [32mfunction[39m [36mvalueToNumberCPS[39m
defined [32mfunction[39m [36mvalueToBooleanCPS[39m
defined [32mfunction[39m [36mvalueToClosureCPS[39m

In [72]:
def evalExprCPS[T](e: Expr, env: Map[String, Value], k: Value => Value): Value =  {
    
    /* Method to deal with binary arithmetic operations */
    
    def applyArith2 (e1: Expr, e2: Expr) (fun: (Double , Double) => Double)  = {
       
        /*
        val u1 = evalExpr(e1, env)
        val v1 = valueToNumber(u1)
        val u2 = evalExpr(e2, env)
        val v2 = valueToNumber(u2)
        val v3 = fun(v1, v2)
        NumValue(v3)
        */
        
        evalExprCPS[Value] (e1,  env, {
            u1 => valueToNumberCPS[Value](u1, {
                v1 => {
                    evalExprCPS[Value](e2, env, {
                        u2 => {
                            valueToNumberCPS[Value](u2,{
                                    v2 => {
                                        k(NumValue(fun(v1, v2)))
                                    }
                            })
                        }
                    })
                }  
            })
        })
    } 
   
    
    /* Helper method to deal with comparison operators */
    def applyComp(e1: Expr, e2: Expr) (fun: (Double, Double) => Boolean)  = {
        /* val u1 = evalExpr(e1, env)
        val v1 = valueToNumber(u1)
        val u2 = evalExpr(e2, env)
        val v2 = valueToNumber(u2)
        val v3 = fun(v1, v2)
        BoolValue(v3)*/
        evalExprCPS[Value] (e1, env, {
            u1 => valueToNumberCPS[Value](u1, { 
                v1 => evalExprCPS(e2, env, {
                    u2 => valueToNumberCPS[Value]( u2, {
                        v2 => k(BoolValue(fun(v1, v2)))   
                    })
                })
        })
    })
    }
    
   
    e match {
        case Const(f) => k(NumValue(f))
        
        case Ident(x) => {
            if (env contains x) 
                k(env(x))
            else 
                throw new IllegalArgumentException(s"Undefined identifier $x")
        }
    
    
        case Plus(e1, e2) => applyArith2 (e1, e2) ( _ + _ )
    
        case Minus(e1, e2) => applyArith2(e1, e2) ( _ - _ )
    
        case Mult(e1, e2) =>  applyArith2(e1, e2) (_ * _)
    
    
    
        case Geq(e1, e2) => applyComp(e1, e2)(_ >= _)
    
        case Eq(e1, e2) => applyComp(e1, e2)(_ == _)
    
        
    
        case IfThenElse(e1, e2, e3) => {
            evalExprCPS(e1, env, {
                case BoolValue(true) => evalExprCPS(e2, env, k)
                case BoolValue(false) => evalExprCPS(e3, env, k)
                case _ => throw new IllegalArgumentException(s"If-then-else condition expr: ${e1} is non-boolean")
            })
        }
    
        case Let(x, e1, e2) => {
            evalExprCPS(e1, env, {
                v1 => {
                    val env2 = env + (x -> v1) // create a new extended env
                    evalExprCPS(e2, env2, k) // eval e2 under that.
                }})
        }
    
        case FunDef(x, e) => {
            k(Closure(x, e, env)) // Return a closure with the current enviroment.
        }
        
        case FunCall(e1, e2) => {
            evalExprCPS(e1, env, {
                v1 => {
                    evalExprCPS(e2, env, {
                        v2 => {
                            v1 match {
                                case Closure(x, closure_ex, closed_env) => {
                                    // First extend closed_env by binding x to v2
                                    val new_env = closed_env + ( x -> v2)
                                    // Evaluate the body of the closure under the extended environment.
                                    evalExprCPS(closure_ex, new_env, k)
                                }
                               case _ => throw new IllegalArgumentException(s"Function call error: expression $e1 does not evaluate to a closure")
                            }
                        }
                    })
                }               
        })
    }
}
}

def evalProgramCPS(p: Program) = {
    val m: Map[String, Value] = Map[String,Value]()
    p match { 
        case TopLevel(e) => evalExprCPS(e, m, x => x)
    }
}

defined [32mfunction[39m [36mevalExprCPS[39m
defined [32mfunction[39m [36mevalProgramCPS[39m

### Example 1
~~~
let square = function(x) 
                x * x
in 
    square(10) 
~~~

In [73]:
val p1 = TopLevel( 
    Let("square",                                // let square = 
         FunDef("x", Mult(Ident("x"), Ident("x"))),  //    function (x) x * x
         FunCall(Ident("square"), Const(10)) //     in  square(10)
       )
)

evalProgramCPS(p1)

[36mp1[39m: [32mTopLevel[39m = TopLevel(Let(square,FunDef(x,Mult(Ident(x),Ident(x))),FunCall(Ident(square),Const(10.0))))
[36mres72_1[39m: [32mValue[39m = NumValue(100.0)

### Example 2

~~~
let x = 10 in
    let y = 15 in 
        let sq1 = function (x)
                    function (y) 
                        x + y * y
        in 
            sq1(x)(y)
~~~

In [74]:
val x = Ident("x")
val y = Ident("y")
val fdef_inner = FunDef("y", Plus(x, Mult(y, y)))
val fdef_outer = FunDef("x", fdef_inner)
val call_expr = FunCall(FunCall(Ident("sq1"), x), y)
val sq1_call = Let("sq1", fdef_outer, call_expr)
val lety = Let("y", Const(15), sq1_call)
val letx = Let("x", Const(10), lety)
val p2 = TopLevel(letx)
evalProgramCPS(p2)

[36mx[39m: [32mIdent[39m = [33mIdent[39m([32m"x"[39m)
[36my[39m: [32mIdent[39m = [33mIdent[39m([32m"y"[39m)
[36mfdef_inner[39m: [32mFunDef[39m = FunDef(y,Plus(Ident(x),Mult(Ident(y),Ident(y))))
[36mfdef_outer[39m: [32mFunDef[39m = FunDef(x,FunDef(y,Plus(Ident(x),Mult(Ident(y),Ident(y)))))
[36mcall_expr[39m: [32mFunCall[39m = FunCall(FunCall(Ident(sq1),Ident(x)),Ident(y))
[36msq1_call[39m: [32mLet[39m = Let(sq1,FunDef(x,FunDef(y,Plus(Ident(x),Mult(Ident(y),Ident(y))))),FunCall(FunCall(Ident(sq1),Ident(x)),Ident(y)))
[36mlety[39m: [32mLet[39m = Let(y,Const(15.0),Let(sq1,FunDef(x,FunDef(y,Plus(Ident(x),Mult(Ident(y),Ident(y))))),FunCall(FunCall(Ident(sq1),Ident(x)),Ident(y))))
[36mletx[39m: [32mLet[39m = Let(x,Const(10.0),Let(y,Const(15.0),Let(sq1,FunDef(x,FunDef(y,Plus(Ident(x),Mult(Ident(y),Ident(y))))),FunCall(FunCall(Ident(sq1),Ident(x)),Ident(y)))))
[36mp2[39m: [32mTopLevel[39m = TopLevel(Let(x,Const(10.0),Let(y,Const(15.0),Let(sq1,FunDef(x

### Example 3


~~~
let h = function(z) 
         z + z
in 
    let g = function(y)
                y * 2.0 + h(y * 1.5)
    in
        let f = function (x) 
                    1.0 - x + g(x)
        in 
            f(3.1415)
~~~

In [75]:
val x = Ident("x")
val y = Ident("y")
val z = Ident("z")

val fDef = FunDef("x", Plus(Minus(Const(1.0), x), FunCall(Ident("g"), x)) )
val gDef = FunDef("y", Plus(Mult(y, Const(2.0)), FunCall(Ident("h"), Mult(y, Const(1.5)))))
val hDef = FunDef("z", Plus(z, z))

val letf = Let("f", fDef, FunCall(Ident("f"), Const(3.1415)))
val letg = Let("g", gDef, letf)
val leth = Let("h", hDef, letg)

val p3 = TopLevel(leth)
evalProgramCPS(p3)

[36mx[39m: [32mIdent[39m = [33mIdent[39m([32m"x"[39m)
[36my[39m: [32mIdent[39m = [33mIdent[39m([32m"y"[39m)
[36mz[39m: [32mIdent[39m = [33mIdent[39m([32m"z"[39m)
[36mfDef[39m: [32mFunDef[39m = FunDef(x,Plus(Minus(Const(1.0),Ident(x)),FunCall(Ident(g),Ident(x))))
[36mgDef[39m: [32mFunDef[39m = FunDef(y,Plus(Mult(Ident(y),Const(2.0)),FunCall(Ident(h),Mult(Ident(y),Const(1.5)))))
[36mhDef[39m: [32mFunDef[39m = FunDef(z,Plus(Ident(z),Ident(z)))
[36mletf[39m: [32mLet[39m = Let(f,FunDef(x,Plus(Minus(Const(1.0),Ident(x)),FunCall(Ident(g),Ident(x)))),FunCall(Ident(f),Const(3.1415)))
[36mletg[39m: [32mLet[39m = Let(g,FunDef(y,Plus(Mult(Ident(y),Const(2.0)),FunCall(Ident(h),Mult(Ident(y),Const(1.5))))),Let(f,FunDef(x,Plus(Minus(Const(1.0),Ident(x)),FunCall(Ident(g),Ident(x)))),FunCall(Ident(f),Const(3.1415))))
[36mleth[39m: [32mLet[39m = Let(h,FunDef(z,Plus(Ident(z),Ident(z))),Let(g,FunDef(y,Plus(Mult(Ident(y),Const(2.0)),FunCall(Ident(h),Mult(Iden

### Example 4 (Bad)

~~~
let f = function (x) 
            if (0 >= x) 
                1
            else
                (x - 1)* f(x - 1 )
in 
    f(10)
~~~


In [76]:
// We will have a exception thrown here: interesting to see the stack trace.
val x = Ident("x")
val compX = Geq(Const(0), x)
val recExpr = Mult(Minus(x, Const(1.0)), FunCall(Ident("f"), Minus(x, Const(1.0))))
val f_defn = FunDef("x", IfThenElse(compX, Const(1.0), recExpr))
val letf = Let("f", f_defn, FunCall(Ident("f"), Const(10.0)))
val p4 = TopLevel(letf)
evalProgramCPS(p4)

: 