Skip to content

Commit

Permalink
[XLS] Ignore the costs of any node that only affects asserts/covers/t…
Browse files Browse the repository at this point in the history
…races

Since asserts, covers, and traces will never be synthesized, we can essentially ignore the delay & register cost of anything that exclusively feeds into these operations. This lets us produce better schedules for these circuits regardless of any extra asserts that we or the user may have added, under the assumption that the synthesis toolchain will trim these nodes and the registers that have no other uses anyway.

It's worth noting that these operations can still prevent certain optimizations, since we can't remove their inputs entirely. There may be more work to be done to allow these optimizations to kick in anyway, potentially by preserving "shadow" values as needed.

PiperOrigin-RevId: 712540220
  • Loading branch information
ericastor authored and copybara-github committed Jan 6, 2025
1 parent 4e147e6 commit 67bf106
Show file tree
Hide file tree
Showing 16 changed files with 272 additions and 38 deletions.
9 changes: 5 additions & 4 deletions xls/codegen/side_effect_condition_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,14 +478,14 @@ top proc f(x: bits[32], init={0}) {

TEST_P(SideEffectConditionPassTest, AssertionInLastStageOfFunction) {
constexpr std::string_view ir_text = R"(package test
fn f(tkn: token, x: bits[32], y: bits[32]) -> (token, bits[32]) {
fn f(tkn: token, x: bits[32], y: bits[32]) -> (token, bits[32], bits[1]) {
xy: bits[32] = umul(x, y)
literal1: bits[32] = literal(value=1)
xy_plus_1: bits[32] = add(xy, literal1)
literal4: bits[32] = literal(value=4)
xy_plus_1_gt_4: bits[1] = ugt(xy_plus_1, literal4)
assertion: token = assert(tkn, xy_plus_1_gt_4, label="foo", message="bar")
ret out: (token, bits[32]) = tuple(assertion, xy_plus_1)
ret out: (token, bits[32], bits[1]) = tuple(assertion, xy_plus_1, xy_plus_1_gt_4)
}
)";
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Package> package,
Expand Down Expand Up @@ -569,15 +569,16 @@ proc g(x: bits[32], init={4}) {
// Now test proc 'g'.
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Package> package,
Parser::ParsePackage(ir_text));
XLS_ASSERT_OK(package->SetTop(package->GetFunctionBaseByName("g").value()));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("g"));
XLS_ASSERT_OK(package->SetTop(proc));
EXPECT_THAT(Run(package.get()), IsOkAndHolds(true));
XLS_ASSERT_OK_AND_ASSIGN(Block * block, package->GetBlock("g"));

XLS_ASSERT_OK_AND_ASSIGN(Node * assertion, block->GetNode("assertion"));
ASSERT_NE(assertion, nullptr);
Node* condition = assertion->As<xls::Assert>()->condition();
EXPECT_THAT(condition, m::Or(m::Not(m::Name(HasSubstr("_stage_done"))),
m::Name("xy_plus_1_gt_4")));
m::Name(HasSubstr("xy_plus_1_gt_4"))));

constexpr int64_t kNumCycles = 10;
std::vector<absl::flat_hash_map<std::string, Value>> inputs(
Expand Down
2 changes: 2 additions & 0 deletions xls/dev_tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ cc_binary(
"//xls/passes:pass_base",
"//xls/passes:query_engine",
"//xls/scheduling:pipeline_schedule",
"//xls/scheduling:schedule_util",
"//xls/scheduling:scheduling_options",
"//xls/tools:codegen",
"//xls/tools:codegen_flags",
Expand All @@ -672,6 +673,7 @@ cc_binary(
"//xls/tools:scheduling_options_flags_cc_proto",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down
13 changes: 12 additions & 1 deletion xls/dev_tools/benchmark_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -83,6 +84,7 @@
#include "xls/passes/pass_base.h"
#include "xls/passes/query_engine.h"
#include "xls/scheduling/pipeline_schedule.h"
#include "xls/scheduling/schedule_util.h"
#include "xls/scheduling/scheduling_options.h"
#include "xls/tools/codegen.h"
#include "xls/tools/codegen_flags.h"
Expand Down Expand Up @@ -364,10 +366,19 @@ absl::Status PrintScheduleInfo(FunctionBase* f,
std::vector<int64_t> flops_per_stage(schedule.length());
std::vector<int64_t> duplicates_per_stage(schedule.length());
std::vector<int64_t> constants_per_stage(schedule.length());
absl::flat_hash_set<Node*> dead_after_synthesis =
GetDeadAfterSynthesisNodes(f);
for (int64_t i = 0; i < schedule.length(); ++i) {
absl::flat_hash_map<BddNodeIndex, std::pair<Node*, int64_t>> bdd_nodes;
for (Node* node : schedule.GetLiveOutOfCycle(i)) {
flops_per_stage[i] += node->GetType()->GetFlatBitCount();
if (!dead_after_synthesis.contains(node) &&
(f->HasImplicitUse(node) ||
absl::c_any_of(node->users(), [&](Node* user) {
return schedule.cycle(user) > i &&
dead_after_synthesis.contains(user);
}))) {
flops_per_stage[i] += node->GetType()->GetFlatBitCount();
}
if (node->GetType()->IsBits()) {
for (int64_t bit_index = 0; bit_index < node->BitCountOrDie();
++bit_index) {
Expand Down
1 change: 1 addition & 0 deletions xls/fdo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ cc_library(
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:op",
"//xls/scheduling:schedule_util",
"//xls/scheduling:scheduling_options",
"//xls/scheduling:sdc_scheduler",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
5 changes: 4 additions & 1 deletion xls/fdo/iterative_sdc_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "xls/ir/nodes.h"
#include "xls/ir/op.h"
#include "xls/ir/proc.h"
#include "xls/scheduling/schedule_util.h"
#include "xls/scheduling/scheduling_options.h"
#include "ortools/math_opt/cpp/math_opt.h"

Expand Down Expand Up @@ -400,8 +401,10 @@ absl::StatusOr<ScheduleCycleMap> ScheduleByIterativeSDC(
ScheduleCycleMap cycle_map;
absl::flat_hash_set<NodeCut> evaluated_cuts;
std::mt19937_64 bit_gen;
absl::flat_hash_set<Node *> dead_after_synthesis =
GetDeadAfterSynthesisNodes(f);
for (int64_t i = 0; i < options.iteration_number; ++i) {
IterativeSDCSchedulingModel model(f, delay_manager);
IterativeSDCSchedulingModel model(f, dead_after_synthesis, delay_manager);

for (const SchedulingConstraint &constraint : constraints) {
XLS_RETURN_IF_ERROR(model.AddSchedulingConstraint(constraint));
Expand Down
6 changes: 5 additions & 1 deletion xls/fdo/iterative_sdc_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

#include <cstdint>
#include <optional>
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand All @@ -37,8 +39,10 @@ class IterativeSDCSchedulingModel : public SDCSchedulingModel {
// Delay map is no longer needed as the delay calculation is completely
// handled by the delay manager.
IterativeSDCSchedulingModel(FunctionBase* func,
absl::flat_hash_set<Node*> dead_after_synthesis,
const DelayManager& delay_manager)
: SDCSchedulingModel(func, DelayMap()), delay_manager_(delay_manager) {}
: SDCSchedulingModel(func, std::move(dead_after_synthesis), DelayMap()),
delay_manager_(delay_manager) {}

// Overrides the original timing constraints builder. This method directly
// call delay manager to extract the paths longer than the given clock period
Expand Down
20 changes: 20 additions & 0 deletions xls/scheduling/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ cc_library(
srcs = ["sdc_scheduler.cc"],
hdrs = ["sdc_scheduler.h"],
deps = [
":schedule_util",
":scheduling_options",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
Expand All @@ -166,12 +167,25 @@ cc_library(
],
)

cc_library(
name = "schedule_util",
srcs = ["schedule_util.cc"],
hdrs = ["schedule_util.h"],
deps = [
"//xls/ir",
"//xls/ir:op",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
],
)

cc_library(
name = "pipeline_schedule",
srcs = ["pipeline_schedule.cc"],
hdrs = ["pipeline_schedule.h"],
deps = [
":pipeline_schedule_cc_proto",
":schedule_util",
":scheduling_options",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
Expand All @@ -183,6 +197,7 @@ cc_library(
"//xls/ir:op",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
Expand All @@ -201,6 +216,7 @@ cc_library(
":min_cut_scheduler",
":pipeline_schedule",
":schedule_bounds",
":schedule_util",
":scheduling_options",
":sdc_scheduler",
"//xls/common/logging:log_lines",
Expand All @@ -219,6 +235,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/random:distributions",
Expand Down Expand Up @@ -275,6 +292,7 @@ cc_test(
"//xls/ir:function_builder",
"//xls/ir:ir_test_base",
"//xls/ir:op",
"//xls/ir:value",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest",
],
Expand Down Expand Up @@ -309,10 +327,12 @@ cc_library(
srcs = ["schedule_bounds.cc"],
hdrs = ["schedule_bounds.h"],
deps = [
":schedule_util",
"//xls/common/status:status_macros",
"//xls/estimators/delay_model:delay_estimator",
"//xls/ir",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
27 changes: 23 additions & 4 deletions xls/scheduling/pipeline_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
Expand All @@ -47,6 +48,7 @@
#include "xls/ir/proc.h"
#include "xls/ir/topo_sort.h"
#include "xls/scheduling/pipeline_schedule.pb.h"
#include "xls/scheduling/schedule_util.h"
#include "xls/scheduling/scheduling_options.h"

namespace xls {
Expand Down Expand Up @@ -268,9 +270,14 @@ absl::Status PipelineSchedule::Verify() const {

absl::Status PipelineSchedule::VerifyTiming(
int64_t clock_period_ps, const DelayEstimator& delay_estimator) const {
// The set of nodes that cannot affect anything that will be synthesized.
const absl::flat_hash_set<Node*> dead_after_synthesis =
GetDeadAfterSynthesisNodes(function_base_);

// Critical path from start of the cycle that a node is scheduled through the
// node itself. If the schedule meets timing, then this value should be less
// than or equal to clock_period_ps for every node.
// than or equal to clock_period_ps for every node (except those that cannot
// affect anything that will be synthesized).
absl::flat_hash_map<Node*, int64_t> node_cp;
// The predecessor (operand) of the node through which the critical-path from
// the start of the cycle extends.
Expand All @@ -284,15 +291,21 @@ absl::Status PipelineSchedule::VerifyTiming(
int64_t cp_to_node_start = 0;
cp_pred[node] = nullptr;
for (Node* operand : node->operands()) {
if (dead_after_synthesis.contains(operand)) {
continue;
}
if (cycle(operand) == cycle(node)) {
if (cp_to_node_start < node_cp.at(operand)) {
cp_to_node_start = node_cp.at(operand);
cp_pred[node] = operand;
}
}
}
XLS_ASSIGN_OR_RETURN(int64_t node_delay,
delay_estimator.GetOperationDelayInPs(node));
int64_t node_delay = 0;
if (!dead_after_synthesis.contains(node)) {
XLS_ASSIGN_OR_RETURN(node_delay,
delay_estimator.GetOperationDelayInPs(node));
}
node_cp[node] = cp_to_node_start + node_delay;
if (max_cp_node == nullptr || node_cp[node] > node_cp[max_cp_node]) {
max_cp_node = node;
Expand Down Expand Up @@ -322,10 +335,16 @@ absl::Status PipelineSchedule::VerifyTiming(

absl::Status PipelineSchedule::VerifyTiming(
int64_t clock_period_ps, const DelayManager& delay_manager) const {
const absl::flat_hash_set<Node*> dead_after_synthesis =
GetDeadAfterSynthesisNodes(function_base_);

PathExtractOptions options;
options.cycle_map = &cycle_map_;
XLS_ASSIGN_OR_RETURN(PathInfo critical_path,
delay_manager.GetLongestPath(options));
delay_manager.GetLongestPath(
options, /*except=*/[&](Node* source, Node* target) {
return dead_after_synthesis.contains(target);
}));

auto [delay, source, target] = critical_path;
if (delay > clock_period_ps) {
Expand Down
15 changes: 12 additions & 3 deletions xls/scheduling/run_pipeline_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/algorithm/container.h"
#include "absl/base/log_severity.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/random/distributions.h"
Expand Down Expand Up @@ -57,6 +58,7 @@
#include "xls/scheduling/min_cut_scheduler.h"
#include "xls/scheduling/pipeline_schedule.h"
#include "xls/scheduling/schedule_bounds.h"
#include "xls/scheduling/schedule_util.h"
#include "xls/scheduling/scheduling_options.h"
#include "xls/scheduling/sdc_scheduler.h"

Expand Down Expand Up @@ -159,12 +161,18 @@ absl::Status TightenBounds(sched::ScheduleBounds& bounds, FunctionBase* f,
return absl::OkStatus();
}

// Returns the critical path through the given nodes (ordered topologically).
// Returns the critical path through the given nodes (ordered topologically),
// ignoring nodes that will be dead after synthesis.
absl::StatusOr<int64_t> ComputeCriticalPath(
absl::Span<Node* const> topo_sort, const DelayEstimator& delay_estimator) {
absl::Span<Node* const> topo_sort,
const absl::flat_hash_set<Node*> dead_after_synthesis,
const DelayEstimator& delay_estimator) {
int64_t function_cp = 0;
absl::flat_hash_map<Node*, int64_t> node_cp;
for (Node* node : topo_sort) {
if (dead_after_synthesis.contains(node)) {
continue;
}
int64_t node_start = 0;
for (Node* operand : node->operands()) {
node_start = std::max(node_start, node_cp[operand]);
Expand All @@ -178,7 +186,8 @@ absl::StatusOr<int64_t> ComputeCriticalPath(
}
absl::StatusOr<int64_t> ComputeCriticalPath(
FunctionBase* f, const DelayEstimator& delay_estimator) {
return ComputeCriticalPath(TopoSort(f), delay_estimator);
return ComputeCriticalPath(TopoSort(f), GetDeadAfterSynthesisNodes(f),
delay_estimator);
}

// Returns the minimum clock period in picoseconds for which it is feasible to
Expand Down
Loading

0 comments on commit 67bf106

Please sign in to comment.