Skip to content
This repository has been archived by the owner on May 5, 2024. It is now read-only.

Commit

Permalink
support memref::subview
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Mar 13, 2024
1 parent 86f0a37 commit 5d7666f
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 90 deletions.
8 changes: 7 additions & 1 deletion openhls/compiler/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import os
from textwrap import indent, dedent
from multiprocessing.pool import ThreadPool

import numpy as np

Expand Down Expand Up @@ -118,7 +119,8 @@ def parfor(**kwargs):
kwargs = tuple(tuple(zip_with_scalar(k, range(*v))) for k, v in kwargs.items())

def wrapper(body):
for args in itertools.product(*kwargs):
def worker(args):
print(f"executing {args=}")
idx = tuple(i for arg, i in args)
pe_idx = extend_idx(idx)
state.state.update_current_pe_idx(pe_idx=pe_idx)
Expand All @@ -134,6 +136,10 @@ def wrapper(body):
else:
body(**dict(args))

with ThreadPool(processes=64) as pool:
result = pool.map_async(worker, list(itertools.product(*kwargs)))
print(result.get())

return wrapper


Expand Down
126 changes: 76 additions & 50 deletions openhls/compiler/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

import networkx as nx
from contextlib import contextmanager
from threading import RLock

from openhls.config import VAL_PREFIX, DTYPE, DEBUG, INCLUDE_AUX_DEPS
from openhls.util import extend_idx
Expand All @@ -17,7 +17,6 @@
class State:
_var_count = 0
_op_call_count = 0
op_graph = nx.MultiDiGraph()
cst_map = {}
cst_count = 0
_pe_idx = (0,)
Expand All @@ -26,85 +25,104 @@ class State:
pe_idx_to_most_recent_op_id = {}
op_id_to_pe_idx = {}
pe_deps = set()
rlock = None

def __init__(self, output_file):
self.op_graph.add_nodes_from(
[INPUT_ARG, MEMREF_ARG, GLOBAL_MEMREF_ARG, CONSTANT]
)
self.output_file = output_file
self.rlock = RLock()

@contextmanager
def with_rlock(self):
self.rlock.acquire()
yield
self.rlock.release()

def incr_var(self):
self._var_count += 1
with self.with_rlock():
self._var_count += 1

@property
def curr_var_id(self):
return self._var_count
with self.with_rlock():
return self._var_count

def incr_op_id(self):
self._op_call_count += 1
with self.with_rlock():
self._op_call_count += 1

@property
def curr_op_id(self):
return self._op_call_count
with self.with_rlock():
return self._op_call_count

def emit(self, *args):
print(*args, file=self.output_file)
with self.with_rlock():
print(*args, file=self.output_file)

def debug_print(self, *args):
if DEBUG:
self.emit(*(["//"] + list(args)))
with self.with_rlock():
self.emit(*(["//"] + list(args)))

def add_val_source(self, v, src):
self.val_source[v] = src
with self.with_rlock():
self.val_source[v] = src

def add_global_memref_arg(self, v):
self.val_source[v] = GLOBAL_MEMREF_ARG
with self.with_rlock():
self.val_source[v] = GLOBAL_MEMREF_ARG

def add_memref_arg(self, v):
self.val_source[v] = MEMREF_ARG
with self.with_rlock():
self.val_source[v] = MEMREF_ARG

def add_constant(self, v):
self.val_source[v] = CONSTANT
with self.with_rlock():
self.val_source[v] = CONSTANT

def add_op_res(self, v, op):
self.val_source[v] = op
with self.with_rlock():
self.val_source[v] = op

def maybe_add_op(self, op):
if op not in self.op_graph.nodes:
self.op_graph.add_node(op)
pass

def add_edge(self, op, arg, out_v):
val_source = self.get_arg_src(arg)
self.op_graph.add_edge(val_source, op, input=arg, output=out_v, id=op.op_id)
pass

def update_most_recent_pe_idx(self, pe_idx, op):
self.pe_idx_to_most_recent_op_id[pe_idx] = op.op_id
with self.with_rlock():
self.pe_idx_to_most_recent_op_id[pe_idx] = op.op_id

def get_most_recent_op_id(self, pe_idx):
return self.pe_idx_to_most_recent_op_id[pe_idx]
with self.with_rlock():
return self.pe_idx_to_most_recent_op_id[pe_idx]

def maybe_add_aux_dep(self, pe_idx, op):
if pe_idx in self.pe_idx_to_most_recent_op_id:
prev_op_id = self.get_most_recent_op_id(pe_idx)
self.pe_deps.add((prev_op_id, op.op_id))
self.update_most_recent_pe_idx(pe_idx, op)
with self.with_rlock():
if pe_idx in self.pe_idx_to_most_recent_op_id:
prev_op_id = self.get_most_recent_op_id(pe_idx)
self.pe_deps.add((prev_op_id, op.op_id))
self.update_most_recent_pe_idx(pe_idx, op)

def get_arg_src(self, arg):
assert arg in self.val_source
return self.val_source[arg]
with self.with_rlock():
assert arg in self.val_source
return self.val_source[arg]

def update_current_pe_idx(self, *, pe_idx=None, val=None):
assert pe_idx is not None or val is not None
if val is not None:
src = self.get_arg_src(val)
if isinstance(src, str):
assert src in {INPUT_ARG, MEMREF_ARG, GLOBAL_MEMREF_ARG, CONSTANT}
if src in {MEMREF_ARG, GLOBAL_MEMREF_ARG}:
self.pe_idx = extend_idx(tuple(map(int, val.id.split("_"))))
with self.with_rlock():
assert pe_idx is not None or val is not None
if val is not None:
src = self.get_arg_src(val)
if isinstance(src, str):
assert src in {INPUT_ARG, MEMREF_ARG, GLOBAL_MEMREF_ARG, CONSTANT}
if src in {MEMREF_ARG, GLOBAL_MEMREF_ARG}:
self.pe_idx = extend_idx(tuple(map(int, val.id.split("_"))))
else:
self.pe_idx = src.pe_idx
else:
self.pe_idx = src.pe_idx
else:
self.pe_idx = pe_idx
self.pe_idx = pe_idx

@property
def dtype(self):
Expand All @@ -120,33 +138,41 @@ def val_prefix(self):

@property
def pe_idx(self):
return self._pe_idx
with self.with_rlock():
return self._pe_idx

@pe_idx.setter
def pe_idx(self, x):
self._pe_idx = x
with self.with_rlock():
self._pe_idx = x

def map_val_to_pe(self, v, pe_idx):
self.val_to_pe_idx[v] = pe_idx
with self.with_rlock():
self.val_to_pe_idx[v] = pe_idx

def get_val_pe(self, v):
return self.val_to_pe_idx[v]
with self.with_rlock():
return self.val_to_pe_idx[v]

def swap_output_file(self, new_file):
old_file = self.output_file
self.output_file = new_file
return old_file
with self.with_rlock():
old_file = self.output_file
self.output_file = new_file
return old_file

def read_output_file(self):
self.output_file.seek(0)
return self.output_file.read()
with self.with_rlock():
self.output_file.seek(0)
return self.output_file.read()

@property
def num_unique_pes(self):
return len(set(self.val_to_pe_idx.values()))
with self.with_rlock():
return len(set(self.val_to_pe_idx.values()))

def __del__(self):
self.output_file.close()
with self.with_rlock():
self.output_file.close()


state = None
14 changes: 12 additions & 2 deletions openhls/ir/memref.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from dataclasses import dataclass
from typing import Tuple

Expand Down Expand Up @@ -81,9 +82,18 @@ def reduce_add(self):
def reduce_max(self):
return ReduceMax(list(self.registers.flatten()))

def alias(self, other_memref):
def alias(self, other_memref, offsets=None, sizes=None, strides=None):
assert isinstance(other_memref, MemRef)
self.registers = other_memref.registers
if offsets is not None and sizes is not None and strides is not None:
subview = []
for o, si, st in zip(offsets, sizes, strides):
subview.append(slice(o, o + si, st))
print("subview", subview, file=sys.stderr)
print("before subview", self.registers.shape, file=sys.stderr)
self.registers = other_memref.registers[tuple(subview)]
print("aftier subview", self.registers.shape, file=sys.stderr)
else:
self.registers = other_memref.registers


class GlobalMemRef:
Expand Down
35 changes: 22 additions & 13 deletions openhls_translate/EmitHLSPy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class ModuleEmitter : public OpenHLSEmitterBase {
void emitLoad(memref::LoadOp op);
void emitStore(memref::StoreOp op);
void emitMemCpy(memref::CopyOp op);
void emitMemSubview(memref::SubViewOp op);
void emitGlobal(memref::GlobalOp op);
void emitGetGlobal(memref::GetGlobalOp op);
void emitTensorStore(memref::TensorStoreOp op);
Expand Down Expand Up @@ -420,6 +421,7 @@ class StmtVisitor : public HLSVisitorBase<StmtVisitor, bool> {
bool visitOp(memref::StoreOp op) { return emitter.emitStore(op), true; }
bool visitOp(memref::DeallocOp op) { return true; }
bool visitOp(memref::CopyOp op) { return emitter.emitMemCpy(op), true; }
bool visitOp(memref::SubViewOp op) { return emitter.emitMemSubview(op), true; }
bool visitOp(memref::GlobalOp op) { return emitter.emitGlobal(op), true; }
bool visitOp(memref::GetGlobalOp op) {
return emitter.emitGetGlobal(op), true;
Expand Down Expand Up @@ -1169,33 +1171,40 @@ void ModuleEmitter::emitStore(memref::StoreOp op) {
}

void ModuleEmitter::emitMemCpy(memref::CopyOp op) {
// indent() << "memcpy(";
indent() << "";
// emitValue(op.target());
// os << " = ";
emitValue(op.target());
os << ".alias(";
emitValue(op.getSource());
os << ")";
// os << ", ";
os << "\n";
}

// auto type = op.target().getType().cast<MemRefType>();
// os << type.getNumElements() << " * sizeof(" << getTypeName(op.target())
// << "))";
// os << "\n";
void ModuleEmitter::emitMemSubview(memref::SubViewOp op) {
indent() << "";
emitArrayDecl(op.getResult());
os << "\n";
indent() << "";
emitValue(op.result());
os << ".alias(";
emitValue(op.getSource());
os << ", offsets=" << op.getStaticOffsets();
os << ", sizes=" << op.getStaticSizes();
os << ", strides=" << op.getStaticStrides();
os << ")";
os << "\n";
}

void ModuleEmitter::emitGlobal(memref::GlobalOp op) {
auto initial_val = op.initial_value();
auto elem = initial_val->dyn_cast<DenseFPElementsAttr>();
os << op.sym_name().str() << " = np.array([";
for (const auto &item : elem.getValues<FloatAttr>())
os << item.getValueAsDouble() << ", ";
os << "]).reshape(";

os << op.sym_name().str() << " = np.full((";
for (const auto &item : elem.getType().getShape())
os << item << ", ";
os << "), ";
for (const auto &item : elem.getValues<FloatAttr>()) {
os << item.getValueAsDouble();
break;
}
os << ")\n";
}

Expand Down
3 changes: 2 additions & 1 deletion openhls_translate/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class HLSVisitorBase {
// Memref-related statements.
memref::AllocOp, memref::AllocaOp, memref::LoadOp, memref::StoreOp,
memref::GlobalOp, memref::GetGlobalOp,
memref::DeallocOp, memref::CopyOp, memref::TensorStoreOp,
memref::DeallocOp, memref::CopyOp, memref::SubViewOp, memref::TensorStoreOp,
tensor::ReshapeOp, memref::ReshapeOp, memref::CollapseShapeOp,
memref::ExpandShapeOp, memref::ReinterpretCastOp,
bufferization::ToMemrefOp, bufferization::ToTensorOp,
Expand Down Expand Up @@ -132,6 +132,7 @@ class HLSVisitorBase {
HANDLE(memref::GetGlobalOp);
HANDLE(memref::DeallocOp);
HANDLE(memref::CopyOp);
HANDLE(memref::SubViewOp);
HANDLE(memref::TensorStoreOp);
HANDLE(tensor::ReshapeOp);
HANDLE(memref::ReshapeOp);
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
requires = [
"setuptools>=42",
"wheel",
"cmake==3.21",
"cmake>=3.24",
# MLIR build depends.
"ninja",
"numpy==1.23.1",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ numpy
networkx
astor
jinja2
cocotb==1.6.2
cocotb
matplotlib
xeda
Loading

0 comments on commit 5d7666f

Please sign in to comment.