Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopScheduleToCalyx] deduplicate groups within a ParOp. #8055

Merged
merged 13 commits into from
Jan 9, 2025
36 changes: 32 additions & 4 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/JSON.h"

#include <optional>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -404,12 +405,13 @@ class ComponentLoweringStateInterface {
/// Put the name of the callee and the instance of the call into map.
void addInstance(StringRef calleeName, InstanceOp instanceOp);

/// Return the group which evaluates the value v. Optionally, caller may
/// specify the expected type of the group.
/// Returns the evaluating group or None if not found.
template <typename TGroupOp = calyx::GroupInterface>
TGroupOp getEvaluatingGroup(Value v) {
std::optional<TGroupOp> findEvaluatingGroup(Value v) {
auto it = valueGroupAssigns.find(v);
assert(it != valueGroupAssigns.end() && "No group evaluating value!");
if (it == valueGroupAssigns.end())
return std::nullopt;

if constexpr (std::is_same_v<TGroupOp, calyx::GroupInterface>)
return it->second;
else {
Expand All @@ -419,6 +421,15 @@ class ComponentLoweringStateInterface {
}
}

/// Return the group which evaluates the value v. Optionally, caller may
/// specify the expected type of the group.
template <typename TGroupOp = calyx::GroupInterface>
TGroupOp getEvaluatingGroup(Value v) {
std::optional<TGroupOp> group = findEvaluatingGroup<TGroupOp>(v);
assert(group.has_value() && "No group evaluating value!");
return *group;
}

template <typename T, typename = void>
struct IsFloatingPoint : std::false_type {};

Expand Down Expand Up @@ -745,6 +756,23 @@ struct EliminateUnusedCombGroups : mlir::OpRewritePattern<calyx::CombGroupOp> {
PatternRewriter &rewriter) const override;
};

/// Removes duplicate EnableOps in parallel operations.
struct DeduplicateParallelOp : mlir::OpRewritePattern<calyx::ParOp> {
using mlir::OpRewritePattern<calyx::ParOp>::OpRewritePattern;

LogicalResult matchAndRewrite(calyx::ParOp parOp,
PatternRewriter &rewriter) const override;
};

/// Removes duplicate EnableOps in static parallel operations.
struct DeduplicateStaticParallelOp
: mlir::OpRewritePattern<calyx::StaticParOp> {
using mlir::OpRewritePattern<calyx::StaticParOp>::OpRewritePattern;

LogicalResult matchAndRewrite(calyx::StaticParOp parOp,
PatternRewriter &rewriter) const override;
};

/// This pass recursively inlines use-def chains of combinational logic (from
/// non-stateful groups) into groups referenced in the control schedule.
class InlineCombGroups
Expand Down
100 changes: 71 additions & 29 deletions lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

#include <type_traits>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -126,6 +127,19 @@ class PipelineScheduler : public calyx::SchedulerInterface<Scheduleable> {
return pipelineRegs[stage];
}

/// Returns the pipeline register for this value if its defining operation is
/// a stage, and std::nullopt otherwise.
std::optional<calyx::RegisterOp> getPipelineRegister(Value value) {
auto opStage = dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp());
if (opStage == nullptr)
return std::nullopt;
// The pipeline register for this input value needs to be discovered.
auto opResult = cast<OpResult>(value);
unsigned int opNumber = opResult.getResultNumber();
auto &stageRegisters = getPipelineRegs(opStage);
return stageRegisters.find(opNumber)->second;
}

/// Add a stage's groups to the pipeline prologue.
void addPipelinePrologue(Operation *op, SmallVector<StringAttr> groupNames) {
pipelinePrologue[op].push_back(groupNames);
Expand Down Expand Up @@ -306,9 +320,14 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
/// Create assignments to the inputs of the library op.
auto group = createGroupForOp<TGroupOp>(rewriter, op);
rewriter.setInsertionPointToEnd(group.getBodyBlock());
for (auto dstOp : enumerate(opInputPorts))
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));
for (auto dstOp : enumerate(opInputPorts)) {
Value srcOp = op->getOperand(dstOp.index());
std::optional<calyx::RegisterOp> pipelineRegister =
getState<ComponentLoweringState>().getPipelineRegister(srcOp);
if (pipelineRegister.has_value())
srcOp = pipelineRegister->getOut();
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), srcOp);
}

/// Replace the result values of the source operator with the new operator.
for (auto res : enumerate(opOutputPorts)) {
Expand Down Expand Up @@ -1019,11 +1038,11 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {

// Collect group names for the prologue or epilogue.
SmallVector<StringAttr> prologueGroups, epilogueGroups;
auto &state = getState<ComponentLoweringState>();

auto updatePrologueAndEpilogue = [&](calyx::GroupOp group) {
// Mark the group for scheduling in the pipeline's block.
getState<ComponentLoweringState>().addBlockScheduleable(stage->getBlock(),
group);
state.addBlockScheduleable(stage->getBlock(), group);

// Add the group to the prologue or epilogue for this stage as
// necessary. The goal is to fill the pipeline so it will be in steady
Expand All @@ -1046,8 +1065,7 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
// Covers the case where there are no values that need to be passed
// through to the next stage, e.g., some intermediary store.
for (auto &op : stage.getBodyBlock())
if (auto group = getState<ComponentLoweringState>()
.getNonPipelinedGroupFrom<calyx::GroupOp>(&op))
if (auto group = state.getNonPipelinedGroupFrom<calyx::GroupOp>(&op))
updatePrologueAndEpilogue(*group);
}

Expand All @@ -1056,28 +1074,48 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
Value value = operand.get();

// Get the pipeline register for that result.
auto pipelineRegister = pipelineRegisters[i];

// Get the evaluating group for that value.
calyx::GroupInterface evaluatingGroup =
getState<ComponentLoweringState>().getEvaluatingGroup(value);
calyx::RegisterOp pipelineRegister = pipelineRegisters[i];
if (std::optional<calyx::RegisterOp> pr =
state.getPipelineRegister(value)) {
value = pr->getOut();
}

// Remember the final group for this stage result.
calyx::GroupOp group;

// Stitch the register in, depending on whether the group was
// combinational or sequential.
if (auto combGroup =
dyn_cast<calyx::CombGroupOp>(evaluatingGroup.getOperation()))
group =
convertCombToSeqGroup(combGroup, pipelineRegister, value, rewriter);
else
group =
replaceGroupRegister(evaluatingGroup, pipelineRegister, rewriter);

// Replace the stage result uses with the register out.
stage.getResult(i).replaceAllUsesWith(pipelineRegister.getOut());

// Get the evaluating group for that value.
std::optional<calyx::GroupInterface> evaluatingGroup =
state.findEvaluatingGroup(value);
if (!evaluatingGroup.has_value()) {
if (value.getDefiningOp<calyx::RegisterOp>() == nullptr) {
// We add this for any unhandled cases.
llvm::errs() << "unexpected: input value: " << value << ", in stage "
<< stage.getStageNumber() << " register " << i
<< " is not a register and was not previously "
"evaluated in a Calyx group. Please open an issue.\n";
return LogicalResult::failure();
}
// This is a register's `out` value being written to this pipeline
// register. We create a new group to build this assignment.
std::string groupName = state.getUniqueName(
loweringState().blockName(pipelineRegister->getBlock()));
group = calyx::createGroup<calyx::GroupOp>(
rewriter, state.getComponentOp(), pipelineRegister->getLoc(),
groupName);
calyx::buildAssignmentsForRegisterWrite(
rewriter, group, state.getComponentOp(), pipelineRegister, value);
} else {
// This was previously evaluated. Stitch the register in, depending on
// whether the group was combinational or sequential.
auto combGroup =
dyn_cast<calyx::CombGroupOp>(evaluatingGroup->getOperation());
group = combGroup == nullptr
? replaceGroupRegister(*evaluatingGroup, pipelineRegister,
rewriter)
: convertCombToSeqGroup(combGroup, pipelineRegister, value,
rewriter);

// Replace the stage result uses with the register out.
stage.getResult(i).replaceAllUsesWith(pipelineRegister.getOut());
}
updatePrologueAndEpilogue(group);
}

Expand Down Expand Up @@ -1142,8 +1180,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
}
doneOp.getSrcMutable().assign(pipelineRegister.getDone());

// Remove the old register completely.
rewriter.eraseOp(tempReg);
// Remove the old register if it has no more uses.
if (tempReg->use_empty())
rewriter.eraseOp(tempReg);

return group;
}
Expand Down Expand Up @@ -1605,6 +1644,9 @@ void LoopScheduleToCalyxPass::runOnOperation() {
addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
*loweringState);

addGreedyPattern<calyx::DeduplicateParallelOp>(loweringPatterns);
addGreedyPattern<calyx::DeduplicateStaticParallelOp>(loweringPatterns);

/// This pattern performs various SSA replacements that must be done
/// after control generation.
addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ using namespace mlir::arith;
namespace circt {
namespace calyx {

template <typename OpTy>
static LogicalResult deduplicateParallelOperation(OpTy parOp,
PatternRewriter &rewriter) {
auto *body = parOp.getBodyBlock();
if (body->getOperations().size() < 2)
return failure();

LogicalResult result = LogicalResult::failure();
SetVector<StringRef> members;
for (auto &op : make_early_inc_range(*body)) {
auto enableOp = dyn_cast<EnableOp>(&op);
if (enableOp == nullptr)
continue;
bool inserted = members.insert(enableOp.getGroupName());
if (!inserted) {
rewriter.eraseOp(enableOp);
result = LogicalResult::success();
}
}
return result;
}

void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName,
Value memref, unsigned memoryID,
SmallVectorImpl<calyx::PortInfo> &inPorts,
Expand Down Expand Up @@ -609,6 +631,22 @@ EliminateUnusedCombGroups::matchAndRewrite(calyx::CombGroupOp combGroupOp,
return success();
}

//===----------------------------------------------------------------------===//
// DeduplicateParallelOperations
//===----------------------------------------------------------------------===//

LogicalResult
DeduplicateParallelOp::matchAndRewrite(calyx::ParOp parOp,
PatternRewriter &rewriter) const {
return deduplicateParallelOperation<calyx::ParOp>(parOp, rewriter);
}

LogicalResult
DeduplicateStaticParallelOp::matchAndRewrite(calyx::StaticParOp parOp,
PatternRewriter &rewriter) const {
return deduplicateParallelOperation<calyx::StaticParOp>(parOp, rewriter);
}

//===----------------------------------------------------------------------===//
// InlineCombGroups
//===----------------------------------------------------------------------===//
Expand Down
44 changes: 44 additions & 0 deletions test/Conversion/LoopScheduleToCalyx/convert_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,47 @@ module {
}
}

// -----

// Pipeline register to pipeline register writes are valid.

//CHECK: calyx.group @bb0_3 {
//CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %stage_0_register_0_reg.out : i32
//CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1
//CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1
//CHECK-NEXT: }
//CHECK-NEXT: calyx.group @bb0_4 {
//CHECK-NEXT: calyx.assign %stage_2_register_0_reg.in = %stage_1_register_0_reg.out : i32
//CHECK-NEXT: calyx.assign %stage_2_register_0_reg.write_en = %true : i1
//CHECK-NEXT: calyx.group_done %stage_2_register_0_reg.done : i1
//CHECK-NEXT: }
module {
func.func @foo(%arg0: i32, %arg1: memref<30x30xi32>) attributes {llvm.linkage = #llvm.linkage<external>} {
%alloca = memref.alloca() : memref<30x30xi32>
cf.br ^bb4
^bb4:
%0 = arith.constant 0 : index
%c1_5 = arith.constant 1 : index
loopschedule.pipeline II = 2 trip_count = 20 iter_args(%arg5 = %0) : (index) -> () {
%6 = arith.cmpi ult, %arg5, %0 : index
loopschedule.register %6 : i1
} do {
%6:2 = loopschedule.pipeline.stage start = 0 {
%10 = memref.load %arg1[%0, %0] : memref<30x30xi32>
%11 = arith.addi %arg5, %c1_5 : index
loopschedule.register %10, %11 : i32, index
} : i32, index
%7 = loopschedule.pipeline.stage start = 1 {
loopschedule.register %6#0 : i32
} : i32
%8 = loopschedule.pipeline.stage start = 2 {
loopschedule.register %7 : i32
} : i32
loopschedule.terminator iter_args(%6#1), results() : (index) -> ()
}
cf.br ^bb6
^bb6:
return
}
}

62 changes: 62 additions & 0 deletions test/Conversion/LoopScheduleToCalyx/pipeline_register_pass.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: circt-opt %s -lower-loopschedule-to-calyx -canonicalize -split-input-file | FileCheck %s

// This will introduce duplicate groups; these should be subsequently removed during canonicalization.

// CHECK: calyx.while %std_lt_0.out with @bb0_0 {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @bb0_1
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
loopschedule.register %S0: index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}

// -----

// Stage pipeline registers passed directly to the next stage
// should also be updated when used in computations.

// CHECK: calyx.group @bb0_2 {
// CHECK-NEXT: calyx.assign %std_add_1.left = %while_0_arg0_reg.out : i32
// CHECK-NEXT: calyx.assign %std_add_1.right = %c1_i32 : i32
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %std_add_1.out : i32
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1
// CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
%math = arith.addi %S0, %const : index
loopschedule.register %math : index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}

Loading