Implementing yield (yield return) using Scala continuations

Yang picture Yang · Feb 4, 2010 · Viewed 7k times · Source

How might one implement C# yield return using Scala continuations? I'd like to be able to write Scala Iterators in the same style. A stab is in the comments on this Scala news post, but it doesn't work (tried using the Scala 2.8.0 beta). Answers in a related question suggest this is possible, but although I've been playing with delimited continuations for a while, I can't seem to exactly wrap my head around how to do this.

Answer

Rich Dougherty picture Rich Dougherty · Feb 7, 2010

Before we introduce continuations we need to build some infrastructure. Below is a trampoline that operates on Iteration objects. An iteration is a computation that can either Yield a new value or it can be Done.

sealed trait Iteration[+R]
case class Yield[+R](result: R, next: () => Iteration[R]) extends Iteration[R]
case object Done extends Iteration[Nothing]

def trampoline[R](body: => Iteration[R]): Iterator[R] = {
  def loop(thunk: () => Iteration[R]): Stream[R] = {
    thunk.apply match {
      case Yield(result, next) => Stream.cons(result, loop(next))
      case Done => Stream.empty
    }
  }
  loop(() => body).iterator
}

The trampoline uses an internal loop that turns the sequence of Iteration objects into a Stream. We then get an Iterator by calling iterator on the resulting stream object. By using a Stream our evaluation is lazy; we don't evaluate our next iteration until it is needed.

The trampoline can be used to build an iterator directly.

val itr1 = trampoline {
  Yield(1, () => Yield(2, () => Yield(3, () => Done)))
}

for (i <- itr1) { println(i) }

That's pretty horrible to write, so let's use delimited continuations to create our Iteration objects automatically.

We use the shift and reset operators to break the computation up into Iterations, then use trampoline to turn the Iterations into an Iterator.

import scala.continuations._
import scala.continuations.ControlContext.{shift,reset}

def iterator[R](body: => Unit @cps[Iteration[R],Iteration[R]]): Iterator[R] =
  trampoline {
    reset[Iteration[R],Iteration[R]] { body ; Done }
  }

def yld[R](result: R): Unit @cps[Iteration[R],Iteration[R]] =
  shift((k: Unit => Iteration[R]) => Yield(result, () => k(())))

Now we can rewrite our example.

val itr2 = iterator[Int] {
  yld(1)
  yld(2)
  yld(3)
}

for (i <- itr2) { println(i) }

Much better!

Now here's an example from the C# reference page for yield that shows some more advanced usage. The types can be a bit tricky to get used to, but it all works.

def power(number: Int, exponent: Int): Iterator[Int] = iterator[Int] {
  def loop(result: Int, counter: Int): Unit @cps[Iteration[Int],Iteration[Int]] = {
    if (counter < exponent) {
      yld(result)
      loop(result * number, counter + 1)
    }
  }
  loop(number, 0)
}

for (i <- power(2, 8)) { println(i) }