Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopScheduleToCalyx] deduplicate groups within a ParOp. #8055

Merged
merged 13 commits into from
Jan 9, 2025
19 changes: 15 additions & 4 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/JSON.h"

#include <optional>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -404,12 +405,13 @@ class ComponentLoweringStateInterface {
/// Put the name of the callee and the instance of the call into map.
void addInstance(StringRef calleeName, InstanceOp instanceOp);

/// Return the group which evaluates the value v. Optionally, caller may
/// specify the expected type of the group.
/// Returns the evaluating group or None if not found.
template <typename TGroupOp = calyx::GroupInterface>
TGroupOp getEvaluatingGroup(Value v) {
std::optional<TGroupOp> findEvaluatingGroup(Value v) {
auto it = valueGroupAssigns.find(v);
assert(it != valueGroupAssigns.end() && "No group evaluating value!");
if (it == valueGroupAssigns.end())
return std::nullopt;

if constexpr (std::is_same_v<TGroupOp, calyx::GroupInterface>)
return it->second;
else {
Expand All @@ -419,6 +421,15 @@ class ComponentLoweringStateInterface {
}
}

/// Return the group which evaluates the value v. Optionally, caller may
/// specify the expected type of the group.
template <typename TGroupOp = calyx::GroupInterface>
TGroupOp getEvaluatingGroup(Value v) {
std::optional<TGroupOp> group = findEvaluatingGroup<TGroupOp>(v);
assert(group.has_value() && "No group evaluating value!");
return *group;
}

template <typename T, typename = void>
struct IsFloatingPoint : std::false_type {};

Expand Down
75 changes: 50 additions & 25 deletions lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,11 +1019,11 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {

// Collect group names for the prologue or epilogue.
SmallVector<StringAttr> prologueGroups, epilogueGroups;
auto &state = getState<ComponentLoweringState>();

auto updatePrologueAndEpilogue = [&](calyx::GroupOp group) {
// Mark the group for scheduling in the pipeline's block.
getState<ComponentLoweringState>().addBlockScheduleable(stage->getBlock(),
group);
state.addBlockScheduleable(stage->getBlock(), group);

// Add the group to the prologue or epilogue for this stage as
// necessary. The goal is to fill the pipeline so it will be in steady
Expand All @@ -1046,8 +1046,7 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
// Covers the case where there are no values that need to be passed
// through to the next stage, e.g., some intermediary store.
for (auto &op : stage.getBodyBlock())
if (auto group = getState<ComponentLoweringState>()
.getNonPipelinedGroupFrom<calyx::GroupOp>(&op))
if (auto group = state.getNonPipelinedGroupFrom<calyx::GroupOp>(&op))
updatePrologueAndEpilogue(*group);
}

Expand All @@ -1058,26 +1057,51 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
// Get the pipeline register for that result.
auto pipelineRegister = pipelineRegisters[i];

// Get the evaluating group for that value.
calyx::GroupInterface evaluatingGroup =
getState<ComponentLoweringState>().getEvaluatingGroup(value);

// Remember the final group for this stage result.
calyx::GroupOp group;

// Stitch the register in, depending on whether the group was
// combinational or sequential.
if (auto combGroup =
dyn_cast<calyx::CombGroupOp>(evaluatingGroup.getOperation()))
group =
convertCombToSeqGroup(combGroup, pipelineRegister, value, rewriter);
else
group =
replaceGroupRegister(evaluatingGroup, pipelineRegister, rewriter);

// Replace the stage result uses with the register out.
stage.getResult(i).replaceAllUsesWith(pipelineRegister.getOut());

// Get the evaluating group for that value.
std::optional<calyx::GroupInterface> evaluatingGroup =
state.findEvaluatingGroup(value);
if (!evaluatingGroup.has_value()) {
if (auto opStage =
dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp())) {
// The pipeline register for this input value needs to be discovered.
auto opResult = cast<OpResult>(value);
unsigned int opNumber = opResult.getResultNumber();
auto &stageRegisters = state.getPipelineRegs(opStage);
calyx::RegisterOp opRegister = stageRegisters.find(opNumber)->second;
value = opRegister.getOut(); // Pass the `out` wire of this register.
}
if (value.getDefiningOp<calyx::RegisterOp>() == nullptr) {
// We add this for any unhandled cases.
llvm::errs() << "unexpected: input value: " << value << ", in stage "
<< stage.getStageNumber() << " register " << i
<< " is not a register and was not previously "
"evaluated in a Calyx group. Please open an issue.\n";
return LogicalResult::failure();
}
// This is a register's `out` value being written to this pipeline
// register. We create a new group to build this assignment.
std::string groupName = state.getUniqueName(
loweringState().blockName(pipelineRegister->getBlock()));
group = calyx::createGroup<calyx::GroupOp>(
rewriter, state.getComponentOp(), pipelineRegister->getLoc(),
groupName);
calyx::buildAssignmentsForRegisterWrite(
rewriter, group, state.getComponentOp(), pipelineRegister, value);
} else {
// This was previously evaluated. Stitch the register in, depending on
// whether the group was combinational or sequential.
auto combGroup =
dyn_cast<calyx::CombGroupOp>(evaluatingGroup->getOperation());
group = combGroup == nullptr
? replaceGroupRegister(*evaluatingGroup, pipelineRegister,
rewriter)
: convertCombToSeqGroup(combGroup, pipelineRegister, value,
rewriter);

// Replace the stage result uses with the register out.
stage.getResult(i).replaceAllUsesWith(pipelineRegister.getOut());
}
updatePrologueAndEpilogue(group);
}

Expand Down Expand Up @@ -1142,8 +1166,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
}
doneOp.getSrcMutable().assign(pipelineRegister.getDone());

// Remove the old register completely.
rewriter.eraseOp(tempReg);
// Remove the old register if it has no more uses.
if (tempReg->use_empty())
rewriter.eraseOp(tempReg);

return group;
}
Expand Down
55 changes: 28 additions & 27 deletions lib/Dialect/Calyx/CalyxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,31 @@ static LogicalResult collapseControl(OpTy controlOp,
return failure();
}

// Remove duplicates from a parallel operation. This may inadvertently occur
// during lowering.
template <typename OpTy>
static LogicalResult deduplicate(OpTy parOp, PatternRewriter &rewriter) {
static_assert(IsAny<OpTy, ParOp, StaticParOp>(),
"requires a parallel operation");
auto *body = parOp.getBodyBlock();
if (body->getOperations().size() < 2)
return failure();

LogicalResult result = LogicalResult::failure();
SetVector<StringRef> members;
for (auto &op : make_early_inc_range(*body)) {
auto enableOp = dyn_cast<EnableOp>(&op);
if (enableOp == nullptr)
continue;
bool inserted = members.insert(enableOp.getGroupName());
if (!inserted) {
rewriter.eraseOp(enableOp);
result = LogicalResult::success();
}
}
return result;
}

template <typename OpTy>
static LogicalResult emptyControl(OpTy controlOp, PatternRewriter &rewriter) {
if (controlOp.getBodyBlock()->empty()) {
Expand Down Expand Up @@ -942,24 +967,11 @@ void StaticSeqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// ParOp
//===----------------------------------------------------------------------===//

LogicalResult ParOp::verify() {
llvm::SmallSet<StringRef, 8> groupNames;

// Add loose requirement that the body of a ParOp may not enable the same
// Group more than once, e.g. calyx.par { calyx.enable @G calyx.enable @G }
for (EnableOp op : getBodyBlock()->getOps<EnableOp>()) {
StringRef groupName = op.getGroupName();
if (groupNames.count(groupName))
return emitOpError() << "cannot enable the same group: \"" << groupName
<< "\" more than once.";
groupNames.insert(groupName);
}

return success();
}
LogicalResult ParOp::verify() { return success(); }

void ParOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(deduplicate<ParOp>);
patterns.add(collapseControl<ParOp>);
patterns.add(emptyControl<ParOp>);
patterns.insert<CollapseUnaryControl<ParOp>>(context);
Expand All @@ -970,18 +982,6 @@ void ParOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//

LogicalResult StaticParOp::verify() {
llvm::SmallSet<StringRef, 8> groupNames;

// Add loose requirement that the body of a ParOp may not enable the same
// Group more than once, e.g. calyx.par { calyx.enable @G calyx.enable @G }
for (EnableOp op : getBodyBlock()->getOps<EnableOp>()) {
StringRef groupName = op.getGroupName();
if (groupNames.count(groupName))
return emitOpError() << "cannot enable the same group: \"" << groupName
<< "\" more than once.";
groupNames.insert(groupName);
}

// static par must only have static control in it
auto &ops = (*this).getBodyBlock()->getOperations();
for (Operation &op : ops) {
Expand All @@ -995,6 +995,7 @@ LogicalResult StaticParOp::verify() {

void StaticParOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(deduplicate<StaticParOp>);
patterns.add(collapseControl<StaticParOp>);
patterns.add(emptyControl<StaticParOp>);
patterns.insert<CollapseUnaryControl<StaticParOp>>(context);
Expand Down
44 changes: 44 additions & 0 deletions test/Conversion/LoopScheduleToCalyx/convert_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,47 @@ module {
}
}

// -----

// Pipeline register to pipeline register writes are valid.

//CHECK: calyx.group @bb0_3 {
//CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %stage_0_register_0_reg.out : i32
//CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1
//CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1
//CHECK-NEXT: }
//CHECK-NEXT: calyx.group @bb0_4 {
//CHECK-NEXT: calyx.assign %stage_2_register_0_reg.in = %stage_1_register_0_reg.out : i32
//CHECK-NEXT: calyx.assign %stage_2_register_0_reg.write_en = %true : i1
//CHECK-NEXT: calyx.group_done %stage_2_register_0_reg.done : i1
//CHECK-NEXT: }
module {
func.func @foo(%arg0: i32, %arg1: memref<30x30xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
%alloca = memref.alloca() : memref<30x30xi32>
cf.br ^bb4
^bb4:
%0 = arith.constant 0 : index
%c1_5 = arith.constant 1 : index
loopschedule.pipeline II = 2 trip_count = 20 iter_args(%arg5 = %0) : (index) -> () {
%6 = arith.cmpi ult, %arg5, %0 : index
loopschedule.register %6 : i1
} do {
%6:2 = loopschedule.pipeline.stage start = 0 {
%10 = memref.load %arg1[%0, %0] : memref<30x30xi32>
%11 = arith.addi %arg5, %c1_5 : index
loopschedule.register %10, %11 : i32, index
} : i32, index
%7 = loopschedule.pipeline.stage start = 1 {
loopschedule.register %6#0 : i32
} : i32
%8 = loopschedule.pipeline.stage start = 2 {
loopschedule.register %7 : i32
} : i32
loopschedule.terminator iter_args(%6#1), results() : (index) -> ()
}
cf.br ^bb6
^bb6:
return
}
}

28 changes: 28 additions & 0 deletions test/Conversion/LoopScheduleToCalyx/remove_duplicates.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: circt-opt %s -lower-loopschedule-to-calyx -canonicalize -split-input-file | FileCheck %s

// This will introduce duplicate groups; these should be subsequently removed during canonicalization.

// CHECK: calyx.while %std_lt_0.out with @bb0_0 {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @bb0_1
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
loopschedule.register %S0: index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}
23 changes: 0 additions & 23 deletions test/Dialect/Calyx/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -478,29 +478,6 @@ module attributes {calyx.entrypoint = "main"} {

// -----

module attributes {calyx.entrypoint = "main"} {
calyx.component @main(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%done: i1 {done}) {
%r.in, %r.write_en, %r.clk, %r.reset, %r.out, %r.done = calyx.register @r : i1, i1, i1, i1, i1, i1
%c1_1 = hw.constant 1 : i1
calyx.wires {
calyx.group @A {
calyx.assign %r.in = %c1_1 : i1
calyx.assign %r.write_en = %c1_1 : i1
calyx.group_done %r.done : i1
}
}
calyx.control {
// expected-error @+1 {{'calyx.par' op cannot enable the same group: "A" more than once.}}
calyx.par {
calyx.enable @A
calyx.enable @A
}
}
}
}

// -----

module attributes {calyx.entrypoint = "main"} {
calyx.component @A(%go: i1 {go}, %clk: i1 {clk}, %reset: i1 {reset}) -> (%out: i1, %done: i1 {done}) {
%c1_1 = hw.constant 1 : i1
Expand Down
Loading