Skip to content

Commit

Permalink
Support inference of parametrics from function arguments in type_syst…
Browse files Browse the repository at this point in the history
…em_v2.

PiperOrigin-RevId: 711561226
  • Loading branch information
richmckeever authored and copybara-github committed Jan 3, 2025
1 parent 6794913 commit 809279c
Show file tree
Hide file tree
Showing 13 changed files with 792 additions and 228 deletions.
62 changes: 35 additions & 27 deletions xls/dslx/frontend/zip_ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,19 @@
namespace xls::dslx {
namespace {

// Returns true if `node` is of type `T`; false otherwise.
template <typename T>
bool MatchType(const AstNode* node) {
return dynamic_cast<const T*>(node) != nullptr;
}

// A visitor intended to be invoked for both the LHS and RHS trees, to perform
// the work of `ZipAst` non-recursively. The caller must do the recursive
// descent itself, and invoke this visitor for each LHS/RHS counterpart node
// encountered, with the LHS first.
class ZipVisitor : public AstNodeVisitorWithDefault {
public:
ZipVisitor(AstNodeVisitor* lhs_visitor, AstNodeVisitor* rhs_visitor,
absl::AnyInvocable<absl::Status(const AstNode*, const AstNode*)>
accept_mismatch_callback)
ZipAstOptions options)
: lhs_visitor_(lhs_visitor),
rhs_visitor_(rhs_visitor),
accept_mismatch_callback_(std::move(accept_mismatch_callback)) {}
options_(std::move(options)) {}

absl::AnyInvocable<absl::Status(const AstNode*, const AstNode*)>&
accept_mismatch_callback() {
return accept_mismatch_callback_;
}
ZipAstOptions& options() { return options_; }

#define DECLARE_HANDLER(__type) \
absl::Status Handle##__type(const __type* n) override { \
Expand All @@ -65,28 +55,49 @@ class ZipVisitor : public AstNodeVisitorWithDefault {
absl::Status Handle(const T* node) {
if (!lhs_.has_value()) {
lhs_ = node;
match_fn_ = MatchType<T>;
match_fn_ = &ZipVisitor::MatchNode<T>;
return absl::OkStatus();
}
// `node` is the RHS if we get here.
if (match_fn_(node)) {
if ((this->*match_fn_)(*lhs_, node)) {
XLS_RETURN_IF_ERROR((*lhs_)->Accept(lhs_visitor_));
XLS_RETURN_IF_ERROR(node->Accept(rhs_visitor_));
} else {
XLS_RETURN_IF_ERROR(accept_mismatch_callback_(*lhs_, node));
XLS_RETURN_IF_ERROR(options_.accept_mismatch_callback(*lhs_, node));
}
lhs_ = std::nullopt;
match_fn_ = nullptr;
return absl::OkStatus();
}

template <typename T>
bool MatchContent(const T* lhs, const T* rhs) {
return true;
}

template <>
bool MatchContent(const NameRef* lhs, const NameRef* rhs) {
return !options_.check_defs_for_name_refs ||
ToAstNode(lhs->name_def()) == ToAstNode(rhs->name_def());
}

// Returns true if `node` is of type `T` and `MatchContent<T>` returns true.
template <typename T>
bool MatchNode(const AstNode* lhs, const AstNode* rhs) {
const T* casted_lhs = dynamic_cast<const T*>(lhs);
const T* casted_rhs = dynamic_cast<const T*>(rhs);
return casted_lhs != nullptr && casted_rhs != nullptr &&
MatchContent(casted_lhs, casted_rhs);
}

AstNodeVisitor* lhs_visitor_;
AstNodeVisitor* rhs_visitor_;
absl::AnyInvocable<absl::Status(const AstNode*, const AstNode*)>
accept_mismatch_callback_;
ZipAstOptions options_;

std::optional<const AstNode*> lhs_;
absl::AnyInvocable<bool(const AstNode*)> match_fn_ = nullptr;

using MatchFn = bool (ZipVisitor::*)(const AstNode*, const AstNode*);
MatchFn match_fn_ = nullptr;
};

// Helper for `ZipAst` which runs recursively and invokes the same `visitor` for
Expand All @@ -98,7 +109,7 @@ absl::Status ZipInternal(ZipVisitor* visitor, const AstNode* lhs,
std::vector<AstNode*> lhs_children = lhs->GetChildren(/*want_types=*/true);
std::vector<AstNode*> rhs_children = rhs->GetChildren(/*want_types=*/true);
if (lhs_children.size() != rhs_children.size()) {
return visitor->accept_mismatch_callback()(lhs, rhs);
return visitor->options().accept_mismatch_callback(lhs, rhs);
}
for (int i = 0; i < lhs_children.size(); i++) {
XLS_RETURN_IF_ERROR(ZipInternal(visitor, lhs_children[i], rhs_children[i]));
Expand All @@ -108,13 +119,10 @@ absl::Status ZipInternal(ZipVisitor* visitor, const AstNode* lhs,

} // namespace

absl::Status ZipAst(
const AstNode* lhs, const AstNode* rhs, AstNodeVisitor* lhs_visitor,
AstNodeVisitor* rhs_visitor,
absl::AnyInvocable<absl::Status(const AstNode*, const AstNode*)>
accept_mismatch_callback) {
ZipVisitor visitor(lhs_visitor, rhs_visitor,
std::move(accept_mismatch_callback));
absl::Status ZipAst(const AstNode* lhs, const AstNode* rhs,
AstNodeVisitor* lhs_visitor, AstNodeVisitor* rhs_visitor,
ZipAstOptions options) {
ZipVisitor visitor(lhs_visitor, rhs_visitor, std::move(options));
return ZipInternal(&visitor, lhs, rhs);
}

Expand Down
32 changes: 23 additions & 9 deletions xls/dslx/frontend/zip_ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,40 @@

#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/strings/substitute.h"
#include "xls/dslx/frontend/ast.h"

namespace xls::dslx {

// Options for the behavior of `ZipAst`.
struct ZipAstOptions {
// Whether to consider it a mismatch if an LHS `NameRef` refers to a different
// def than the RHS one.
bool check_defs_for_name_refs = false;

// The callback for handling mismatches. By default, this generates an error.
absl::AnyInvocable<absl::Status(const AstNode*, const AstNode*)>
accept_mismatch_callback =
[](const AstNode* x, const AstNode* y) -> absl::Status {
return absl::InvalidArgumentError(
absl::Substitute("Mismatch: $0 vs. $1", x->ToString(), y->ToString()));
};
};

// Traverses `lhs` and `rhs`, invoking `lhs_visitor` and then `rhs_visitor` for
// each corresponding node pair.
//
// The expectation is that `lhs` and `rhs` are structurally equivalent, meaning
// each corresponding node is of the same class and has the same number of
// children.
//
// If a structural mismatch is encountered, then the `accept_mismatch_callback`
// is invoked, and if it errors, the error is propagated out of `ZipAst`. On
// success of `accept_mismatch_callback`, the mismatching subtree is ignored,
// and `ZipAst` proceeds.
absl::Status ZipAst(
const AstNode* lhs, const AstNode* rhs, AstNodeVisitor* lhs_visitor,
AstNodeVisitor* rhs_visitor,
absl::AnyInvocable<absl::Status(const AstNode*, const AstNode*)>
accept_mismatch_callback);
// If a structural mismatch is encountered, then
// `options.accept_mismatch_callback` is invoked, and if it errors, the error is
// propagated out of `ZipAst`. On success of `options.accept_mismatch_callback`,
// the mismatching subtree is ignored, and `ZipAst` proceeds.
absl::Status ZipAst(const AstNode* lhs, const AstNode* rhs,
AstNodeVisitor* lhs_visitor, AstNodeVisitor* rhs_visitor,
ZipAstOptions options = ZipAstOptions{});

} // namespace xls::dslx

Expand Down
76 changes: 59 additions & 17 deletions xls/dslx/frontend/zip_ast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace {
using absl_testing::StatusIs;
using ::testing::Contains;
using ::testing::ElementsAre;
using ::testing::HasSubstr;
using ::testing::Not;

class ZipAstTest : public ::testing::Test {
Expand Down Expand Up @@ -85,10 +86,11 @@ TEST_F(ZipAstTest, ZipEmptyModules) {
Collector collector;
XLS_EXPECT_OK(
ZipAst(module1.get(), module2.get(), &collector, &collector,
[](const AstNode*, const AstNode*) {
ZipAstOptions{.accept_mismatch_callback = [](const AstNode*,
const AstNode*) {
return absl::FailedPreconditionError(
"Should not invoke accept mismatch callback here.");
}));
}}));
EXPECT_THAT(collector.nodes(), ElementsAre(module1.get(), module2.get()));
}

Expand All @@ -104,13 +106,49 @@ fn muladd<S: bool, N: u32>(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] {
Collector collector2;
XLS_EXPECT_OK(
ZipAst(module1.get(), module2.get(), &collector1, &collector2,
[](const AstNode* a, const AstNode* b) {
ZipAstOptions{.accept_mismatch_callback = [](const AstNode* a,
const AstNode* b) {
return absl::FailedPreconditionError(
"Should not invoke accept mismatch callback here.");
}));
}}));
EXPECT_EQ(collector1.GetNodeStrings(), collector2.GetNodeStrings());
}

TEST_F(ZipAstTest, ZipStructurallyMatchingExprsWithoutNameRefChecking) {
constexpr std::string_view kProgram = R"(
const X = u32:4;
const Y = X + 1;
const Z = Y + 1;
)";
XLS_ASSERT_OK_AND_ASSIGN(auto module, Parse(kProgram));
XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* y,
module->GetMemberOrError<ConstantDef>("Y"));
XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* z,
module->GetMemberOrError<ConstantDef>("Z"));
Collector collector1;
Collector collector2;
XLS_EXPECT_OK(ZipAst(y, z, &collector1, &collector2));
}

TEST_F(ZipAstTest, ZipStructurallyMatchingExprsWithNameRefChecking) {
constexpr std::string_view kProgram = R"(
const X = u32:4;
const Y = X + 1;
const Z = Y + 1;
)";
XLS_ASSERT_OK_AND_ASSIGN(auto module, Parse(kProgram));
XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* y,
module->GetMemberOrError<ConstantDef>("Y"));
XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* z,
module->GetMemberOrError<ConstantDef>("Z"));
Collector collector1;
Collector collector2;
EXPECT_THAT(
ZipAst(y, z, &collector1, &collector2,
ZipAstOptions{.check_defs_for_name_refs = true}),
StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("X vs. Y")));
}

TEST_F(ZipAstTest, ZipWithMismatchAccepted) {
XLS_ASSERT_OK_AND_ASSIGN(auto module1, Parse(R"(
fn muladd<S: bool, N: u32>(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] {
Expand All @@ -124,12 +162,14 @@ fn muladd<S: bool, N: u32>(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] {
)"));
Collector collector1;
Collector collector2;
XLS_EXPECT_OK(ZipAst(module1.get(), module2.get(), &collector1, &collector2,
[](const AstNode* a, const AstNode* b) {
EXPECT_EQ(a->ToString(), "b");
EXPECT_EQ(b->ToString(), "u32:42");
return absl::OkStatus();
}));
XLS_EXPECT_OK(
ZipAst(module1.get(), module2.get(), &collector1, &collector2,
ZipAstOptions{.accept_mismatch_callback = [](const AstNode* a,
const AstNode* b) {
EXPECT_EQ(a->ToString(), "b");
EXPECT_EQ(b->ToString(), "u32:42");
return absl::OkStatus();
}}));
EXPECT_EQ(collector1.nodes().size(), collector2.nodes().size());
EXPECT_THAT(collector1.GetNodeStrings(), Contains("(a + u32:1)"));
EXPECT_THAT(collector2.GetNodeStrings(), Contains("(a + u32:1)"));
Expand All @@ -152,13 +192,15 @@ fn muladd<S: bool, N: u32>(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] {
)"));
Collector collector1;
Collector collector2;
EXPECT_THAT(ZipAst(module1.get(), module2.get(), &collector1, &collector2,
[](const AstNode* a, const AstNode* b) {
EXPECT_EQ(a->ToString(), "b");
EXPECT_EQ(b->ToString(), "u32:42");
return absl::InvalidArgumentError("rejected");
}),
StatusIs(absl::StatusCode::kInvalidArgument));
EXPECT_THAT(
ZipAst(module1.get(), module2.get(), &collector1, &collector2,
ZipAstOptions{.accept_mismatch_callback =
[](const AstNode* a, const AstNode* b) {
EXPECT_EQ(a->ToString(), "b");
EXPECT_EQ(b->ToString(), "u32:42");
return absl::InvalidArgumentError("rejected");
}}),
StatusIs(absl::StatusCode::kInvalidArgument));
EXPECT_EQ(collector1.nodes().size(), collector2.nodes().size());
EXPECT_THAT(collector1.GetNodeStrings(), Contains("(a + u32:1)"));
EXPECT_THAT(collector2.GetNodeStrings(), Contains("(a + u32:1)"));
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/type_system_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ cc_library(
hdrs = ["inference_table_to_type_info.h"],
deps = [
":inference_table",
":solve_for_parametrics",
":type_annotation_utils",
"//xls/common:visitor",
"//xls/common/status:status_macros",
Expand All @@ -59,6 +60,7 @@ cc_library(
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:ast_cloner",
"//xls/dslx/frontend:ast_node_visitor_with_default",
"//xls/dslx/frontend:ast_utils",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"//xls/dslx/type_system:deduce_utils",
Expand All @@ -74,6 +76,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
],
)
Expand Down Expand Up @@ -228,6 +231,7 @@ cc_library(
"//xls/dslx/frontend:ast",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"//xls/dslx/type_system:type_info",
"//xls/ir:bits",
"//xls/ir:number_parser",
"@com_google_absl//absl/status",
Expand Down
25 changes: 17 additions & 8 deletions xls/dslx/type_system_v2/inference_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,15 @@ class InferenceTableImpl : public InferenceTable {
variable,
InvocationScopedExpr(caller_invocation, binding->type_annotation(),
std::get<Expr*>(value)));
} else if (binding->expr() == nullptr) {
return absl::UnimplementedError(absl::StrCat(
"Type inference version 2 is a work in progress and doesn't yet "
"support inferring parametrics from function arguments: ",
invocation->ToString()));
} else {
} else if (binding->expr() != nullptr) {
values.emplace(
variable,
InvocationScopedExpr(invocation.get(), binding->type_annotation(),
binding->expr()));
}
}
const ParametricInvocation* result = invocation.get();
node_to_parametric_invocation_.emplace(&node, result);
parametric_invocations_.push_back(std::move(invocation));
parametric_values_by_invocation_.emplace(result, std::move(values));
return result;
Expand All @@ -249,11 +245,22 @@ class InferenceTableImpl : public InferenceTable {
return result;
}

InvocationScopedExpr GetParametricValue(
std::optional<const ParametricInvocation*> GetParametricInvocation(
const Invocation* node) const override {
const auto it = node_to_parametric_invocation_.find(node);
return it == node_to_parametric_invocation_.end()
? std::nullopt
: std::make_optional(it->second);
}

std::optional<InvocationScopedExpr> GetParametricValue(
const NameDef& binding_name_def,
const ParametricInvocation& invocation) const override {
const InferenceVariable* variable = variables_.at(&binding_name_def).get();
return parametric_values_by_invocation_.at(&invocation).at(variable);
const absl::flat_hash_map<const InferenceVariable*, InvocationScopedExpr>&
values = parametric_values_by_invocation_.at(&invocation);
const auto it = values.find(variable);
return it == values.end() ? std::nullopt : std::make_optional(it->second);
}

absl::Status SetTypeAnnotation(const AstNode* node,
Expand Down Expand Up @@ -371,6 +378,8 @@ class InferenceTableImpl : public InferenceTable {
// Parametric invocations and the corresponding information about parametric
// variables.
std::vector<std::unique_ptr<ParametricInvocation>> parametric_invocations_;
absl::flat_hash_map<const Invocation*, const ParametricInvocation*>
node_to_parametric_invocation_;
absl::flat_hash_map<
const ParametricInvocation*,
absl::flat_hash_map<const InferenceVariable*, InvocationScopedExpr>>
Expand Down
13 changes: 10 additions & 3 deletions xls/dslx/type_system_v2/inference_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,17 @@ class InferenceTable {
virtual std::vector<const ParametricInvocation*> GetParametricInvocations()
const = 0;

// Retrieves the `ParametricInvocation` associated with `node` if there is
// one.
virtual std::optional<const ParametricInvocation*> GetParametricInvocation(
const Invocation* node) const = 0;

// Returns the expression for the value of the given parametric in the given
// invocation. Note that the return value may be scoped to either `invocation`
// or its caller, depending on where the value expression originated from.
virtual InvocationScopedExpr GetParametricValue(
// invocation, if the parametric has an explicit or default expression. If it
// is implicit, then this returns `nullopt`. Note that the return value may be
// scoped to either `invocation` or its caller, depending on where the value
// expression originated from.
virtual std::optional<InvocationScopedExpr> GetParametricValue(
const NameDef& binding_name_def,
const ParametricInvocation& invocation) const = 0;

Expand Down
Loading

0 comments on commit 809279c

Please sign in to comment.