Skip to content

Commit

Permalink
Complete chapter 4: Monads
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhijit Sarkar committed Jan 15, 2024
1 parent 5e3a9b8 commit 132b29d
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 18 deletions.
16 changes: 16 additions & 0 deletions src/main/scala/ch04/Eval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ch04

import cats.Eval as CatsEval

/*
4.6.5 Exercise: Safer Folding using Eval
The naive implementation of foldRight below is not stack safe. Make it so using Eval.
*/
object Eval:
def foldRight[A, B](as: List[A], acc: B)(fn: (A, B) => B): B =
def foldR(xs: List[A]): CatsEval[B] =
xs match
case head :: tail => CatsEval.defer(foldR(tail).map(fn(head, _)))
case Nil => CatsEval.now(acc)

foldR(as).value
28 changes: 28 additions & 0 deletions src/main/scala/ch04/Monad.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package ch04

trait Monad[F[_]]:
def pure[A](a: A): F[A]

def flatMap[A, B](value: F[A])(f: A => F[B]): F[B]

/*
4.1.2 Exercise: Getting Func-y
Every monad is also a functor. We can define map in the same way
for every monad using the existing methods, flatMap and pure.
Try defining map yourself now.
*/
def map[A, B](value: F[A])(f: A => B): F[B] =
flatMap(value)(a => pure(f(a)))

object MonadInstances:
type Id[A] = A

/*
4.3.1 Exercise: Monadic Secret Identities
Implement pure, map, and flatMap for Id!
*/
given idMonad: Monad[Id] with
def pure[A](a: A): Id[A] = a

def flatMap[A, B](value: Id[A])(f: A => Id[B]): Id[B] =
f(value)
19 changes: 19 additions & 0 deletions src/main/scala/ch04/MonadError.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package ch04

import cats.{Monad, MonadError as CatsMonadError}
import cats.syntax.applicativeError.catsSyntaxApplicativeErrorId

/*
4.5.4 Exercise: Abstracting
Implement a method validateAdult with the following signature
def validateAdult[F[_]](age: Int)(implicit me: MonadError[F, Throwable]): F[Int]
When passed an age greater than or equal to 18 it should return that value as a success.
Otherwise it should return a error represented as an IllegalArgumentException.
*/
object MonadError:
def validateAdult[F[_]](age: Int)(implicit me: CatsMonadError[F, Throwable]): F[Int] =
if age >= 18
then Monad[F].pure(age)
else new IllegalArgumentException("Age must be greater than or equal to 18").raiseError[F, Int]
40 changes: 40 additions & 0 deletions src/main/scala/ch04/Reader.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ch04

import cats.data.Reader as CatsReader
import cats.syntax.applicative.catsSyntaxApplicativeId

/*
4.8.3 Exercise: Hacking on Readers
The classic use of Readers is to build programs that accept a configuration as a parameter.
Let's ground this with a complete example of a simple login system.
Our configuration will consist of two databases: a list of valid users and a list of their passwords.
Start by creating a type alias DbReader for a Reader that consumes a Db as input.
Now create methods that generate DbReaders to look up the username for an Int user ID,
and look up the password for a String username.
Finally create a checkLogin method to check the password for a given user ID.
*/
object Reader:
final case class Db(
usernames: Map[Int, String],
passwords: Map[String, String]
)

type DbReader[A] = CatsReader[Db, A]

def findUsername(userId: Int): DbReader[Option[String]] =
CatsReader(_.usernames.get(userId))

def checkPassword(username: String, password: String): DbReader[Boolean] =
CatsReader(_.passwords.get(username).contains(password))

def checkLogin(userId: Int, password: String): DbReader[Boolean] =
for {
username <- findUsername(userId)
passwordOk <- username
.map(checkPassword(_, password))
.getOrElse(false.pure[DbReader])
} yield passwordOk
58 changes: 58 additions & 0 deletions src/main/scala/ch04/State.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package ch04

import cats.data.State as CatsState
import cats.syntax.applicative.catsSyntaxApplicativeId
import cats.syntax.apply.catsSyntaxApplyOps

/*
4.9.3 Exercise: Post-Order Calculator
Let's write an interpreter for post-order expressions.
We can parse each symbol into a State instance representing
a transformation on the stack and an intermediate result.
Start by writing a function evalOne that parses a single symbol into an instance of State.
If the stack is in the wrong configuration, it's OK to throw an exception.
*/
object State:
type Stack = List[Int]
type CalcState[A] = CatsState[Stack, A]

def eval(sym: String, s: Stack): Stack =
s match
case x :: y :: s1 =>
sym match
case "+" => x + y :: s1
case "-" => y - x :: s1
case "*" => x * y :: s1
case "/" if x != 0 => y / x :: s1
case "/" => sys.error("divide by zero")
case _ => sys.error("bad expression")

def evalOne(sym: String): CalcState[Int] =
for
s <- CatsState.get[Stack]
s1 = sym match
case x if x.forall(Character.isDigit) => x.toInt :: s
case x => eval(sym, s)

_ <- CatsState.set[Stack](s1)
yield s1.head

/*
Generalise this example by writing an evalAll method that computes the result of a List[String].
Use evalOne to process each symbol, and thread the resulting State monads together using flatMap.
*/
def evalAll(input: List[String]): CalcState[Int] =
input.foldLeft(0.pure[CalcState]) { (s, x) =>
// We discard the value, but must use the previous
// state for the next computation.
// Simply invoking evalOne will create a new state.
s *> evalOne(x)
}

/*
Complete the exercise by implementing an evalInput function that splits an input String into symbols,
calls evalAll, and runs the result with an initial stack.
*/
def evalInput(input: String): Int =
evalAll(input.split(" ").toList).runA(Nil).value
43 changes: 43 additions & 0 deletions src/main/scala/ch04/Tree.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ch04

import cats.Monad as CatsMonad

/*
4.10.1 Exercise: Branching out Further with Monads
Let's write a Monad for the Tree data type given below.
Verify that the code works on instances of Branch and Leaf,
and that the Monad provides Functor-like behaviour for free.
Also verify that having a Monad in scope allows us to use for comprehensions,
despite the fact that we haven’t directly implemented flatMap or map on Tree.
Don't feel you have to make tailRecM tail-recursive. Doing so is quite difficult.
*/
sealed trait Tree[+A]

final case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

final case class Leaf[A](value: A) extends Tree[A]

def branch[A](left: Tree[A], right: Tree[A]): Tree[A] =
Branch(left, right)

def leaf[A](value: A): Tree[A] =
Leaf(value)

given CatsMonad[Tree] with
override def pure[A](x: A): Tree[A] =
Leaf(x)

override def flatMap[A, B](t: Tree[A])(f: A => Tree[B]): Tree[B] =
t match
case Leaf(x) => f(x)
case Branch(l, r) => Branch(flatMap(l)(f), flatMap(r)(f))

// Not stack-safe!
override def tailRecM[A, B](a: A)(f: A => Tree[Either[A, B]]): Tree[B] =
flatMap(f(a)):
case Left(value) => tailRecM(value)(f)
case Right(value) => Leaf(value)
27 changes: 27 additions & 0 deletions src/main/scala/ch04/Writer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package ch04

import cats.data.Writer as CatsWriter
import cats.syntax.applicative.catsSyntaxApplicativeId
import cats.syntax.writer.catsSyntaxWriterId

/*
4.7.3 Exercise: Show Your Working
Rewrite factorial so it captures the log messages in a Writer.
Demonstrate that this allows us to reliably separate the logs for concurrent computations.
*/
object Writer:
def slowly[A](body: => A): A =
try body
finally Thread.sleep(100)

type Logged[A] = CatsWriter[Vector[String], A]

def factorial(n: Int): Logged[Int] =
for
ans <-
if (n == 0)
then 1.pure[Logged]
else slowly(factorial(n - 1).map(_ * n))
_ <- Vector(s"fact $n $ans").tell
yield ans
146 changes: 146 additions & 0 deletions src/main/scala/ch04/ch04.worksheet.sc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import cats.Eval
import cats.data.{Reader, Writer, State}
import cats.syntax.applicative.catsSyntaxApplicativeId
import cats.syntax.writer.catsSyntaxWriterId

// ----------------------------------------------------------------------------
// Eval
// ----------------------------------------------------------------------------

// call-by-value which is eager and memoized
val now = Eval.now(math.random + 1000)
// call-by-name which is lazy and not memoized
val always = Eval.always(math.random + 3000)
// call-by-need which is lazy and memoized
val later = Eval.later(math.random + 2000)

now.value
always.value
later.value

val greeting = Eval
.always{ println("Step 1"); "Hello" }
.map{ str => println("Step 2"); s"$str world" }

greeting.value
// Step 1
// Step 2
// res16: String = "Hello world"

val ans = for {
a <- Eval.now{ println("Calculating A"); 40 }
b <- Eval.always{ println("Calculating B"); 2 }
} yield {
println("Adding A and B")
a + b
}

ans.value // first access
// Calculating B
// Adding A and B
// res17: Int = 42 // first access
ans.value // second access
// Calculating B
// Adding A and B
// res18: Int = 42

val saying = Eval
.always{ println("Step 1"); "The cat" }
.map{ str => println("Step 2"); s"$str sat on" }
.memoize
.map{ str => println("Step 3"); s"$str the mat" }

saying.value // first access
// Step 1
// Step 2
// Step 3
// res19: String = "The cat sat on the mat" // first access
saying.value // second access
// Step 3
// res20: String = "The cat sat on the mat"

// stack-safe
def factorial(n: BigInt): Eval[BigInt] =
if(n == 1) {
Eval.now(n)
} else {
Eval.defer(factorial(n - 1).map(_ * n))
}

factorial(50000).value

// stack-safe foldRight
ch04.Eval.foldRight((1 to 100000).toList, 0L)(_ + _)

// ----------------------------------------------------------------------------
// Writer
// ----------------------------------------------------------------------------
type Logged[A] = Writer[Vector[String], A]

123.pure[Logged]

Vector("msg1", "msg2", "msg3").tell

val b = 123.writer(Vector("msg1", "msg2", "msg3"))

val writer1 = for {
a <- 10.pure[Logged]
_ <- Vector("a", "b", "c").tell
b <- 32.writer(Vector("x", "y", "z"))
} yield a + b

writer1.run

val writer2 = writer1.mapWritten(_.map(_.toUpperCase))

writer2.run

val writer3 = writer1.bimap(
log => log.map(_.toUpperCase),
res => res * 100
)

writer3.run

val writer5 = writer1.reset

writer5.run

// ----------------------------------------------------------------------------
// Reader
// ----------------------------------------------------------------------------
final case class Cat(name: String, favoriteFood: String)

val catName: Reader[Cat, String] =
Reader(cat => cat.name)

val greetKitty: Reader[Cat, String] =
catName.map(name => s"Hello ${name}")

val feedKitty: Reader[Cat, String] =
Reader(cat => s"Have a nice bowl of ${cat.favoriteFood}")

val greetAndFeed: Reader[Cat, String] =
for {
greet <- greetKitty
feed <- feedKitty
} yield s"$greet. $feed."

greetAndFeed(Cat("Garfield", "lasagne"))

// ----------------------------------------------------------------------------
// State
// ----------------------------------------------------------------------------

val a = State[Int, String] { state =>
(state, s"The state is $state")
}

// Get the state and the result
val (state, result) = a.run(10).value

// Get the state, ignore the result
val justTheState = a.runS(10).value

// Get the result, ignore the state
val justTheResult = a.runA(10).value
Loading

0 comments on commit 132b29d

Please sign in to comment.