Skip to content

Commit

Permalink
Add support for serialization in Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
romac committed Jan 30, 2024
1 parent 39127af commit bb17309
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 37 deletions.
51 changes: 47 additions & 4 deletions core/src/main/scala/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ object Backend:
for inboxes <- LocalBackend.makeInboxes(locs)
yield LocalBackend(inboxes)

def http[M[_]: Concurrent](
locs: List[Loc]
): HTTPBackend[M] =
HTTPBackend(locs)

class LocalBackend[M[_]](inboxes: Map[Loc, Queue[M, Any]]):
val locs = inboxes.keys.toSeq

Expand All @@ -37,19 +42,19 @@ class LocalBackend[M[_]](inboxes: Map[Loc, Queue[M, Any]]):
case NetworkSig.Run(ma) =>
ma

case NetworkSig.Send(a, to) =>
case NetworkSig.Send(a, to, ser) =>
val inbox = inboxes.get(to).get
inbox.offer(a)

case NetworkSig.Recv(from) =>
case NetworkSig.Recv(from, ser) =>
val inbox = inboxes.get(at).get
inbox.take.map(_.asInstanceOf[A])

case NetworkSig.Broadcast(a) =>
case NetworkSig.Broadcast(a, ser) =>
locs
.filter(_ != at)
.traverse_ { to =>
run(at, inboxes)(NetworkSig.Send(a, to))
run(at, inboxes)(NetworkSig.Send(a, to, ser))
}

object LocalBackend:
Expand All @@ -63,3 +68,41 @@ object LocalBackend:
extension (backend: LocalBackend[M])
def runNetwork[A](at: Loc)(network: Network[M, A]): M[A] =
runNetwork(at)(network)

class HTTPBackend[M[_]](locs: List[Loc]):
def runNetwork[A](at: Loc)(
network: Network[M, A]
)(using M: Monad[M]): M[A] =
network.foldMap(run(at, locs).toFunctionK)

private[choreo] def run(
at: Loc,
locs: List[Loc]
)(using M: Monad[M]): [A] => NetworkSig[M, A] => M[A] = [A] =>
(na: NetworkSig[M, A]) =>
na match
case NetworkSig.Run(ma) =>
ma

case NetworkSig.Send(a, to, ser) =>
val encoded = ser.encode(a)
// TODO: send to network
M.pure(())

case NetworkSig.Recv(from, ser) =>
val encoded: ser.Encoding = ??? // TODO: receive from network
val value = ser.decode(encoded).get
M.pure(value)

case NetworkSig.Broadcast(a, ser) =>
locs
.filter(_ != at)
.traverse_ { to =>
run(at, locs)(NetworkSig.Send(a, to, ser))
}

object HTTPBackend:
given backend[M[_]: Monad]: Backend[HTTPBackend[M], M] with
extension (backend: HTTPBackend[M])
def runNetwork[A](at: Loc)(network: Network[M, A]): M[A] =
runNetwork(at)(network)
39 changes: 26 additions & 13 deletions core/src/main/scala/Choreo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,39 @@ enum ChoreoSig[M[_], A]:
case Local[M[_], A, L <: Loc](l: L, m: Unwrap[L] => M[A])
extends ChoreoSig[M, A @@ L]

case Comm[M[_], A, L0 <: Loc, L1 <: Loc](l0: L0, a: A @@ L0, l1: L1)
extends ChoreoSig[M, A @@ L1]

case Cond[M[_], A, B, L <: Loc](l: L, a: A @@ L, f: A => Choreo[M, B])
extends ChoreoSig[M, B]
case Comm[M[_], A, L0 <: Loc, L1 <: Loc](
l0: L0,
a: A @@ L0,
l1: L1,
s: Serialize[A]
) extends ChoreoSig[M, A @@ L1]

case Cond[M[_], A, B, L <: Loc](
l: L,
a: A @@ L,
f: A => Choreo[M, B],
s: Serialize[A]
) extends ChoreoSig[M, B]

extension [L <: Loc](l: L)
def locally[M[_], A](m: Unwrap[l.type] ?=> M[A]): Choreo[M, A @@ l.type] =
Free.liftF(ChoreoSig.Local[M, A, l.type](l, un => m(using un)))

def send[A](a: A @@ L): Sendable[A, L] = (l, a)
def send[A: Serialize](a: A @@ L): Sendable[A, L] = (l, a)

def cond[M[_], A, B](a: A @@ L)(f: A => Choreo[M, B]): Choreo[M, B] =
Free.liftF(ChoreoSig.Cond(l, a, f))
def cond[M[_], A, B](a: A @@ L)(
f: A => Choreo[M, B]
)(using s: Serialize[A]): Choreo[M, B] =
Free.liftF(ChoreoSig.Cond(l, a, f, s))

opaque type Sendable[A, L <: Loc] = (L, A @@ L)

extension [A, Src <: Loc](s: Sendable[A, Src])
def to[M[_], Dst <: Loc](dst: Dst): Choreo[M, A @@ dst.type] =
Free.liftF(ChoreoSig.Comm(s._1, s._2, dst))
extension [A, Src <: Loc](sendable: Sendable[A, Src])
def to[M[_], Dst <: Loc](dst: Dst)(using
s: Serialize[A]
): Choreo[M, A @@ dst.type] =
val (a, src) = sendable
Free.liftF(ChoreoSig.Comm(a, src, dst, s))

extension [M[_], A](c: Choreo[M, A])
def runLocal(using M: Monad[M]): M[A] =
Expand All @@ -55,8 +68,8 @@ extension [M[_], A](c: Choreo[M, A])
case ChoreoSig.Local(l, m) =>
m(unwrap).map(wrap(_).asInstanceOf)

case ChoreoSig.Comm(l0, a, l1) =>
case ChoreoSig.Comm(l0, a, l1, s) =>
M.pure(wrap(unwrap(a)).asInstanceOf)

case ChoreoSig.Cond(l, a, f) =>
case ChoreoSig.Cond(l, a, f, s) =>
f(unwrap(a)).runLocal
36 changes: 21 additions & 15 deletions core/src/main/scala/Network.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import choreo.utils.toFunctionK

enum NetworkSig[M[_], A]:
case Run(ma: M[A]) extends NetworkSig[M, A]
case Send(a: A, to: Loc) extends NetworkSig[M, Unit]
case Recv(from: Loc) extends NetworkSig[M, A]
case Broadcast(a: A) extends NetworkSig[M, Unit]
case Send(a: A, to: Loc, ser: Serialize[A]) extends NetworkSig[M, Unit]
case Recv(from: Loc, ser: Serialize[A]) extends NetworkSig[M, A]
case Broadcast(a: A, ser: Serialize[A]) extends NetworkSig[M, Unit]

type Network[M[_], A] = Free[[X] =>> NetworkSig[M, X], A]

Expand All @@ -23,14 +23,14 @@ object Network:
def run[M[_], A](ma: M[A]): Network[M, A] =
Free.liftF(NetworkSig.Run(ma))

def send[M[_], A](a: A, to: Loc): Network[M, Unit] =
Free.liftF(NetworkSig.Send(a, to))
def send[M[_], A](a: A, to: Loc)(using ser: Serialize[A]): Network[M, Unit] =
Free.liftF(NetworkSig.Send(a, to, ser))

def recv[M[_], A](from: Loc): Network[M, A] =
Free.liftF(NetworkSig.Recv(from))
def recv[M[_], A](from: Loc)(using ser: Serialize[A]): Network[M, A] =
Free.liftF(NetworkSig.Recv(from, ser))

def broadcast[M[_], A](a: A): Network[M, Unit] =
Free.liftF(NetworkSig.Broadcast(a))
def broadcast[M[_], A](a: A)(using ser: Serialize[A]): Network[M, Unit] =
Free.liftF(NetworkSig.Broadcast(a, ser))

def empty[M[_], A, L <: Loc]: Network[M, A @@ L] =
Network.pure(At.empty[A, L])
Expand All @@ -48,13 +48,19 @@ object Endpoint:
if at == loc then Network.run(m(unwrap)).map(wrap.asInstanceOf)
else Network.empty.asInstanceOf

case ChoreoSig.Comm(src, a, dst) =>
case ChoreoSig.Comm(src, a, dst, ser) =>
if at == src then
Network.send(unwrap(a), dst) *> Network.empty.asInstanceOf
else if at == dst then Network.recv(src).map(wrap.asInstanceOf)
Network.send(unwrap(a), dst)(using
ser
) *> Network.empty.asInstanceOf
else if at == dst then
Network.recv(src)(using ser).map(wrap.asInstanceOf)
else Network.empty[M, a.Value, a.Location]

case ChoreoSig.Cond(loc, a, f) =>
case ChoreoSig.Cond(loc, a, f, ser) =>
if at == loc then
Network.broadcast(unwrap(a)) *> project(f(unwrap(a)), at)
else Network.recv(loc).flatMap(a => project(f(a.asInstanceOf), at))
Network.broadcast(unwrap(a))(using ser) *> project(f(unwrap(a)), at)
else
Network
.recv(loc)(using ser)
.flatMap(a => project(f(a.asInstanceOf), at))
15 changes: 15 additions & 0 deletions core/src/main/scala/Serialize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package choreo

trait Serialize[A]:
type Encoding
def encode(a: A): Encoding
def decode(encoded: Encoding): Option[A]

object Serialize:
def identity[A] = new Serialize[A]:
type Encoding = A
def encode(a: A) = a
def decode(encoded: Encoding) = Some(encoded)

object identities:
given [A]: Serialize[A] = identity[A]
9 changes: 8 additions & 1 deletion examples/src/main/scala/Bookseller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ val buyer: "buyer" = "buyer"
val seller: "sender" = "sender"

def main: IO[Unit] =
import choreo.Serialize.identities.given

for
backend <- Backend.local(List(buyer, seller))

Expand All @@ -29,7 +31,12 @@ def main: IO[Unit] =
_ <- (sellerIO, buyerIO).parTupled
yield ()

def protocol: Choreo[IO, Option[Date @@ "buyer"]] =
def protocol(using
Serialize[Boolean],
Serialize[String],
Serialize[Double],
Serialize[Date]
): Choreo[IO, Option[Date @@ "buyer"]] =
for
titleB <- buyer.locally:
IO.print("Enter book title: ") *> IO.readLine
Expand Down
15 changes: 11 additions & 4 deletions examples/src/main/scala/KV.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@ val client: "client" = "client"
val server: "server" = "server"

def main: IO[Unit] =
import choreo.Serialize.identities.given

for
backend <- Backend.local(List(client, server))
clientTask = choreo.run(backend, client)
serverTask = choreo.run(backend, server)
clientTask = app.run(backend, client)
serverTask = app.run(backend, server)
_ <- (clientTask, serverTask).parTupled
yield ()

def choreo: Choreo[IO, Unit] =
def app(using Serialize[Request], Serialize[Response]): Choreo[IO, Unit] =
for
stateS <- server.locally(Ref.of[IO, State](Map.empty))
_ <- step(stateS).foreverM
yield ()

def step(stateS: Ref[IO, State] @@ "server"): Choreo[IO, Unit] =
def step(
stateS: Ref[IO, State] @@ "server"
)(using Serialize[Request], Serialize[Response]): Choreo[IO, Unit] =
for
reqC <- client.locally(readRequest)
resC <- kvs(reqC, stateS)
Expand All @@ -44,6 +48,9 @@ def step(stateS: Ref[IO, State] @@ "server"): Choreo[IO, Unit] =
def kvs(
reqC: Request @@ "client",
stateS: Ref[IO, State] @@ "server"
)(using
Serialize[Request],
Serialize[Response]
): Choreo[IO, Response @@ "client"] =
for
reqS <- client.send(reqC).to(server)
Expand Down

0 comments on commit bb17309

Please sign in to comment.