diff --git a/include/circt/Dialect/Calyx/CalyxLoweringUtils.h b/include/circt/Dialect/Calyx/CalyxLoweringUtils.h index 8cb05f9c115f..220be105e4ee 100644 --- a/include/circt/Dialect/Calyx/CalyxLoweringUtils.h +++ b/include/circt/Dialect/Calyx/CalyxLoweringUtils.h @@ -756,6 +756,23 @@ struct EliminateUnusedCombGroups : mlir::OpRewritePattern { PatternRewriter &rewriter) const override; }; +/// Removes duplicate EnableOps in parallel operations. +struct DeduplicateParallelOp : mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(calyx::ParOp parOp, + PatternRewriter &rewriter) const override; +}; + +/// Removes duplicate EnableOps in static parallel operations. +struct DeduplicateStaticParallelOp + : mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(calyx::StaticParOp parOp, + PatternRewriter &rewriter) const override; +}; + /// This pass recursively inlines use-def chains of combinational logic (from /// non-stateful groups) into groups referenced in the control schedule. class InlineCombGroups diff --git a/lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp b/lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp index 2f51e1a42291..1763df012a2f 100644 --- a/lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp +++ b/lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp @@ -29,6 +29,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" +#include #include namespace circt { @@ -126,6 +127,19 @@ class PipelineScheduler : public calyx::SchedulerInterface { return pipelineRegs[stage]; } + /// Returns the pipeline register for this value if its defining operation is + /// a stage, and std::nullopt otherwise. + std::optional getPipelineRegister(Value value) { + auto opStage = dyn_cast(value.getDefiningOp()); + if (opStage == nullptr) + return std::nullopt; + // The pipeline register for this input value needs to be discovered. + auto opResult = cast(value); + unsigned int opNumber = opResult.getResultNumber(); + auto &stageRegisters = getPipelineRegs(opStage); + return stageRegisters.find(opNumber)->second; + } + /// Add a stage's groups to the pipeline prologue. void addPipelinePrologue(Operation *op, SmallVector groupNames) { pipelinePrologue[op].push_back(groupNames); @@ -306,9 +320,14 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern { /// Create assignments to the inputs of the library op. auto group = createGroupForOp(rewriter, op); rewriter.setInsertionPointToEnd(group.getBodyBlock()); - for (auto dstOp : enumerate(opInputPorts)) - rewriter.create(op.getLoc(), dstOp.value(), - op->getOperand(dstOp.index())); + for (auto dstOp : enumerate(opInputPorts)) { + Value srcOp = op->getOperand(dstOp.index()); + std::optional pipelineRegister = + getState().getPipelineRegister(srcOp); + if (pipelineRegister.has_value()) + srcOp = pipelineRegister->getOut(); + rewriter.create(op.getLoc(), dstOp.value(), srcOp); + } /// Replace the result values of the source operator with the new operator. for (auto res : enumerate(opOutputPorts)) { @@ -1055,22 +1074,17 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern { Value value = operand.get(); // Get the pipeline register for that result. - auto pipelineRegister = pipelineRegisters[i]; + calyx::RegisterOp pipelineRegister = pipelineRegisters[i]; + if (std::optional pr = + state.getPipelineRegister(value)) { + value = pr->getOut(); + } calyx::GroupOp group; // Get the evaluating group for that value. std::optional evaluatingGroup = state.findEvaluatingGroup(value); if (!evaluatingGroup.has_value()) { - if (auto opStage = - dyn_cast(value.getDefiningOp())) { - // The pipeline register for this input value needs to be discovered. - auto opResult = cast(value); - unsigned int opNumber = opResult.getResultNumber(); - auto &stageRegisters = state.getPipelineRegs(opStage); - calyx::RegisterOp opRegister = stageRegisters.find(opNumber)->second; - value = opRegister.getOut(); // Pass the `out` wire of this register. - } if (value.getDefiningOp() == nullptr) { // We add this for any unhandled cases. llvm::errs() << "unexpected: input value: " << value << ", in stage " @@ -1166,8 +1180,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern { } doneOp.getSrcMutable().assign(pipelineRegister.getDone()); - // Remove the old register completely. - rewriter.eraseOp(tempReg); + // Remove the old register if it has no more uses. + if (tempReg->use_empty()) + rewriter.eraseOp(tempReg); return group; } @@ -1534,10 +1549,11 @@ class LoopScheduleToCalyxPass if (runOnce) config.maxIterations = 1; - /// Can't return applyPatternsGreedily. Root isn't + /// Can't return applyPatternsAndFoldGreedily. Root isn't /// necessarily erased so it will always return failed(). Instead, /// forward the 'succeeded' value from PartialLoweringPatternBase. - (void)applyPatternsGreedily(getOperation(), std::move(pattern), config); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern), + config); return partialPatternRes; } @@ -1628,6 +1644,9 @@ void LoopScheduleToCalyxPass::runOnOperation() { addOncePattern(loweringPatterns, patternState, *loweringState); + addGreedyPattern(loweringPatterns); + addGreedyPattern(loweringPatterns); + /// This pattern performs various SSA replacements that must be done /// after control generation. addOncePattern(loweringPatterns, patternState, funcMap, @@ -1665,8 +1684,8 @@ void LoopScheduleToCalyxPass::runOnOperation() { RewritePatternSet cleanupPatterns(&getContext()); cleanupPatterns.add(&getContext()); - if (failed( - applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) { + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(cleanupPatterns)))) { signalPassFailure(); return; } diff --git a/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp b/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp index 1507e9e1e82f..a8616ea2b1a0 100644 --- a/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp +++ b/lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp @@ -29,6 +29,28 @@ using namespace mlir::arith; namespace circt { namespace calyx { +template +static LogicalResult deduplicateParallelOperation(OpTy parOp, + PatternRewriter &rewriter) { + auto *body = parOp.getBodyBlock(); + if (body->getOperations().size() < 2) + return failure(); + + LogicalResult result = LogicalResult::failure(); + SetVector members; + for (auto &op : make_early_inc_range(*body)) { + auto enableOp = dyn_cast(&op); + if (enableOp == nullptr) + continue; + bool inserted = members.insert(enableOp.getGroupName()); + if (!inserted) { + rewriter.eraseOp(enableOp); + result = LogicalResult::success(); + } + } + return result; +} + void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName, Value memref, unsigned memoryID, SmallVectorImpl &inPorts, @@ -609,6 +631,22 @@ EliminateUnusedCombGroups::matchAndRewrite(calyx::CombGroupOp combGroupOp, return success(); } +//===----------------------------------------------------------------------===// +// DeduplicateParallelOperations +//===----------------------------------------------------------------------===// + +LogicalResult +DeduplicateParallelOp::matchAndRewrite(calyx::ParOp parOp, + PatternRewriter &rewriter) const { + return deduplicateParallelOperation(parOp, rewriter); +} + +LogicalResult +DeduplicateStaticParallelOp::matchAndRewrite(calyx::StaticParOp parOp, + PatternRewriter &rewriter) const { + return deduplicateParallelOperation(parOp, rewriter); +} + //===----------------------------------------------------------------------===// // InlineCombGroups //===----------------------------------------------------------------------===// diff --git a/test/Conversion/LoopScheduleToCalyx/pipeline_register_pass.mlir b/test/Conversion/LoopScheduleToCalyx/pipeline_register_pass.mlir new file mode 100644 index 000000000000..9598f75c448d --- /dev/null +++ b/test/Conversion/LoopScheduleToCalyx/pipeline_register_pass.mlir @@ -0,0 +1,62 @@ +// RUN: circt-opt %s -lower-loopschedule-to-calyx -canonicalize -split-input-file | FileCheck %s + +// This will introduce duplicate groups; these should be subsequently removed during canonicalization. + +// CHECK: calyx.while %std_lt_0.out with @bb0_0 { +// CHECK-NEXT: calyx.par { +// CHECK-NEXT: calyx.enable @bb0_1 +// CHECK-NEXT: } +// CHECK-NEXT: } +module { + func.func @foo() attributes {} { + %const = arith.constant 1 : index + loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () { + %latch = arith.cmpi ult, %counter, %const : index + loopschedule.register %latch : i1 + } do { + %S0 = loopschedule.pipeline.stage start = 0 { + %op = arith.addi %counter, %const : index + loopschedule.register %op : index + } : index + %S1 = loopschedule.pipeline.stage start = 1 { + loopschedule.register %S0: index + } : index + loopschedule.terminator iter_args(%S0), results() : (index) -> () + } + return + } +} + +// ----- + +// Stage pipeline registers passed directly to the next stage +// should also be updated when used in computations. + +// CHECK: calyx.group @bb0_2 { +// CHECK-NEXT: calyx.assign %std_add_1.left = %while_0_arg0_reg.out : i32 +// CHECK-NEXT: calyx.assign %std_add_1.right = %c1_i32 : i32 +// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %std_add_1.out : i32 +// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1 +// CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1 +// CHECK-NEXT: } +module { + func.func @foo() attributes {} { + %const = arith.constant 1 : index + loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () { + %latch = arith.cmpi ult, %counter, %const : index + loopschedule.register %latch : i1 + } do { + %S0 = loopschedule.pipeline.stage start = 0 { + %op = arith.addi %counter, %const : index + loopschedule.register %op : index + } : index + %S1 = loopschedule.pipeline.stage start = 1 { + %math = arith.addi %S0, %const : index + loopschedule.register %math : index + } : index + loopschedule.terminator iter_args(%S0), results() : (index) -> () + } + return + } +} +