Skip to content

Commit

Permalink
Complete ch05
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhijit Sarkar committed Dec 21, 2024
1 parent 72a9350 commit 3d7f461
Show file tree
Hide file tree
Showing 14 changed files with 433 additions and 0 deletions.
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ align.preset = more
maxColumn = 120
runner.dialect = scala3
assumeStandardLibraryStripMargin = true
# https://github.com/scalameta/scalameta/issues/4090
project.excludePaths = [
"glob:**/ch04/src/**.scala"
]
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The older code is available in branches.

2. [Algebraic Data Types](ch02)
3. [Objects as Codata](ch03)
4. [Contextual Abstraction](ch04)
5. [Reified Interpreters](ch05)

## Running tests
```
Expand Down
48 changes: 48 additions & 0 deletions ch05/src/Expression.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package ch05

/*
The core of the interpreter strategy is a separation between description and action.
The description is the program, and the interpreter is the action that carries out the program.
This separation is allows for composition of programs, and managing effects by delaying
them till the time the program is run. We sometimes call this structure an algebra, with
constructs and combinators defining programs and destructors defining interpreters.
The interpreter is then a structural recursion over this ADT.
We saw that the straightforward implementation is not stack-safe, and which caused us to
to introduction the idea of tail recursion and continuations. We reified continuations
functions, and saw that we can convert any program into continuation-passing style which
has every method call in tail position. Due to Scala runtime limitations not all calls
in tail position can be converted to tail calls, so we reified calls and returns into
data structures used by a recursive loop called a trampoline.
*/
enum Expression:
case Literal(value: Double)
case Addition(left: Expression, right: Expression)
case Subtraction(left: Expression, right: Expression)
case Multiplication(left: Expression, right: Expression)
case Division(left: Expression, right: Expression)

def +(that: Expression): Expression =
Addition(this, that)

def -(that: Expression): Expression =
Subtraction(this, that)

def *(that: Expression): Expression =
Multiplication(this, that)

def /(that: Expression): Expression =
Division(this, that)

def eval: Double =
this match
case Literal(value) => value
case Addition(left, right) => left.eval + right.eval
case Subtraction(left, right) => left.eval - right.eval
case Multiplication(left, right) => left.eval * right.eval
case Division(left, right) => left.eval / right.eval

object Expression:
def apply(value: Double): Expression =
Literal(value)
42 changes: 42 additions & 0 deletions ch05/src/ExpressionC.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package ch05

// Continuation-Passing style.
enum ExpressionC:
case Literal(value: Double)
case Addition(left: ExpressionC, right: ExpressionC)
case Subtraction(left: ExpressionC, right: ExpressionC)
case Multiplication(left: ExpressionC, right: ExpressionC)
case Division(left: ExpressionC, right: ExpressionC)

def +(that: ExpressionC): ExpressionC =
Addition(this, that)

def -(that: ExpressionC): ExpressionC =
Subtraction(this, that)

def *(that: ExpressionC): ExpressionC =
Multiplication(this, that)

def /(that: ExpressionC): ExpressionC =
Division(this, that)

def eval: Double =
type Continuation = Double => Double

def loop(expr: ExpressionC, cont: Continuation): Double =
expr match
case Literal(value) => cont(value)
case Addition(left, right) =>
loop(left, l => loop(right, r => cont(l + r)))
case Subtraction(left, right) =>
loop(left, l => loop(right, r => cont(l - r)))
case Multiplication(left, right) =>
loop(left, l => loop(right, r => cont(l * r)))
case Division(left, right) =>
loop(left, l => loop(right, r => cont(l / r)))

loop(this, identity)

object ExpressionC:
def apply(value: Double): ExpressionC =
Literal(value)
55 changes: 55 additions & 0 deletions ch05/src/ExpressionT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package ch05

enum ExpressionT:
case Literal(value: Double)
case Addition(left: ExpressionT, right: ExpressionT)
case Subtraction(left: ExpressionT, right: ExpressionT)
case Multiplication(left: ExpressionT, right: ExpressionT)
case Division(left: ExpressionT, right: ExpressionT)

def eval: Double =
// Trampoline style.
type Continuation = Double => Call

enum Call:
case Continue(value: Double, k: Continuation)
case Loop(expr: ExpressionT, k: Continuation)
case Done(result: Double)

def loop2(left: ExpressionT, right: ExpressionT, cont: Continuation, op: (Double, Double) => Double): Call =
Call.Loop(
left,
l => Call.Loop(right, r => Call.Continue(op(l, r), cont))
)

def loop(expr: ExpressionT, cont: Continuation): Call =
expr match
case Literal(value) => Call.Continue(value, cont)
case Addition(left, right) => loop2(left, right, cont, _ + _)
case Subtraction(left, right) => loop2(left, right, cont, _ - _)
case Multiplication(left, right) => loop2(left, right, cont, _ * _)
case Division(left, right) => loop2(left, right, cont, _ / _)

def trampoline(call: Call): Double =
call match
case Call.Continue(value, k) => trampoline(k(value))
case Call.Loop(expr, k) => trampoline(loop(expr, k))
case Call.Done(result) => result

trampoline(loop(this, x => Call.Done(x)))

def +(that: ExpressionT): ExpressionT =
Addition(this, that)

def -(that: ExpressionT): ExpressionT =
Subtraction(this, that)

def *(that: ExpressionT): ExpressionT =
Multiplication(this, that)

def /(that: ExpressionT): ExpressionT =
Division(this, that)

object ExpressionT:
def apply(value: Double): ExpressionT =
Literal(value)
43 changes: 43 additions & 0 deletions ch05/src/Regexp.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ch05

enum Regexp:
case Append(left: Regexp, right: Regexp)
case OrElse(first: Regexp, second: Regexp)
case Repeat(source: Regexp)
case Apply(string: String)
case Empty

def ++(that: Regexp): Regexp =
Append(this, that)

def orElse(that: Regexp): Regexp =
OrElse(this, that)

def repeat: Regexp =
Repeat(this)

def `*`: Regexp = this.repeat

def matches(input: String): Boolean =
def loop(regexp: Regexp, idx: Int): Option[Int] =
regexp match
case Append(left, right) =>
loop(left, idx).flatMap(loop(right, _))
case OrElse(first, second) =>
loop(first, idx).orElse(loop(second, idx))
case Repeat(source) =>
loop(source, idx)
.flatMap(loop(regexp, _))
.orElse(Some(idx))
case Apply(string) =>
Option.when(input.startsWith(string, idx))(idx + string.size)
case Empty => None

// Check we matched the entire input
loop(this, 0).map(_ == input.size).getOrElse(false)

object Regexp:
val empty: Regexp = Empty

def apply(string: String): Regexp =
Apply(string)
64 changes: 64 additions & 0 deletions ch05/src/RegexpC.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package ch05

// Continuation-Passing style.
enum RegexpC:
case Append(left: RegexpC, right: RegexpC)
case OrElse(first: RegexpC, second: RegexpC)
case Repeat(source: RegexpC)
case Apply(string: String)
case Empty

def ++(that: RegexpC): RegexpC =
Append(this, that)

def orElse(that: RegexpC): RegexpC =
OrElse(this, that)

def repeat: RegexpC =
Repeat(this)

def `*`: RegexpC = this.repeat

def matches(input: String): Boolean =
// Define a type alias so we can easily write continuations.
type Continuation = Option[Int] => Option[Int]

def loop(
regexp: RegexpC,
idx: Int,
cont: Continuation
): Option[Int] =
regexp match
case Append(left, right) =>
val k: Continuation = _ match
case None => cont(None)
case Some(i) => loop(right, i, cont)
loop(left, idx, k)

case OrElse(first, second) =>
val k: Continuation = _ match
case None => loop(second, idx, cont)
case some => cont(some)
loop(first, idx, k)

case Repeat(source) =>
val k: Continuation =
_ match
case None => cont(Some(idx))
case Some(i) => loop(regexp, i, cont)
loop(source, idx, k)

case Apply(string) =>
cont(Option.when(input.startsWith(string, idx))(idx + string.size))

case Empty =>
cont(None)

// Check we matched the entire input
loop(this, 0, identity).map(_ == input.size).getOrElse(false)

object RegexpC:
val empty: RegexpC = Empty

def apply(string: String): RegexpC =
Apply(string)
85 changes: 85 additions & 0 deletions ch05/src/RegexpT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package ch05

enum RegexpT:
def ++(that: RegexpT): RegexpT =
Append(this, that)

def orElse(that: RegexpT): RegexpT =
OrElse(this, that)

def repeat: RegexpT =
Repeat(this)

def `*`: RegexpT = this.repeat

def matches(input: String): Boolean =
/*
Scala's runtimes don't support full tail calls, so calls from a continuation
to loop or from loop to a continuation will use a stack frame.
So, instead of making a call, we return a value that reifies the call we want to make.
This idea is the core of trampolining.
*/
// Define a type alias so we can easily write continuations.
type Continuation = Option[Int] => Call

enum Call:
case Loop(regexp: RegexpT, index: Int, continuation: Continuation)
case Continue(index: Option[Int], continuation: Continuation)
case Done(index: Option[Int])

def loop(regexp: RegexpT, idx: Int, cont: Continuation): Call =
regexp match
case Append(left, right) =>
val k: Continuation = _ match
case None => Call.Continue(None, cont)
case Some(i) => Call.Loop(right, i, cont)
Call.Loop(left, idx, k)

case OrElse(first, second) =>
val k: Continuation = _ match
case None => Call.Loop(second, idx, cont)
case some => Call.Continue(some, cont)
Call.Loop(first, idx, k)

case Repeat(source) =>
val k: Continuation =
_ match
case None => Call.Continue(Some(idx), cont)
case Some(i) => Call.Loop(regexp, i, cont)
Call.Loop(source, idx, k)

// The following could directly call 'cont' with the Option
// if Scala had support for full tail calls.
case Apply(string) =>
Call.Continue(
Option.when(input.startsWith(string, idx))(idx + string.size),
cont
)

case Empty =>
Call.Continue(None, cont)

def trampoline(next: Call): Option[Int] =
next match
case Call.Loop(regexp, index, continuation) =>
trampoline(loop(regexp, index, continuation))
case Call.Continue(index, continuation) =>
trampoline(continuation(index))
case Call.Done(index) => index

// Check we matched the entire input
trampoline(loop(this, 0, Call.Done(_)))
.map(_ == input.size)
.getOrElse(false)

case Append(left: RegexpT, right: RegexpT)
case OrElse(first: RegexpT, second: RegexpT)
case Repeat(source: RegexpT)
case Apply(string: String)
case Empty

object RegexpT:
val empty: RegexpT = Empty

def apply(string: String): RegexpT =
Apply(string)
11 changes: 11 additions & 0 deletions ch05/test/src/ExpressionCSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ch05

import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.shouldBe

class ExpressionCSpec extends AnyFunSpec:
describe("ExpressionC"):
it("eval"):
val fortyTwo = ((ExpressionC(15.0) + ExpressionC(5.0)) * ExpressionC(2.0) + ExpressionC(2.0)) / ExpressionC(1.0)
fortyTwo.eval shouldBe 42.0d

11 changes: 11 additions & 0 deletions ch05/test/src/ExpressionSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ch05

import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.shouldBe

class ExpressionSpec extends AnyFunSpec:
describe("Expression"):
it("eval"):
val fortyTwo = ((Expression(15.0) + Expression(5.0)) * Expression(2.0) + Expression(2.0)) / Expression(1.0)
fortyTwo.eval shouldBe 42.0d

11 changes: 11 additions & 0 deletions ch05/test/src/ExpressionTSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ch05

import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers.shouldBe

class ExpressionTSpec extends AnyFunSpec:
describe("ExpressionT"):
it("eval"):
val fortyTwo = ((ExpressionT(15.0) + ExpressionT(5.0)) * ExpressionT(2.0) + ExpressionT(2.0)) / ExpressionT(1.0)
fortyTwo.eval shouldBe 42.0d

Loading

0 comments on commit 3d7f461

Please sign in to comment.