Skip to content

Commit

Permalink
[ESI] Add hostmem write support to cosim
Browse files Browse the repository at this point in the history
  • Loading branch information
teqdruid committed Jan 9, 2025
1 parent 4e47877 commit 5fa406f
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 31 deletions.
47 changes: 47 additions & 0 deletions frontends/PyCDE/integration_test/esitester.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,52 @@ def construct(ports):
mem_data_ce.assign(hostmem_read_resp_valid)


class WriteMem(Module):
"""Writes a cycle count to host memory at address 0 in MMIO upon each MMIO
transaction."""
clk = Clock()
rst = Reset()

@generator
def construct(ports):
cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType))
resp_ready_wire = Wire(Bits(1))
cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire)
mmio_xact = cmd_valid & resp_ready_wire

write_loc_ce = mmio_xact & cmd.write & (cmd.offset == UInt(32)(0))
write_loc = Reg(UInt(64),
clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=write_loc_ce)
write_loc.assign(cmd.data.as_uint())

response_data = write_loc.as_bits()
response_chan, response_ready = Channel(Bits(64)).wrap(
response_data, cmd_valid)
resp_ready_wire.assign(response_ready)

mmio_rw = esi.MMIO.read_write(appid=AppID("WriteMem"))
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

tag = Counter(8)(clk=ports.clk, rst=ports.rst, increment=mmio_xact)

cycle_counter = Counter(64)(clk=ports.clk,
rst=ports.rst,
increment=Bits(1)(1))

hostmem_write_req, _ = esi.HostMem.wrap_write_req(
write_loc,
cycle_counter.out.as_bits(),
tag.out,
valid=mmio_xact.reg(ports.clk, ports.rst))

hostmem_write_resp = esi.HostMem.write(appid=AppID("WriteMem_hostwrite"),
req=hostmem_write_req)


class EsiTesterTop(Module):
clk = Clock()
rst = Reset()
Expand All @@ -122,6 +168,7 @@ class EsiTesterTop(Module):
def construct(ports):
PrintfExample(clk=ports.clk, rst=ports.rst)
ReadMem(clk=ports.clk, rst=ports.rst)
WriteMem(clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
80 changes: 73 additions & 7 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from __future__ import annotations

from ..common import Clock, Input, Output, Reset
from ..constructs import AssignableSignal, ControlReg, NamedWire, Wire
from ..constructs import AssignableSignal, NamedWire, Wire
from .. import esi
from ..module import Module, generator, modparams
from ..signals import BitsSignal, BundleSignal, ChannelSignal
from ..support import clog2
from ..types import (Array, Bits, Bundle, BundledChannel, Channel,
ChannelDirection, StructType, Type, UInt)
ChannelDirection, StructType, UInt)

from typing import Dict, List, Tuple
import typing
Expand Down Expand Up @@ -266,34 +266,100 @@ class ChannelHostMemImpl(esi.ServiceImplementation):
clk = Clock()
rst = Reset()

UpstreamReq = StructType([
UpstreamReadReq = StructType([
("address", UInt(64)),
("length", UInt(32)),
("tag", UInt(8)),
])
read = Output(
Bundle([
BundledChannel("req", ChannelDirection.TO, UpstreamReq),
BundledChannel("req", ChannelDirection.TO, UpstreamReadReq),
BundledChannel(
"resp", ChannelDirection.FROM,
StructType([
("tag", UInt(8)),
("data", Bits(read_width)),
])),
]))
UpstreamWriteReq = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", Bits(write_width)),
])
write = Output(
Bundle([
BundledChannel("req", ChannelDirection.TO, UpstreamWriteReq),
BundledChannel("ackTag", ChannelDirection.FROM, UInt(8)),
]))

@generator
def generate(ports, bundles: esi._ServiceGeneratorBundles):
read_reqs = [req for req in bundles.to_client_reqs if req.port == 'read']
ports.read = ChannelHostMemImpl.build_tagged_read_mux(ports, read_reqs)
write_reqs = [
req for req in bundles.to_client_reqs if req.port == 'write'
]
ports.write = ChannelHostMemImpl.build_tagged_write_mux(ports, write_reqs)

@staticmethod
def build_tagged_write_mux(
ports, reqs: List[esi._OutputBundleSetter]) -> BundleSignal:
"""Build the write side of the HostMem service."""

# If there's no write clients, just return a no-op write bundle
if len(reqs) == 0:
req, _ = Channel(ChannelHostMemImpl.UpstreamWriteReq).wrap(
{
"address": 0,
"tag": 0,
"data": 0
}, 0)
write_bundle, _ = ChannelHostMemImpl.write.type.pack(req=req)
return write_bundle

# TODO: mux together multiple write clients.
assert len(reqs) == 1, "Only one write client supported for now."

# Build the write request channels and ack wires.
write_channels: List[ChannelSignal] = []
write_acks = []
for req in reqs:
# Get the request channel and its data type.
reqch = [c.channel for c in req.type.channels if c.name == 'req'][0]
data_type = reqch.inner_type.data
assert data_type == Bits(
write_width
), f"Gearboxing not yet supported. Client {req.client_name}"

# Write acks to be filled in later.
write_ack = Wire(Channel(UInt(8)))
write_acks.append(write_ack)

# Pack up the bundle and assign the request channel.
write_req_bundle_type = esi.HostMem.write_req_bundle_type(data_type)
bundle_sig, froms = write_req_bundle_type.pack(ackTag=write_ack)
tagged_client_req = froms["req"]
req.assign(bundle_sig)
write_channels.append(tagged_client_req)

# TODO: re-write the tags and store the client and client tag.

# Build a channel mux for the write requests.
tagged_write_channel = esi.ChannelMux(write_channels)
upstream_write_bundle, froms = ChannelHostMemImpl.write.type.pack(
req=tagged_write_channel)
ack_tag = froms["ackTag"]
# TODO: decode the ack tag and assign it to the correct client.
write_acks[0].assign(ack_tag)
return upstream_write_bundle

@staticmethod
def build_tagged_read_mux(
ports, reqs: List[esi._OutputBundleSetter]) -> BundleSignal:
"""Build the read side of the HostMem service."""

if len(reqs) == 0:
req, req_ready = Channel(ChannelHostMemImpl.UpstreamReq).wrap(
req, req_ready = Channel(ChannelHostMemImpl.UpstreamReadReq).wrap(
{
"tag": 0,
"length": 0,
Expand All @@ -305,7 +371,7 @@ def build_tagged_read_mux(
# TODO: mux together multiple read clients.
assert len(reqs) == 1, "Only one read client supported for now."

req = Wire(Channel(ChannelHostMemImpl.UpstreamReq))
req = Wire(Channel(ChannelHostMemImpl.UpstreamReadReq))
read_bundle, froms = ChannelHostMemImpl.read.type.pack(req=req)
resp_chan_ready = Wire(Bits(1))
resp_data, resp_valid = froms["resp"].unwrap(resp_chan_ready)
Expand Down Expand Up @@ -335,7 +401,7 @@ def build_tagged_read_mux(

# Assign the multiplexed read request to the upstream request.
req.assign(
client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReq({
client_req.transform(lambda r: ChannelHostMemImpl.UpstreamReadReq({
"address": r.address,
"length": 1,
"tag": r.tag
Expand Down
6 changes: 6 additions & 0 deletions frontends/PyCDE/src/pycde/bsp/cosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def build(ports):
resp_wire.type)
resp_wire.assign(data)

ack_wire = Wire(Channel(UInt(8)))
write_req = hostmem.write.unpack(ackTag=ack_wire)['req']
ack_tag = esi.CallService.call(esi.AppID("__cosim_hostmem_write"),
write_req, UInt(8))
ack_wire.assign(ack_tag)

class ESI_Cosim_Top(Module):
clk = Clock()
rst = Input(Bits(1))
Expand Down
40 changes: 25 additions & 15 deletions frontends/PyCDE/src/pycde/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,24 +519,34 @@ class _HostMem(ServiceDecl):
("tag", UInt(8)),
])

WriteReqType = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", Any()),
])

def __init__(self):
super().__init__(self.__class__)

def write_req_bundle_type(self, data_type: Type) -> Bundle:
"""Build a write request bundle type for the given data type."""
write_req_type = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", data_type),
])
return Bundle([
BundledChannel("req", ChannelDirection.FROM, write_req_type),
BundledChannel("ackTag", ChannelDirection.TO, UInt(8))
])

def write_req_channel_type(self, data_type: Type) -> StructType:
"""Return a write request struct type for 'data_type'."""
return StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", data_type),
])

def wrap_write_req(self, address: UIntSignal, data: Type, tag: UIntSignal,
valid: BitsSignal) -> Tuple[ChannelSignal, BitsSignal]:
"""Create the proper channel type for a write request and use it to wrap the
given request arguments. Returns the Channel signal and a ready bit."""
inner_type = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", data.type),
])
inner_type = self.write_req_channel_type(data.type)
return Channel(inner_type).wrap(
inner_type({
"address": address,
Expand All @@ -548,10 +558,10 @@ def write(self, appid: AppID, req: ChannelSignal) -> ChannelSignal:
"""Create a write request to the host memory out of a request channel."""
self._materialize_service_decl()

write_bundle_type = Bundle([
BundledChannel("req", ChannelDirection.FROM, _HostMem.WriteReqType),
BundledChannel("ackTag", ChannelDirection.TO, UInt(8))
])
# Extract the data type from the request channel and call the helper to get
# the write bundle type for the req channel.
req_data_type = req.type.inner_type.data
write_bundle_type = self.write_req_bundle_type(req_data_type)

bundle = cast(
BundleSignal,
Expand Down
5 changes: 3 additions & 2 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,8 +809,9 @@ def unpack(self, **kwargs: ChannelSignal) -> Dict[str, ChannelSignal]:
raise ValueError(
f"Missing channel values for {', '.join(from_channels.keys())}")

unpack_op = esi.UnpackBundleOp([bc.channel._type for bc in to_channels],
self.value, operands)
with get_user_loc():
unpack_op = esi.UnpackBundleOp([bc.channel._type for bc in to_channels],
self.value, operands)

to_channels_results = unpack_op.toChannels
ret = {
Expand Down
7 changes: 4 additions & 3 deletions frontends/PyCDE/src/pycde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,9 +858,10 @@ def pack(
if len(to_channels) > 0:
raise ValueError(f"Missing channels: {', '.join(to_channels.keys())}")

pack_op = esi.PackBundleOp(self._type,
[bc.channel._type for bc in from_channels],
operands)
with get_user_loc():
pack_op = esi.PackBundleOp(self._type,
[bc.channel._type for bc in from_channels],
operands)

return BundleSignal(pack_op.bundle, self), Bundle.PackSignalResults(
[_FromCirctValue(c) for c in pack_op.fromChannels], self)
Expand Down
4 changes: 2 additions & 2 deletions frontends/PyCDE/test/test_esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ def build(ports):
# CHECK-NEXT: [[R5:%.+]] = hwarith.constant 0 : ui256
# CHECK-NEXT: [[R6:%.+]] = hw.struct_create ([[R0]], [[R4]], [[R5]]) : !hw.struct<address: ui64, tag: ui8, data: ui256>
# CHECK-NEXT: %chanOutput_0, %ready_1 = esi.wrap.vr [[R6]], %false : !hw.struct<address: ui64, tag: ui8, data: ui256>
# CHECK-NEXT: [[R7:%.+]] = esi.service.req <@_HostMem::@write>(#esi.appid<"host_mem_write_req">) : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: !esi.any>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK-NEXT: %ackTag = esi.bundle.unpack %chanOutput_0 from [[R7]] : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: !esi.any>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK-NEXT: [[R7:%.+]] = esi.service.req <@_HostMem::@write>(#esi.appid<"host_mem_write_req">) : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: ui256>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK-NEXT: %ackTag = esi.bundle.unpack %chanOutput_0 from [[R7]] : !esi.bundle<[!esi.channel<!hw.struct<address: ui64, tag: ui8, data: ui256>> from "req", !esi.channel<ui8> to "ackTag"]>
# CHECK: esi.service.std.hostmem @_HostMem
@unittestmodule(esi_sys=True)
class HostMemReq(Module):
Expand Down
56 changes: 56 additions & 0 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,14 @@ struct HostMemReadResp {
uint64_t data;
uint8_t tag;
};

struct HostMemWriteReq {
uint64_t data;
uint8_t tag;
uint64_t address;
};

using HostMemWriteResp = uint8_t;
#pragma pack(pop)

class CosimHostMem : public HostMem {
Expand Down Expand Up @@ -465,6 +473,34 @@ class CosimHostMem : public HostMem {
*readRespPort, *readReqPort));
read->connect([this](const MessageData &req) { return serviceRead(req); },
true);

// Setup the write side callback.
ChannelDesc writeArg, writeResp;
if (!rpcClient->getChannelDesc("__cosim_hostmem_write.arg", writeArg) ||
!rpcClient->getChannelDesc("__cosim_hostmem_write.result", writeResp))
throw std::runtime_error("Could not find HostMem channels");

const esi::Type *writeRespType =
getType(ctxt, new StructType(writeResp.type(),
{{"tag", new UIntType("ui8", 8)},
{"data", new BitsType("i64", 64)}}));
const esi::Type *writeReqType =
getType(ctxt, new StructType(writeArg.type(),
{{"address", new UIntType("ui64", 64)},
{"length", new UIntType("ui32", 32)},
{"tag", new UIntType("ui8", 8)}}));

// Get ports, create the function, then connect to it.
writeRespPort = std::make_unique<WriteCosimChannelPort>(
rpcClient->stub.get(), writeResp, writeRespType,
"__cosim_hostmem_write.result");
writeReqPort = std::make_unique<ReadCosimChannelPort>(
rpcClient->stub.get(), writeArg, writeReqType,
"__cosim_hostmem_write.arg");
write.reset(CallService::Callback::get(acc, AppID("__cosim_hostmem_write"),
*writeRespPort, *writeReqPort));
write->connect([this](const MessageData &req) { return serviceWrite(req); },
true);
}

// Service the read request as a callback. Simply reads the data from the
Expand All @@ -491,6 +527,23 @@ class CosimHostMem : public HostMem {
return MessageData::from(resp);
}

// Service a write request as a callback. Simply write the data to the
// location specified. TODO: check that the memory has been mapped.
MessageData serviceWrite(const MessageData &reqBytes) {
const HostMemWriteReq *req = reqBytes.as<HostMemWriteReq>();
acc.getLogger().debug(
[&](std::string &subsystem, std::string &msg,
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "HostMem";
msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" +
toHex(req->data) + " tag=" + std::to_string(req->tag);
});
uint64_t *dataPtr = reinterpret_cast<uint64_t *>(req->address);
*dataPtr = req->data;
HostMemWriteResp resp = req->tag;
return MessageData::from(resp);
}

struct CosimHostMemRegion : public HostMemRegion {
CosimHostMemRegion(std::size_t size) {
ptr = malloc(size);
Expand Down Expand Up @@ -530,6 +583,9 @@ class CosimHostMem : public HostMem {
std::unique_ptr<WriteCosimChannelPort> readRespPort;
std::unique_ptr<ReadCosimChannelPort> readReqPort;
std::unique_ptr<CallService::Callback> read;
std::unique_ptr<WriteCosimChannelPort> writeRespPort;
std::unique_ptr<ReadCosimChannelPort> writeReqPort;
std::unique_ptr<CallService::Callback> write;
};

} // namespace
Expand Down
Loading

0 comments on commit 5fa406f

Please sign in to comment.