Skip to content

Commit

Permalink
[LoopScheduleToCalyx] deduplicate groups within a ParOp. (#8055)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgyurgyik authored Jan 9, 2025
1 parent 76c562b commit 4a062a2
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 19 deletions.
17 changes: 17 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,23 @@ struct EliminateUnusedCombGroups : mlir::OpRewritePattern<calyx::CombGroupOp> {
PatternRewriter &rewriter) const override;
};

/// Removes duplicate EnableOps in parallel operations.
struct DeduplicateParallelOp : mlir::OpRewritePattern<calyx::ParOp> {
using mlir::OpRewritePattern<calyx::ParOp>::OpRewritePattern;

LogicalResult matchAndRewrite(calyx::ParOp parOp,
PatternRewriter &rewriter) const override;
};

/// Removes duplicate EnableOps in static parallel operations.
struct DeduplicateStaticParallelOp
: mlir::OpRewritePattern<calyx::StaticParOp> {
using mlir::OpRewritePattern<calyx::StaticParOp>::OpRewritePattern;

LogicalResult matchAndRewrite(calyx::StaticParOp parOp,
PatternRewriter &rewriter) const override;
};

/// This pass recursively inlines use-def chains of combinational logic (from
/// non-stateful groups) into groups referenced in the control schedule.
class InlineCombGroups
Expand Down
57 changes: 38 additions & 19 deletions lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

#include <type_traits>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -126,6 +127,19 @@ class PipelineScheduler : public calyx::SchedulerInterface<Scheduleable> {
return pipelineRegs[stage];
}

/// Returns the pipeline register for this value if its defining operation is
/// a stage, and std::nullopt otherwise.
std::optional<calyx::RegisterOp> getPipelineRegister(Value value) {
auto opStage = dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp());
if (opStage == nullptr)
return std::nullopt;
// The pipeline register for this input value needs to be discovered.
auto opResult = cast<OpResult>(value);
unsigned int opNumber = opResult.getResultNumber();
auto &stageRegisters = getPipelineRegs(opStage);
return stageRegisters.find(opNumber)->second;
}

/// Add a stage's groups to the pipeline prologue.
void addPipelinePrologue(Operation *op, SmallVector<StringAttr> groupNames) {
pipelinePrologue[op].push_back(groupNames);
Expand Down Expand Up @@ -306,9 +320,14 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
/// Create assignments to the inputs of the library op.
auto group = createGroupForOp<TGroupOp>(rewriter, op);
rewriter.setInsertionPointToEnd(group.getBodyBlock());
for (auto dstOp : enumerate(opInputPorts))
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));
for (auto dstOp : enumerate(opInputPorts)) {
Value srcOp = op->getOperand(dstOp.index());
std::optional<calyx::RegisterOp> pipelineRegister =
getState<ComponentLoweringState>().getPipelineRegister(srcOp);
if (pipelineRegister.has_value())
srcOp = pipelineRegister->getOut();
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), srcOp);
}

/// Replace the result values of the source operator with the new operator.
for (auto res : enumerate(opOutputPorts)) {
Expand Down Expand Up @@ -1055,22 +1074,17 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
Value value = operand.get();

// Get the pipeline register for that result.
auto pipelineRegister = pipelineRegisters[i];
calyx::RegisterOp pipelineRegister = pipelineRegisters[i];
if (std::optional<calyx::RegisterOp> pr =
state.getPipelineRegister(value)) {
value = pr->getOut();
}

calyx::GroupOp group;
// Get the evaluating group for that value.
std::optional<calyx::GroupInterface> evaluatingGroup =
state.findEvaluatingGroup(value);
if (!evaluatingGroup.has_value()) {
if (auto opStage =
dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp())) {
// The pipeline register for this input value needs to be discovered.
auto opResult = cast<OpResult>(value);
unsigned int opNumber = opResult.getResultNumber();
auto &stageRegisters = state.getPipelineRegs(opStage);
calyx::RegisterOp opRegister = stageRegisters.find(opNumber)->second;
value = opRegister.getOut(); // Pass the `out` wire of this register.
}
if (value.getDefiningOp<calyx::RegisterOp>() == nullptr) {
// We add this for any unhandled cases.
llvm::errs() << "unexpected: input value: " << value << ", in stage "
Expand Down Expand Up @@ -1166,8 +1180,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
}
doneOp.getSrcMutable().assign(pipelineRegister.getDone());

// Remove the old register completely.
rewriter.eraseOp(tempReg);
// Remove the old register if it has no more uses.
if (tempReg->use_empty())
rewriter.eraseOp(tempReg);

return group;
}
Expand Down Expand Up @@ -1534,10 +1549,11 @@ class LoopScheduleToCalyxPass
if (runOnce)
config.maxIterations = 1;

/// Can't return applyPatternsGreedily. Root isn't
/// Can't return applyPatternsAndFoldGreedily. Root isn't
/// necessarily erased so it will always return failed(). Instead,
/// forward the 'succeeded' value from PartialLoweringPatternBase.
(void)applyPatternsGreedily(getOperation(), std::move(pattern), config);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
config);
return partialPatternRes;
}

Expand Down Expand Up @@ -1628,6 +1644,9 @@ void LoopScheduleToCalyxPass::runOnOperation() {
addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
*loweringState);

addGreedyPattern<calyx::DeduplicateParallelOp>(loweringPatterns);
addGreedyPattern<calyx::DeduplicateStaticParallelOp>(loweringPatterns);

/// This pattern performs various SSA replacements that must be done
/// after control generation.
addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
Expand Down Expand Up @@ -1665,8 +1684,8 @@ void LoopScheduleToCalyxPass::runOnOperation() {
RewritePatternSet cleanupPatterns(&getContext());
cleanupPatterns.add<calyx::MultipleGroupDonePattern,
calyx::NonTerminatingGroupDonePattern>(&getContext());
if (failed(
applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(cleanupPatterns)))) {
signalPassFailure();
return;
}
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ using namespace mlir::arith;
namespace circt {
namespace calyx {

template <typename OpTy>
static LogicalResult deduplicateParallelOperation(OpTy parOp,
PatternRewriter &rewriter) {
auto *body = parOp.getBodyBlock();
if (body->getOperations().size() < 2)
return failure();

LogicalResult result = LogicalResult::failure();
SetVector<StringRef> members;
for (auto &op : make_early_inc_range(*body)) {
auto enableOp = dyn_cast<EnableOp>(&op);
if (enableOp == nullptr)
continue;
bool inserted = members.insert(enableOp.getGroupName());
if (!inserted) {
rewriter.eraseOp(enableOp);
result = LogicalResult::success();
}
}
return result;
}

void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName,
Value memref, unsigned memoryID,
SmallVectorImpl<calyx::PortInfo> &inPorts,
Expand Down Expand Up @@ -609,6 +631,22 @@ EliminateUnusedCombGroups::matchAndRewrite(calyx::CombGroupOp combGroupOp,
return success();
}

//===----------------------------------------------------------------------===//
// DeduplicateParallelOperations
//===----------------------------------------------------------------------===//

LogicalResult
DeduplicateParallelOp::matchAndRewrite(calyx::ParOp parOp,
PatternRewriter &rewriter) const {
return deduplicateParallelOperation<calyx::ParOp>(parOp, rewriter);
}

LogicalResult
DeduplicateStaticParallelOp::matchAndRewrite(calyx::StaticParOp parOp,
PatternRewriter &rewriter) const {
return deduplicateParallelOperation<calyx::StaticParOp>(parOp, rewriter);
}

//===----------------------------------------------------------------------===//
// InlineCombGroups
//===----------------------------------------------------------------------===//
Expand Down
62 changes: 62 additions & 0 deletions test/Conversion/LoopScheduleToCalyx/pipeline_register_pass.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: circt-opt %s -lower-loopschedule-to-calyx -canonicalize -split-input-file | FileCheck %s

// This will introduce duplicate groups; these should be subsequently removed during canonicalization.

// CHECK: calyx.while %std_lt_0.out with @bb0_0 {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @bb0_1
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
loopschedule.register %S0: index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}

// -----

// Stage pipeline registers passed directly to the next stage
// should also be updated when used in computations.

// CHECK: calyx.group @bb0_2 {
// CHECK-NEXT: calyx.assign %std_add_1.left = %while_0_arg0_reg.out : i32
// CHECK-NEXT: calyx.assign %std_add_1.right = %c1_i32 : i32
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %std_add_1.out : i32
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1
// CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
%math = arith.addi %S0, %const : index
loopschedule.register %math : index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}

0 comments on commit 4a062a2

Please sign in to comment.