Skip to content
This repository has been archived by the owner on May 5, 2024. It is now read-only.

Commit

Permalink
parameterize everything by wE, wF
Browse files Browse the repository at this point in the history
remove clock enable
create tests
  • Loading branch information
makslevental committed Aug 2, 2022
1 parent 32aca33 commit e1c02e6
Show file tree
Hide file tree
Showing 47 changed files with 1,758 additions and 616 deletions.
33 changes: 11 additions & 22 deletions bragghls/compiler/compile.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,22 @@
import argparse
import ast
import importlib.util
import io
import os
import shutil
from subprocess import Popen, PIPE

import astor

import bragghls.runner
import bragghls.state
from bragghls.parse import parse_mlir_module
from bragghls.ir.parse import parse_mlir_module
from bragghls.ir.transforms import transform_forward, rewrite_schedule_vals
from bragghls.rtl.emit_verilog import emit_verilog
from bragghls.runner import Forward, get_default_args
from bragghls.testbench.tb_runner import testbench_runner
from bragghls.transforms import transform_forward, rewrite_schedule_vals
from bragghls.util import import_module_from_fp, import_module_from_string
from scripts.hack_affine_scf import scf_to_affine


def import_module_from_string(name: str, source: str):
spec = importlib.util.spec_from_loader(name, loader=None)
module = importlib.util.module_from_spec(spec)
exec(source, module.__dict__)
return module


def import_module_from_fp(name: str, fp: str):
spec = importlib.util.spec_from_file_location(name, fp)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


def translate(affine_mlir_str):
p = Popen(
[
Expand Down Expand Up @@ -106,7 +91,6 @@ def main(args):
with open(f"{artifacts_dir}/{name}_pythonized_mlir.py", "r") as f:
pythonized_mlir = f.read()

output_name = "UNKNOWN"
if args.rewrite:
rewritten_py_code = rewrite(pythonized_mlir)
if DEBUG:
Expand Down Expand Up @@ -158,7 +142,8 @@ def main(args):

verilog_file, input_wires, output_wires, max_fsm_stage = emit_verilog(
name,
args.wE + args.wF + 3,
args.wE,
args.wF,
op_id_data,
func_args,
returns,
Expand All @@ -177,11 +162,13 @@ def main(args):
if args.testbench:
testbench_runner(
proj_path=f"{artifacts_dir}",
module_fp=os.path.abspath(f"{artifacts_dir}/{name}_rewritten.py"),
sv_file_name=f"{name}.sv",
top_level=name,
py_module=f"{name}_tb",
max_fsm_stage=max_fsm_stage,
output_name=output_name,
wE=args.wE,
wF=args.wF,
)

os.remove(f"{artifacts_dir}/{name}_rewritten.mlir")
Expand All @@ -195,8 +182,10 @@ def main(args):
parser.add_argument("-r", "--rewrite", default=False, action="store_true")
parser.add_argument("-s", "--schedule", default=False, action="store_true")
parser.add_argument("-v", "--verilog", default=False, action="store_true")
parser.add_argument("-b", "--testbench", default=False, action="store_true")
parser.add_argument("--wE", default=4)
parser.add_argument("--wF", default=4)
parser.add_argument("-b", "--testbench", default=False, action="store_true")
args = parser.parse_args()
args.wE = int(args.wE)
args.wF = int(args.wF)
main(args)
8 changes: 3 additions & 5 deletions bragghls/flopoco/convert_flopoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
import flopoco_converter


def convert_float_to_flopoco_binary_str(f, width_exp=4, width_frac=4):
def convert_float_to_flopoco_binary_str(f, width_exp, width_frac):
s = flopoco_converter.fp2binstr(width_exp, width_frac, str(f))
assert len(s) == width_exp + width_frac + 2 + 1
return s


def convert_flopoco_binary_str_to_float(s, width_exp=4, width_frac=4):
def convert_flopoco_binary_str_to_float(s, width_exp, width_frac):
assert len(s) == width_exp + width_frac + 2 + 1
return float(flopoco_converter.bin2fp(width_exp, width_frac, s))


if __name__ == "__main__":
print(convert_flopoco_binary_str_to_float("01011010001"))
for i in range(100):
print("#", i, f"b{convert_float_to_flopoco_binary_str(i)}")
print("fmul_x", convert_flopoco_binary_str_to_float("0101000000000", 5, 5))
21 changes: 11 additions & 10 deletions bragghls/flopoco/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
import numpy as np

import bragghls.state
from bragghls.ops import chunks
from bragghls.util import idx_to_str
from bragghls.util import idx_to_str, chunks

try:
from . import flopoco_converter
except:
import flopoco_converter

WE = int(os.getenv("WE", "4"))
WF = int(os.getenv("WE", "4"))
WE = int(os.getenv("WE"))
WF = int(os.getenv("WF"))

FPNUMBER = namedtuple("FPNUMBER", "pe_idx")(None)

Expand All @@ -37,8 +36,8 @@ def ReduceAdd(vals):
@dataclass(frozen=True)
class Val:
ieee: float
wE: int = WE
wF: int = WF
wE: int
wF: int
fp: flopoco_converter.FPNumber = None
name: str = None

Expand Down Expand Up @@ -111,6 +110,7 @@ def __init__(
self.registers = np.empty(shape, dtype=object)
self.input = input
self.output = output
assert wE is not None and wF is not None
self.wE = wE
self.wF = wF

Expand Down Expand Up @@ -186,6 +186,7 @@ def __init__(self, global_name, global_array: np.ndarray, wE=WE, wF=WF):
self.global_array = global_array
self.shape = global_array.shape
self.vals = np.empty(self.shape, dtype=object)
assert wE is not None and wF is not None
for idx, v in np.ndenumerate(global_array):
v = Val(v, wE, wF)
try:
Expand Down Expand Up @@ -220,12 +221,12 @@ def __repr__(self):


class FMAC:
wE = WE
wF = WF

def __init__(self, *pe_idx):
def __init__(self, *pe_idx, wE=WE, wF=WF):
assert wE is not None and wF is not None
assert pe_idx
self.pe_idx = pe_idx
self.wE = wE
self.wF = wF
self.result = Val(0, self.wE, self.wF)

def Add(self, a, b):
Expand Down
Empty file added bragghls/ir/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion bragghls/memref.py → bragghls/ir/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from bragghls import state
from bragghls.ops import make_constant, Val, ReduceAdd
from bragghls.ir.ops import Val, make_constant, ReduceAdd
from bragghls.util import idx_to_str

MemRefIndex = Tuple[int, ...]
Expand Down
File renamed without changes.
17 changes: 6 additions & 11 deletions bragghls/ops.py → bragghls/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np

from bragghls import state
from bragghls.state import DTYPE, CONSTANT
from bragghls.util import extend_idx
from bragghls.state import DTYPE, CONSTANT, ADD_LATENCY, MUL_LATENCY
from bragghls.util import extend_idx, chunks


def overload_op(type):
Expand Down Expand Up @@ -97,14 +97,14 @@ def emit(self):
return f'{self.res} = "{self.type.value}" ({args_str}) {{ {attrs_str} }} : ({", ".join([DTYPE] * len(self.args))}) -> {DTYPE}'


FMAC_LATENCY = lambda n_elements: 3 * n_elements + 2
FMAC_LATENCY = lambda n_elements: MUL_LATENCY + ADD_LATENCY * n_elements


class Latencies:
latencies = {
OpType.ADD: 3,
OpType.SUB: 3,
OpType.MUL: 2,
OpType.ADD: ADD_LATENCY,
OpType.SUB: ADD_LATENCY,
OpType.MUL: MUL_LATENCY,
OpType.DIV: 3,
OpType.GT: 1,
OpType.NEG: 1,
Expand Down Expand Up @@ -231,11 +231,6 @@ def Result(self, copy=True):
return op_res


def chunks(lst, n):
for i in range(0, len(lst), n):
yield lst[i : i + n]


def reducer(accum, val):
if isinstance(val[0], Val):
state.state.update_current_pe_idx(val=val[0])
Expand Down
2 changes: 1 addition & 1 deletion bragghls/parse.py → bragghls/ir/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from torch_mlir._mlir_libs._mlir.ir import Context, Module, OpView, FunctionType

from bragghls.ops import OPS, OpType, Op, LATENCIES
from bragghls.ir.ops import OpType, OPS, Op, LATENCIES


def traverse_op_region_block_iterators(op, handler):
Expand Down
2 changes: 1 addition & 1 deletion bragghls/transforms.py → bragghls/ir/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import astor

from bragghls.parse import parse_mlir_module, reg_idents
from bragghls.ir.parse import parse_mlir_module, reg_idents


class RemoveMAC(ast.NodeTransformer):
Expand Down
67 changes: 24 additions & 43 deletions bragghls/rtl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
from textwrap import dedent, indent

from bragghls.flopoco.convert_flopoco import convert_float_to_flopoco_binary_str
from bragghls.state import USING_FLOPOCO


@dataclass(frozen=True)
class Wire:
id: str
bit_width: int
signal_width: int

def __str__(self):
return f"{self.id}"

def instantiate(self):
if self.bit_width > 1:
return f"wire [{self.bit_width - 1}:0] {self};"
if self.signal_width > 1:
return f"wire [{self.signal_width - 1}:0] {self};"
else:
return f"wire {self};"

Expand All @@ -26,26 +27,30 @@ def __lt__(self, other):
@dataclass(frozen=True)
class Reg:
id: str
bit_width: int
signal_width: int

def __str__(self):
return f"{self.id}"

def instantiate(self):
if self.bit_width > 1:
return f"reg [{self.bit_width - 1}:0] {self};"
if self.signal_width > 1:
return f"reg [{self.signal_width - 1}:0] {self};"
else:
return f"reg {self};"


def make_constant(v, precision):
if v is None:
# return f"{precision}'d{random.randint(0, 2 ** precision - 1)}"
return f"{precision}'b01001110000"
def make_constant(v, width_exp, width_frac):
if USING_FLOPOCO:
signal_width = width_exp + width_frac + 3
else:
# %val_cst_00
assert isinstance(v, (float, int))
return f"{precision}'b{convert_float_to_flopoco_binary_str(v)}"
raise Exception("not using flopoco thus invalid signal width")

if v is None:
v = 0.0

# %val_cst_00
assert isinstance(v, (float, int))
return f"{signal_width}'b{convert_float_to_flopoco_binary_str(v, width_exp, width_frac)}"


class CombOrSeq(enum.Enum):
Expand Down Expand Up @@ -77,46 +82,22 @@ def make_always_branch(lefts, rights, cond, comb_or_seq=CombOrSeq.SEQ):
)


def make_fmac_branches(pe, fsm_states, init_val, args):
def make_fmac_branches(pe, fmul_states, fadd_states, init_val, args):
return indent(
dedent(
"\n".join(
[
f"""\
if (1'b1 == {fsm_states[0]}) begin
{pe.fmul.x} <= {args[0]};
{pe.fmul.y} <= {args[1]};
{pe.fmul.ce} <= 1;
end
if (1'b1 == {fsm_states[1]}) begin
{pe.fadd.x} <= {init_val};
{pe.fadd.y} <= {pe.fmul.r};
{pe.fadd.ce} <= 1;
end
"""
]
+ [
f"""\
if (1'b1 == {fsm_state}) begin
{pe.fmul.x} <= {args[2 * (i + 1)]};
{pe.fmul.y} <= {args[2 * (i + 1) + 1]};
{pe.fmul.ce} <= 1;
{pe.fadd.ce} <= 1;
{pe.fmul.x} <= {args[2 * i]};
{pe.fmul.y} <= {args[2 * i + 1]};
end
if (1'b1 == {fsm_states[2 * i + 2 + 1]}) begin
{pe.fadd.x} <= {pe.fadd.r};
if (1'b1 == {fadd_states[i]}) begin
{pe.fadd.x} <= {init_val if i == 0 else pe.fadd.r};
{pe.fadd.y} <= {pe.fmul.r};
{pe.fadd.ce} <= 1;
end
"""
for i, fsm_state in enumerate(fsm_states[2:-1:2])
]
+ [
f"""\
if (1'b1 == {fsm_states[-1]}) begin
{pe.fadd.ce} <= 1;
end
"""
for i, fsm_state in enumerate(fmul_states)
]
)
),
Expand Down
Loading

0 comments on commit e1c02e6

Please sign in to comment.