diff --git a/xls/dslx/frontend/zip_ast.cc b/xls/dslx/frontend/zip_ast.cc index 5aab39ae77..b38e37c504 100644 --- a/xls/dslx/frontend/zip_ast.cc +++ b/xls/dslx/frontend/zip_ast.cc @@ -27,12 +27,6 @@ namespace xls::dslx { namespace { -// Returns true if `node` is of type `T`; false otherwise. -template -bool MatchType(const AstNode* node) { - return dynamic_cast(node) != nullptr; -} - // A visitor intended to be invoked for both the LHS and RHS trees, to perform // the work of `ZipAst` non-recursively. The caller must do the recursive // descent itself, and invoke this visitor for each LHS/RHS counterpart node @@ -40,16 +34,12 @@ bool MatchType(const AstNode* node) { class ZipVisitor : public AstNodeVisitorWithDefault { public: ZipVisitor(AstNodeVisitor* lhs_visitor, AstNodeVisitor* rhs_visitor, - absl::AnyInvocable - accept_mismatch_callback) + ZipAstOptions options) : lhs_visitor_(lhs_visitor), rhs_visitor_(rhs_visitor), - accept_mismatch_callback_(std::move(accept_mismatch_callback)) {} + options_(std::move(options)) {} - absl::AnyInvocable& - accept_mismatch_callback() { - return accept_mismatch_callback_; - } + ZipAstOptions& options() { return options_; } #define DECLARE_HANDLER(__type) \ absl::Status Handle##__type(const __type* n) override { \ @@ -65,28 +55,49 @@ class ZipVisitor : public AstNodeVisitorWithDefault { absl::Status Handle(const T* node) { if (!lhs_.has_value()) { lhs_ = node; - match_fn_ = MatchType; + match_fn_ = &ZipVisitor::MatchNode; return absl::OkStatus(); } // `node` is the RHS if we get here. - if (match_fn_(node)) { + if ((this->*match_fn_)(*lhs_, node)) { XLS_RETURN_IF_ERROR((*lhs_)->Accept(lhs_visitor_)); XLS_RETURN_IF_ERROR(node->Accept(rhs_visitor_)); } else { - XLS_RETURN_IF_ERROR(accept_mismatch_callback_(*lhs_, node)); + XLS_RETURN_IF_ERROR(options_.accept_mismatch_callback(*lhs_, node)); } lhs_ = std::nullopt; match_fn_ = nullptr; return absl::OkStatus(); } + template + bool MatchContent(const T* lhs, const T* rhs) { + return true; + } + + template <> + bool MatchContent(const NameRef* lhs, const NameRef* rhs) { + return !options_.check_defs_for_name_refs || + ToAstNode(lhs->name_def()) == ToAstNode(rhs->name_def()); + } + + // Returns true if `node` is of type `T` and `MatchContent` returns true. + template + bool MatchNode(const AstNode* lhs, const AstNode* rhs) { + const T* casted_lhs = dynamic_cast(lhs); + const T* casted_rhs = dynamic_cast(rhs); + return casted_lhs != nullptr && casted_rhs != nullptr && + MatchContent(casted_lhs, casted_rhs); + } + AstNodeVisitor* lhs_visitor_; AstNodeVisitor* rhs_visitor_; - absl::AnyInvocable - accept_mismatch_callback_; + ZipAstOptions options_; std::optional lhs_; - absl::AnyInvocable match_fn_ = nullptr; + + using MatchFn = bool (ZipVisitor::*)(const AstNode*, const AstNode*); + MatchFn match_fn_ = nullptr; }; // Helper for `ZipAst` which runs recursively and invokes the same `visitor` for @@ -98,7 +109,7 @@ absl::Status ZipInternal(ZipVisitor* visitor, const AstNode* lhs, std::vector lhs_children = lhs->GetChildren(/*want_types=*/true); std::vector rhs_children = rhs->GetChildren(/*want_types=*/true); if (lhs_children.size() != rhs_children.size()) { - return visitor->accept_mismatch_callback()(lhs, rhs); + return visitor->options().accept_mismatch_callback(lhs, rhs); } for (int i = 0; i < lhs_children.size(); i++) { XLS_RETURN_IF_ERROR(ZipInternal(visitor, lhs_children[i], rhs_children[i])); @@ -108,13 +119,10 @@ absl::Status ZipInternal(ZipVisitor* visitor, const AstNode* lhs, } // namespace -absl::Status ZipAst( - const AstNode* lhs, const AstNode* rhs, AstNodeVisitor* lhs_visitor, - AstNodeVisitor* rhs_visitor, - absl::AnyInvocable - accept_mismatch_callback) { - ZipVisitor visitor(lhs_visitor, rhs_visitor, - std::move(accept_mismatch_callback)); +absl::Status ZipAst(const AstNode* lhs, const AstNode* rhs, + AstNodeVisitor* lhs_visitor, AstNodeVisitor* rhs_visitor, + ZipAstOptions options) { + ZipVisitor visitor(lhs_visitor, rhs_visitor, std::move(options)); return ZipInternal(&visitor, lhs, rhs); } diff --git a/xls/dslx/frontend/zip_ast.h b/xls/dslx/frontend/zip_ast.h index 2bd7770f60..bf193866fb 100644 --- a/xls/dslx/frontend/zip_ast.h +++ b/xls/dslx/frontend/zip_ast.h @@ -17,10 +17,26 @@ #include "absl/functional/any_invocable.h" #include "absl/status/status.h" +#include "absl/strings/substitute.h" #include "xls/dslx/frontend/ast.h" namespace xls::dslx { +// Options for the behavior of `ZipAst`. +struct ZipAstOptions { + // Whether to consider it a mismatch if an LHS `NameRef` refers to a different + // def than the RHS one. + bool check_defs_for_name_refs = false; + + // The callback for handling mismatches. By default, this generates an error. + absl::AnyInvocable + accept_mismatch_callback = + [](const AstNode* x, const AstNode* y) -> absl::Status { + return absl::InvalidArgumentError( + absl::Substitute("Mismatch: $0 vs. $1", x->ToString(), y->ToString())); + }; +}; + // Traverses `lhs` and `rhs`, invoking `lhs_visitor` and then `rhs_visitor` for // each corresponding node pair. // @@ -28,15 +44,13 @@ namespace xls::dslx { // each corresponding node is of the same class and has the same number of // children. // -// If a structural mismatch is encountered, then the `accept_mismatch_callback` -// is invoked, and if it errors, the error is propagated out of `ZipAst`. On -// success of `accept_mismatch_callback`, the mismatching subtree is ignored, -// and `ZipAst` proceeds. -absl::Status ZipAst( - const AstNode* lhs, const AstNode* rhs, AstNodeVisitor* lhs_visitor, - AstNodeVisitor* rhs_visitor, - absl::AnyInvocable - accept_mismatch_callback); +// If a structural mismatch is encountered, then +// `options.accept_mismatch_callback` is invoked, and if it errors, the error is +// propagated out of `ZipAst`. On success of `options.accept_mismatch_callback`, +// the mismatching subtree is ignored, and `ZipAst` proceeds. +absl::Status ZipAst(const AstNode* lhs, const AstNode* rhs, + AstNodeVisitor* lhs_visitor, AstNodeVisitor* rhs_visitor, + ZipAstOptions options = ZipAstOptions{}); } // namespace xls::dslx diff --git a/xls/dslx/frontend/zip_ast_test.cc b/xls/dslx/frontend/zip_ast_test.cc index fa5db8cfd4..13a334dd36 100644 --- a/xls/dslx/frontend/zip_ast_test.cc +++ b/xls/dslx/frontend/zip_ast_test.cc @@ -41,6 +41,7 @@ namespace { using absl_testing::StatusIs; using ::testing::Contains; using ::testing::ElementsAre; +using ::testing::HasSubstr; using ::testing::Not; class ZipAstTest : public ::testing::Test { @@ -85,10 +86,11 @@ TEST_F(ZipAstTest, ZipEmptyModules) { Collector collector; XLS_EXPECT_OK( ZipAst(module1.get(), module2.get(), &collector, &collector, - [](const AstNode*, const AstNode*) { + ZipAstOptions{.accept_mismatch_callback = [](const AstNode*, + const AstNode*) { return absl::FailedPreconditionError( "Should not invoke accept mismatch callback here."); - })); + }})); EXPECT_THAT(collector.nodes(), ElementsAre(module1.get(), module2.get())); } @@ -104,13 +106,49 @@ fn muladd(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] { Collector collector2; XLS_EXPECT_OK( ZipAst(module1.get(), module2.get(), &collector1, &collector2, - [](const AstNode* a, const AstNode* b) { + ZipAstOptions{.accept_mismatch_callback = [](const AstNode* a, + const AstNode* b) { return absl::FailedPreconditionError( "Should not invoke accept mismatch callback here."); - })); + }})); EXPECT_EQ(collector1.GetNodeStrings(), collector2.GetNodeStrings()); } +TEST_F(ZipAstTest, ZipStructurallyMatchingExprsWithoutNameRefChecking) { + constexpr std::string_view kProgram = R"( +const X = u32:4; +const Y = X + 1; +const Z = Y + 1; + )"; + XLS_ASSERT_OK_AND_ASSIGN(auto module, Parse(kProgram)); + XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* y, + module->GetMemberOrError("Y")); + XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* z, + module->GetMemberOrError("Z")); + Collector collector1; + Collector collector2; + XLS_EXPECT_OK(ZipAst(y, z, &collector1, &collector2)); +} + +TEST_F(ZipAstTest, ZipStructurallyMatchingExprsWithNameRefChecking) { + constexpr std::string_view kProgram = R"( +const X = u32:4; +const Y = X + 1; +const Z = Y + 1; + )"; + XLS_ASSERT_OK_AND_ASSIGN(auto module, Parse(kProgram)); + XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* y, + module->GetMemberOrError("Y")); + XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* z, + module->GetMemberOrError("Z")); + Collector collector1; + Collector collector2; + EXPECT_THAT( + ZipAst(y, z, &collector1, &collector2, + ZipAstOptions{.check_defs_for_name_refs = true}), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("X vs. Y"))); +} + TEST_F(ZipAstTest, ZipWithMismatchAccepted) { XLS_ASSERT_OK_AND_ASSIGN(auto module1, Parse(R"( fn muladd(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] { @@ -124,12 +162,14 @@ fn muladd(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] { )")); Collector collector1; Collector collector2; - XLS_EXPECT_OK(ZipAst(module1.get(), module2.get(), &collector1, &collector2, - [](const AstNode* a, const AstNode* b) { - EXPECT_EQ(a->ToString(), "b"); - EXPECT_EQ(b->ToString(), "u32:42"); - return absl::OkStatus(); - })); + XLS_EXPECT_OK( + ZipAst(module1.get(), module2.get(), &collector1, &collector2, + ZipAstOptions{.accept_mismatch_callback = [](const AstNode* a, + const AstNode* b) { + EXPECT_EQ(a->ToString(), "b"); + EXPECT_EQ(b->ToString(), "u32:42"); + return absl::OkStatus(); + }})); EXPECT_EQ(collector1.nodes().size(), collector2.nodes().size()); EXPECT_THAT(collector1.GetNodeStrings(), Contains("(a + u32:1)")); EXPECT_THAT(collector2.GetNodeStrings(), Contains("(a + u32:1)")); @@ -152,13 +192,15 @@ fn muladd(a: xN[S][N], b: xN[S][N], c: xN[S][N]) -> xN[S][N] { )")); Collector collector1; Collector collector2; - EXPECT_THAT(ZipAst(module1.get(), module2.get(), &collector1, &collector2, - [](const AstNode* a, const AstNode* b) { - EXPECT_EQ(a->ToString(), "b"); - EXPECT_EQ(b->ToString(), "u32:42"); - return absl::InvalidArgumentError("rejected"); - }), - StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT( + ZipAst(module1.get(), module2.get(), &collector1, &collector2, + ZipAstOptions{.accept_mismatch_callback = + [](const AstNode* a, const AstNode* b) { + EXPECT_EQ(a->ToString(), "b"); + EXPECT_EQ(b->ToString(), "u32:42"); + return absl::InvalidArgumentError("rejected"); + }}), + StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_EQ(collector1.nodes().size(), collector2.nodes().size()); EXPECT_THAT(collector1.GetNodeStrings(), Contains("(a + u32:1)")); EXPECT_THAT(collector2.GetNodeStrings(), Contains("(a + u32:1)")); diff --git a/xls/dslx/type_system_v2/BUILD b/xls/dslx/type_system_v2/BUILD index b759e75cd0..8b8b23cd0c 100644 --- a/xls/dslx/type_system_v2/BUILD +++ b/xls/dslx/type_system_v2/BUILD @@ -48,6 +48,7 @@ cc_library( hdrs = ["inference_table_to_type_info.h"], deps = [ ":inference_table", + ":solve_for_parametrics", ":type_annotation_utils", "//xls/common:visitor", "//xls/common/status:status_macros", @@ -59,6 +60,7 @@ cc_library( "//xls/dslx/frontend:ast", "//xls/dslx/frontend:ast_cloner", "//xls/dslx/frontend:ast_node_visitor_with_default", + "//xls/dslx/frontend:ast_utils", "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", "//xls/dslx/type_system:deduce_utils", @@ -74,6 +76,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) @@ -228,6 +231,7 @@ cc_library( "//xls/dslx/frontend:ast", "//xls/dslx/frontend:module", "//xls/dslx/frontend:pos", + "//xls/dslx/type_system:type_info", "//xls/ir:bits", "//xls/ir:number_parser", "@com_google_absl//absl/status", diff --git a/xls/dslx/type_system_v2/inference_table.cc b/xls/dslx/type_system_v2/inference_table.cc index 79221314f1..57dd68ca05 100644 --- a/xls/dslx/type_system_v2/inference_table.cc +++ b/xls/dslx/type_system_v2/inference_table.cc @@ -221,12 +221,7 @@ class InferenceTableImpl : public InferenceTable { variable, InvocationScopedExpr(caller_invocation, binding->type_annotation(), std::get(value))); - } else if (binding->expr() == nullptr) { - return absl::UnimplementedError(absl::StrCat( - "Type inference version 2 is a work in progress and doesn't yet " - "support inferring parametrics from function arguments: ", - invocation->ToString())); - } else { + } else if (binding->expr() != nullptr) { values.emplace( variable, InvocationScopedExpr(invocation.get(), binding->type_annotation(), @@ -234,6 +229,7 @@ class InferenceTableImpl : public InferenceTable { } } const ParametricInvocation* result = invocation.get(); + node_to_parametric_invocation_.emplace(&node, result); parametric_invocations_.push_back(std::move(invocation)); parametric_values_by_invocation_.emplace(result, std::move(values)); return result; @@ -249,11 +245,22 @@ class InferenceTableImpl : public InferenceTable { return result; } - InvocationScopedExpr GetParametricValue( + std::optional GetParametricInvocation( + const Invocation* node) const override { + const auto it = node_to_parametric_invocation_.find(node); + return it == node_to_parametric_invocation_.end() + ? std::nullopt + : std::make_optional(it->second); + } + + std::optional GetParametricValue( const NameDef& binding_name_def, const ParametricInvocation& invocation) const override { const InferenceVariable* variable = variables_.at(&binding_name_def).get(); - return parametric_values_by_invocation_.at(&invocation).at(variable); + const absl::flat_hash_map& + values = parametric_values_by_invocation_.at(&invocation); + const auto it = values.find(variable); + return it == values.end() ? std::nullopt : std::make_optional(it->second); } absl::Status SetTypeAnnotation(const AstNode* node, @@ -371,6 +378,8 @@ class InferenceTableImpl : public InferenceTable { // Parametric invocations and the corresponding information about parametric // variables. std::vector> parametric_invocations_; + absl::flat_hash_map + node_to_parametric_invocation_; absl::flat_hash_map< const ParametricInvocation*, absl::flat_hash_map> diff --git a/xls/dslx/type_system_v2/inference_table.h b/xls/dslx/type_system_v2/inference_table.h index 096b892810..74b15d49fa 100644 --- a/xls/dslx/type_system_v2/inference_table.h +++ b/xls/dslx/type_system_v2/inference_table.h @@ -186,10 +186,17 @@ class InferenceTable { virtual std::vector GetParametricInvocations() const = 0; + // Retrieves the `ParametricInvocation` associated with `node` if there is + // one. + virtual std::optional GetParametricInvocation( + const Invocation* node) const = 0; + // Returns the expression for the value of the given parametric in the given - // invocation. Note that the return value may be scoped to either `invocation` - // or its caller, depending on where the value expression originated from. - virtual InvocationScopedExpr GetParametricValue( + // invocation, if the parametric has an explicit or default expression. If it + // is implicit, then this returns `nullopt`. Note that the return value may be + // scoped to either `invocation` or its caller, depending on where the value + // expression originated from. + virtual std::optional GetParametricValue( const NameDef& binding_name_def, const ParametricInvocation& invocation) const = 0; diff --git a/xls/dslx/type_system_v2/inference_table_test.cc b/xls/dslx/type_system_v2/inference_table_test.cc index da701bed84..be16855c2a 100644 --- a/xls/dslx/type_system_v2/inference_table_test.cc +++ b/xls/dslx/type_system_v2/inference_table_test.cc @@ -223,16 +223,18 @@ TEST_F(InferenceTableTest, ParametricVariable) { EXPECT_THAT(table_->GetParametricInvocations(), ElementsAre(parametric_invocation1, parametric_invocation2)); - InvocationScopedExpr parametric_inv1_n_value = + std::optional parametric_inv1_n_value = table_->GetParametricValue(*n, *parametric_invocation1); - InvocationScopedExpr parametric_inv2_n_value = + std::optional parametric_inv2_n_value = table_->GetParametricValue(*n, *parametric_invocation2); + ASSERT_TRUE(parametric_inv1_n_value.has_value()); + ASSERT_TRUE(parametric_inv2_n_value.has_value()); // These exprs are scoped to `nullopt` invocation because they reside in the // non-parametric calling context. - EXPECT_EQ(parametric_inv1_n_value.invocation(), std::nullopt); - EXPECT_EQ(parametric_inv2_n_value.invocation(), std::nullopt); - EXPECT_EQ(parametric_inv1_n_value.expr()->ToString(), "u32:4"); - EXPECT_EQ(parametric_inv2_n_value.expr()->ToString(), "u32:5"); + EXPECT_EQ(parametric_inv1_n_value->invocation(), std::nullopt); + EXPECT_EQ(parametric_inv2_n_value->invocation(), std::nullopt); + EXPECT_EQ(parametric_inv1_n_value->expr()->ToString(), "u32:4"); + EXPECT_EQ(parametric_inv2_n_value->expr()->ToString(), "u32:5"); } TEST_F(InferenceTableTest, ParametricVariableWithDefault) { @@ -280,24 +282,28 @@ TEST_F(InferenceTableTest, ParametricVariableWithDefault) { EXPECT_THAT(table_->GetParametricInvocations(), ElementsAre(parametric_invocation1, parametric_invocation2)); - InvocationScopedExpr parametric_inv1_m_value = + std::optional parametric_inv1_m_value = table_->GetParametricValue(*m, *parametric_invocation1); - InvocationScopedExpr parametric_inv1_n_value = + std::optional parametric_inv1_n_value = table_->GetParametricValue(*n, *parametric_invocation1); - InvocationScopedExpr parametric_inv2_m_value = + std::optional parametric_inv2_m_value = table_->GetParametricValue(*m, *parametric_invocation2); - InvocationScopedExpr parametric_inv2_n_value = + std::optional parametric_inv2_n_value = table_->GetParametricValue(*n, *parametric_invocation2); + ASSERT_TRUE(parametric_inv1_m_value.has_value()); + ASSERT_TRUE(parametric_inv1_n_value.has_value()); + ASSERT_TRUE(parametric_inv2_m_value.has_value()); + ASSERT_TRUE(parametric_inv2_n_value.has_value()); // Exprs that reside in the callee are scoped to the callee invocation. - EXPECT_EQ(parametric_inv1_m_value.invocation(), parametric_invocation1); - EXPECT_EQ(parametric_inv2_m_value.invocation(), std::nullopt); - EXPECT_EQ(parametric_inv1_n_value.invocation(), parametric_invocation1); - EXPECT_EQ(parametric_inv2_n_value.invocation(), parametric_invocation2); - EXPECT_EQ(parametric_inv1_m_value.expr()->ToString(), "u32:4"); - EXPECT_EQ(parametric_inv2_m_value.expr()->ToString(), "u32:5"); - EXPECT_EQ(parametric_inv1_n_value.expr()->ToString(), "M * M"); - EXPECT_EQ(parametric_inv2_n_value.expr()->ToString(), "M * M"); + EXPECT_EQ(parametric_inv1_m_value->invocation(), parametric_invocation1); + EXPECT_EQ(parametric_inv2_m_value->invocation(), std::nullopt); + EXPECT_EQ(parametric_inv1_n_value->invocation(), parametric_invocation1); + EXPECT_EQ(parametric_inv2_n_value->invocation(), parametric_invocation2); + EXPECT_EQ(parametric_inv1_m_value->expr()->ToString(), "u32:4"); + EXPECT_EQ(parametric_inv2_m_value->expr()->ToString(), "u32:5"); + EXPECT_EQ(parametric_inv1_n_value->expr()->ToString(), "M * M"); + EXPECT_EQ(parametric_inv2_n_value->expr()->ToString(), "M * M"); } TEST_F(InferenceTableTest, ParametricVariableWithArrayAnnotation) { @@ -331,10 +337,11 @@ TEST_F(InferenceTableTest, ParametricVariableWithArrayAnnotation) { table_->AddParametricInvocation(*invocation, *foo, bar, /*caller_invocation=*/std::nullopt)); - InvocationScopedExpr parametric_inv_m_value = + std::optional parametric_inv_m_value = table_->GetParametricValue(*m, *parametric_invocation); - EXPECT_EQ(parametric_inv_m_value.invocation(), std::nullopt); - EXPECT_EQ(parametric_inv_m_value.expr()->ToString(), "u32:5"); + ASSERT_TRUE(parametric_inv_m_value.has_value()); + EXPECT_EQ(parametric_inv_m_value->invocation(), std::nullopt); + EXPECT_EQ(parametric_inv_m_value->expr()->ToString(), "u32:5"); } TEST_F(InferenceTableTest, ParametricVariableWithUnsupportedAnnotation) { diff --git a/xls/dslx/type_system_v2/inference_table_to_type_info.cc b/xls/dslx/type_system_v2/inference_table_to_type_info.cc index 13468bc83a..7f4c8e3ff7 100644 --- a/xls/dslx/type_system_v2/inference_table_to_type_info.cc +++ b/xls/dslx/type_system_v2/inference_table_to_type_info.cc @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include #include @@ -32,7 +34,9 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/substitute.h" +#include "absl/types/span.h" #include "absl/types/variant.h" #include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" @@ -41,6 +45,7 @@ #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/ast_cloner.h" #include "xls/dslx/frontend/ast_node_visitor_with_default.h" +#include "xls/dslx/frontend/ast_utils.h" #include "xls/dslx/frontend/module.h" #include "xls/dslx/frontend/pos.h" #include "xls/dslx/import_data.h" @@ -51,6 +56,7 @@ #include "xls/dslx/type_system/type_info.h" #include "xls/dslx/type_system/unwrap_meta_type.h" #include "xls/dslx/type_system_v2/inference_table.h" +#include "xls/dslx/type_system_v2/solve_for_parametrics.h" #include "xls/dslx/type_system_v2/type_annotation_utils.h" #include "xls/dslx/warning_collector.h" @@ -184,14 +190,13 @@ class ConversionOrderVisitor : public AstNodeVisitorWithDefault { private: absl::StatusOr> GetRelatedInvocations(const AstNode* node) { - std::vector result; + std::vector referenced_invocations; if (std::optional annotation = table_.GetTypeAnnotation(node); annotation.has_value()) { const auto it = invocation_scoped_annotations_.find(*annotation); - if (it != invocation_scoped_annotations_.end() && - !handled_parametric_invocations_.contains(it->second)) { - result.push_back(it->second); + if (it != invocation_scoped_annotations_.end()) { + referenced_invocations.push_back(it->second); } } if (std::optional variable = table_.GetTypeVariable(node); @@ -200,12 +205,42 @@ class ConversionOrderVisitor : public AstNodeVisitorWithDefault { table_.GetTypeAnnotationsForTypeVariable(*variable)); for (const TypeAnnotation* annotation : annotations) { const auto it = invocation_scoped_annotations_.find(annotation); - if (it != invocation_scoped_annotations_.end() && - !handled_parametric_invocations_.contains(it->second)) { - result.push_back(it->second); + if (it != invocation_scoped_annotations_.end()) { + referenced_invocations.push_back(it->second); + } + } + } + // In a case like `foo(foo(...))` where `foo` is a parametric function, the + // implicit parametrics of the outer invocation depend on the inference of + // the inner invocation. + std::list nested_invocations; + for (const ParametricInvocation* invocation : referenced_invocations) { + XLS_ASSIGN_OR_RETURN( + std::vector descendants, + CollectUnder(&invocation->node(), /*want_types=*/false)); + for (const AstNode* descendant : descendants) { + if (descendant == &invocation->node()) { + continue; + } + if (const auto* descendant_invocation = + dynamic_cast(descendant)) { + std::optional + descendant_parametric_invocation = + table_.GetParametricInvocation(descendant_invocation); + if (descendant_parametric_invocation.has_value()) { + nested_invocations.push_front(*descendant_parametric_invocation); + } } } } + std::vector result; + auto needs_handling = [&](const ParametricInvocation* invocation) { + return !handled_parametric_invocations_.contains(invocation); + }; + absl::c_copy_if(nested_invocations, std::back_inserter(result), + needs_handling); + absl::c_copy_if(referenced_invocations, std::back_inserter(result), + needs_handling); return result; } @@ -503,20 +538,132 @@ class InferenceTableConverter { return it->second; } absl::flat_hash_map values; + absl::flat_hash_set implicit_parametrics; + auto infer_pending_implicit_parametrics = [&]() -> absl::Status { + if (implicit_parametrics.empty()) { + return absl::OkStatus(); + } + absl::flat_hash_map new_values; + XLS_ASSIGN_OR_RETURN(new_values, InferImplicitFunctionParametrics( + invocation, implicit_parametrics)); + implicit_parametrics.clear(); + values.merge(std::move(new_values)); + return absl::OkStatus(); + }; for (const ParametricBinding* binding : invocation->callee().parametric_bindings()) { - InvocationScopedExpr expr = + std::optional expr = table_.GetParametricValue(*binding->name_def(), *invocation); - XLS_ASSIGN_OR_RETURN(InterpValue value, Evaluate(expr)); - invocation_type_info_.at(invocation) - ->NoteConstExpr(binding->name_def(), value); - values.emplace(binding->name_def()->identifier(), value); + if (expr.has_value()) { + // The expr may be a default expr which may use the inferred values of + // any parametrics preceding it, so let's resolve any pending implicit + // ones now. + XLS_RETURN_IF_ERROR(infer_pending_implicit_parametrics()); + // Now evaluate the expr. + XLS_ASSIGN_OR_RETURN(InterpValue value, Evaluate(*expr)); + invocation_type_info_.at(invocation) + ->NoteConstExpr(binding->name_def(), value); + values.emplace(binding->name_def()->identifier(), value); + } else { + implicit_parametrics.insert(binding); + } } + // Resolve any implicit ones that are at the end of the list. + XLS_RETURN_IF_ERROR(infer_pending_implicit_parametrics()); ParametricEnv env(values); converted_parametric_envs_.emplace(invocation, env); return env; } + // Attempts to infer the values of the specified implicit parametrics in an + // invocation, using the types of the regular arguments being passed. If not + // all of `implicit_parametrics` can be determined, this function returns an + // error. + absl::StatusOr> + InferImplicitFunctionParametrics( + const ParametricInvocation* invocation, + absl::flat_hash_set implicit_parametrics) { + VLOG(5) << "Inferring " << implicit_parametrics.size() + << " implicit parametrics for invocation: " << ToString(invocation); + absl::flat_hash_map values; + const absl::Span formal_args = invocation->callee().params(); + const absl::Span actual_args = invocation->node().args(); + TypeInfo* ti = invocation_type_info_.at(invocation); + for (int i = 0; i < formal_args.size(); i++) { + std::optional actual_arg_type_var = + table_.GetTypeVariable(actual_args[i]); + if (!actual_arg_type_var.has_value()) { + VLOG(5) << "The actual argument: `" << actual_args[i]->ToString() + << "` has no type variable."; + continue; + } + VLOG(5) << "Using type variable: " << (*actual_arg_type_var)->ToString(); + XLS_ASSIGN_OR_RETURN( + std::vector actual_arg_annotations, + table_.GetTypeAnnotationsForTypeVariable(*actual_arg_type_var)); + XLS_RETURN_IF_ERROR( + ResolveVariableTypeAnnotations(invocation, actual_arg_annotations)); + TypeInfo* actual_arg_ti = base_type_info_; + if (invocation->caller_invocation().has_value()) { + actual_arg_ti = + invocation_type_info_.at(*invocation->caller_invocation()); + } + // The type variable for the actual argument should have at least one + // annotation associated with it that came from the formal argument and is + // therefore dependent on the parametric we are solving for. Let's unify + // just the independent annotations(s) for the purposes of solving for the + // variable. + RemoveAnnotationsReferringToNamesWithoutTypeInfo(actual_arg_ti, + actual_arg_annotations); + if (actual_arg_annotations.empty()) { + VLOG(5) << "The actual argument type variable: " + << (*actual_arg_type_var)->ToString() + << " has no independent type annotations."; + continue; + } + XLS_ASSIGN_OR_RETURN( + const TypeAnnotation* actual_arg_type, + UnifyTypeAnnotations(invocation, actual_arg_annotations, + actual_args[i]->span())); + std::optional effective_invocation = + GetEffectiveParametricInvocation(invocation->caller_invocation(), + actual_arg_type); + absl::flat_hash_map resolved; + VLOG(5) << "Infer using actual type: " << actual_arg_type->ToString() + << " with effective invocation: " + << ToString(effective_invocation); + XLS_ASSIGN_OR_RETURN( + resolved, + SolveForParametrics( + actual_arg_type, formal_args[i]->type_annotation(), + implicit_parametrics, + [&](const TypeAnnotation* expected_type, const Expr* expr) { + return Evaluate(InvocationScopedExpr(effective_invocation, + expected_type, expr)); + })); + for (auto& [binding, value] : resolved) { + VLOG(5) << "Inferred implicit parametric value: " << value.ToString() + << " for binding: " << binding->identifier() + << " using function argument: `" << actual_args[i]->ToString() + << "` of actual type: " << actual_arg_type->ToString(); + ti->NoteConstExpr(binding->name_def(), value); + implicit_parametrics.erase(binding); + values.emplace(binding->identifier(), std::move(value)); + } + } + if (!implicit_parametrics.empty()) { + std::vector binding_names; + binding_names.reserve(implicit_parametrics.size()); + for (const ParametricBinding* binding : implicit_parametrics) { + binding_names.push_back(binding->identifier()); + } + return absl::InvalidArgumentError( + absl::StrCat("Could not infer parametric(s): ", + absl::StrJoin(binding_names, ", "))); + } + return values; + } + absl::StatusOr EvaluateBoolOrExpr( std::optional parametric_invocation, std::variant value_or_expr) { @@ -579,11 +726,8 @@ class InferenceTableConverter { return absl::InvalidArgumentError( "Failed to unify because there are no type annotations."); } - for (int i = 0; i < annotations.size(); i++) { - XLS_ASSIGN_OR_RETURN(annotations[i], - ResolveVariableTypeAnnotations(parametric_invocation, - annotations[i])); - } + XLS_RETURN_IF_ERROR( + ResolveVariableTypeAnnotations(parametric_invocation, annotations)); if (annotations.size() == 1 && !invocation_scoped_type_annotations_.contains(annotations[0])) { // This is here mainly for preservation of shorthand annotations appearing @@ -796,6 +940,19 @@ class InferenceTableConverter { return result; } + // Variant that deeply resolves all `TypeVariableTypeAnnotation`s within a + // vector of annotations. + absl::Status ResolveVariableTypeAnnotations( + std::optional parametric_invocation, + std::vector& annotations) { + for (int i = 0; i < annotations.size(); i++) { + XLS_ASSIGN_OR_RETURN(annotations[i], + ResolveVariableTypeAnnotations(parametric_invocation, + annotations[i])); + } + return absl::OkStatus(); + } + // Checks if the given concrete type ultimately makes sense for the given // node, based on the intrinsic properties of the node, like being an add // operation or containing an embedded literal. @@ -933,6 +1090,38 @@ class InferenceTableConverter { return context_invocation; } + // Removes any annotations in the given vector that contain any `NameRef` + // whose type info has not (yet) been generated. The effective `TypeInfo` for + // each annotation is either `default_ti`; or, for invocation-scoped + // annotations, the `TypeInfo` for the relevant parametric invocation. + void RemoveAnnotationsReferringToNamesWithoutTypeInfo( + TypeInfo* default_ti, std::vector& annotations) { + annotations.erase( + std::remove_if( + annotations.begin(), annotations.end(), + [&](const TypeAnnotation* annotation) { + TypeInfo* ti = default_ti; + const auto it = + invocation_scoped_type_annotations_.find(annotation); + if (it != invocation_scoped_type_annotations_.end()) { + ti = invocation_type_info_.at(it->second); + } + FreeVariables vars = + GetFreeVariablesByLambda(annotation, [&](const NameRef& ref) { + if (!std::holds_alternative( + ref.name_def())) { + return false; + } + const NameDef* name_def = + std::get(ref.name_def()); + return !ti->GetItem(name_def).has_value() && + !ti->IsKnownConstExpr(name_def); + }); + return vars.GetFreeVariableCount() > 0; + }), + annotations.end()); + } + const InferenceTable& table_; Module& module_; ImportData& import_data_; diff --git a/xls/dslx/type_system_v2/solve_for_parametrics.cc b/xls/dslx/type_system_v2/solve_for_parametrics.cc index a7c30ff985..d5762c05ba 100644 --- a/xls/dslx/type_system_v2/solve_for_parametrics.cc +++ b/xls/dslx/type_system_v2/solve_for_parametrics.cc @@ -95,7 +95,8 @@ class Resolver { Resolver( Visitor* resolvable_visitor, Visitor* dependent_visitor, const absl::flat_hash_set& bindings_to_resolve, - absl::AnyInvocable(const Expr*)> + absl::AnyInvocable( + const TypeAnnotation* expected_type, const Expr*)> expr_evaluator) : resolvable_visitor_(resolvable_visitor), dependent_visitor_(dependent_visitor), @@ -132,25 +133,12 @@ class Resolver { *resolvable_visitor_->last_signedness_and_bit_count(); SignednessAndBitCountResult dependent_signedness_and_bit_count = *dependent_visitor_->last_signedness_and_bit_count(); - if (std::holds_alternative( - dependent_signedness_and_bit_count.bit_count)) { - XLS_RETURN_IF_ERROR( - std::get(dependent_signedness_and_bit_count.bit_count) - ->Accept(dependent_visitor_)); - XLS_RETURN_IF_ERROR( - ResolveVariable(resolvable_signedness_and_bit_count.bit_count, - *dependent_visitor_->last_variable())); - } - if (std::holds_alternative( - dependent_signedness_and_bit_count.signedness)) { - XLS_RETURN_IF_ERROR( - std::get(dependent_signedness_and_bit_count.signedness) - ->Accept(dependent_visitor_)); - XLS_RETURN_IF_ERROR( - ResolveVariable(resolvable_signedness_and_bit_count.signedness, - *dependent_visitor_->last_variable())); - } - return absl::OkStatus(); + XLS_RETURN_IF_ERROR(ResolveIntegerTypeComponent( + resolvable_signedness_and_bit_count.bit_count, + dependent_signedness_and_bit_count.bit_count)); + return ResolveIntegerTypeComponent( + resolvable_signedness_and_bit_count.signedness, + dependent_signedness_and_bit_count.signedness); } absl::flat_hash_map& results() { @@ -174,8 +162,10 @@ class Resolver { return absl::InvalidArgumentError( absl::Substitute("Could not evaluate: $0", resolvable->ToString())); } - XLS_ASSIGN_OR_RETURN(InterpValue value, expr_evaluator_(expr)); - return NoteValue(it->second, std::move(value)); + const ParametricBinding* binding = it->second; + XLS_ASSIGN_OR_RETURN(InterpValue value, + expr_evaluator_(binding->type_annotation(), expr)); + return NoteValue(binding, std::move(value)); } // Variant that takes an int64_t `value` instead of an `Expr`. @@ -188,10 +178,14 @@ class Resolver { XLS_ASSIGN_OR_RETURN( SignednessAndBitCountResult signedness_and_bit_count, GetSignednessAndBitCount(it->second->type_annotation())); - XLS_ASSIGN_OR_RETURN(bool is_signed, - Evaluate(signedness_and_bit_count.signedness)); - XLS_ASSIGN_OR_RETURN(int64_t bit_count, - Evaluate(signedness_and_bit_count.bit_count)); + XLS_ASSIGN_OR_RETURN( + bool is_signed, + Evaluate(CreateBoolAnnotation(*variable->owner(), variable->span()), + signedness_and_bit_count.signedness)); + XLS_ASSIGN_OR_RETURN( + int64_t bit_count, + Evaluate(CreateS64Annotation(*variable->owner(), variable->span()), + signedness_and_bit_count.bit_count)); return NoteValue(it->second, is_signed ? InterpValue::MakeSBits(bit_count, value) : InterpValue::MakeUBits(bit_count, value)); @@ -207,6 +201,30 @@ class Resolver { return ResolveVariable(std::get(value), variable); } + // Resolves a variable for the signedness or bit count of an integer type like + // `xN[S][N]`, `uN[N]`, etc. The input is expected to be a field from a + // `SignednessAndBitCountResult` object obtained using the type annotation. + // If the component is not a variable, this does nothing. + template + absl::Status ResolveIntegerTypeComponent( + const std::variant& resolvable, + const std::variant& dependent) { + if (!std::holds_alternative(dependent)) { + // If the dependent value is not actually dependent on anything, there is + // nothing to do. We would hit this for the static component in a + // dependent annotation that only has one component parameterized, like + // `xN[S][32]`. + return absl::OkStatus(); + } + XLS_RETURN_IF_ERROR( + std::get(dependent)->Accept(dependent_visitor_)); + if (dependent_visitor_->last_variable().has_value()) { + XLS_RETURN_IF_ERROR( + ResolveVariable(resolvable, *dependent_visitor_->last_variable())); + } + return absl::OkStatus(); + } + // Records the given `value` for `variable` in the result map, and ensures // that it doesn't conflict with a value already recorded. absl::Status NoteValue(const ParametricBinding* variable, InterpValue value) { @@ -221,12 +239,14 @@ class Resolver { } template - absl::StatusOr Evaluate(std::variant value_or_expr) { + absl::StatusOr Evaluate(const TypeAnnotation* expected_type, + std::variant value_or_expr) { if (std::holds_alternative(value_or_expr)) { return std::get(value_or_expr); } - XLS_ASSIGN_OR_RETURN(InterpValue value, - expr_evaluator_(std::get(value_or_expr))); + XLS_ASSIGN_OR_RETURN( + InterpValue value, + expr_evaluator_(expected_type, std::get(value_or_expr))); return value.GetBitValueUnsigned(); } @@ -234,7 +254,9 @@ class Resolver { Visitor* const dependent_visitor_; absl::flat_hash_map bindings_to_resolve_; - absl::AnyInvocable(const Expr*)> expr_evaluator_; + absl::AnyInvocable( + const TypeAnnotation* expected_type, const Expr*)> + expr_evaluator_; absl::flat_hash_map results_; }; @@ -244,17 +266,21 @@ absl::StatusOr> SolveForParametrics(const TypeAnnotation* resolvable_type, const TypeAnnotation* parametric_dependent_type, absl::flat_hash_set parametrics, - absl::AnyInvocable(const Expr*)> + absl::AnyInvocable( + const TypeAnnotation* expected_type, const Expr*)> expr_evaluator) { Visitor resolvable_visitor; Visitor dependent_visitor; Resolver resolver(&resolvable_visitor, &dependent_visitor, parametrics, std::move(expr_evaluator)); - XLS_RETURN_IF_ERROR(ZipAst(resolvable_type, parametric_dependent_type, - &resolvable_visitor, &dependent_visitor, - [&](const AstNode* lhs, const AstNode* rhs) { - return resolver.AcceptMismatch(lhs, rhs); - })); + XLS_RETURN_IF_ERROR( + ZipAst(resolvable_type, parametric_dependent_type, &resolvable_visitor, + &dependent_visitor, + ZipAstOptions{.check_defs_for_name_refs = true, + .accept_mismatch_callback = [&](const AstNode* lhs, + const AstNode* rhs) { + return resolver.AcceptMismatch(lhs, rhs); + }})); return std::move(resolver.results()); } diff --git a/xls/dslx/type_system_v2/solve_for_parametrics.h b/xls/dslx/type_system_v2/solve_for_parametrics.h index 2306d06de5..91b5598052 100644 --- a/xls/dslx/type_system_v2/solve_for_parametrics.h +++ b/xls/dslx/type_system_v2/solve_for_parametrics.h @@ -44,7 +44,8 @@ absl::StatusOr> SolveForParametrics(const TypeAnnotation* resolvable_type, const TypeAnnotation* parametric_dependent_type, absl::flat_hash_set parametrics, - absl::AnyInvocable(const Expr*)> + absl::AnyInvocable( + const TypeAnnotation* expected_type, const Expr*)> expr_evaluator); } // namespace xls::dslx diff --git a/xls/dslx/type_system_v2/solve_for_parametrics_test.cc b/xls/dslx/type_system_v2/solve_for_parametrics_test.cc index 97ec8c979d..0f9d1bb6c3 100644 --- a/xls/dslx/type_system_v2/solve_for_parametrics_test.cc +++ b/xls/dslx/type_system_v2/solve_for_parametrics_test.cc @@ -86,10 +86,11 @@ const BAR: uN[4] = uN[4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, IsEmpty()); } @@ -107,13 +108,37 @@ const BAR: uN[4] = uN[4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeU32(4)))); } +TEST_F(SolveForParametricsTest, SolveForUnWithConstant) { + XLS_ASSERT_OK_AND_ASSIGN(auto module, Parse(R"( +const X = u32:1; +fn foo(a: uN[N]) -> uN[N] { a } +const BAR: uN[X] = uN[X]:1; +)")); + XLS_ASSERT_OK_AND_ASSIGN(const Function* foo, + module->GetMemberOrError("foo")); + XLS_ASSERT_OK_AND_ASSIGN(const ConstantDef* bar, + module->GetMemberOrError("BAR")); + const ParametricBinding* n = foo->parametric_bindings()[0]; + const Param* a = foo->params()[0]; + absl::flat_hash_map values; + XLS_ASSERT_OK_AND_ASSIGN( + values, + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return InterpValue::MakeU32(5); + })); + EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeU32(5)))); +} + TEST_F(SolveForParametricsTest, SolveForUnWithUnWithExpr) { XLS_ASSERT_OK_AND_ASSIGN(auto module, Parse(R"( const X = u32:1; @@ -128,10 +153,12 @@ const BAR: uN[4 + X] = uN[4 + X]:1; const Param* a = foo->params()[0]; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( - values, SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return InterpValue::MakeU32(5); })); + values, + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return InterpValue::MakeU32(5); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeU32(5)))); } @@ -149,10 +176,11 @@ const BAR: u4 = u4:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeU32(4)))); } @@ -170,10 +198,11 @@ const BAR: sN[4] = sN[4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeS32(4)))); } @@ -191,10 +220,11 @@ const BAR: xN[false][4] = xN[false][4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeS32(4)))); } @@ -212,10 +242,11 @@ const BAR: bits[4] = bits[4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeS32(4)))); } @@ -233,10 +264,11 @@ const BAR: uN[4] = uN[4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeS32(4)))); } @@ -255,10 +287,11 @@ const BAR: xN[true][4] = xN[true][4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{s, n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{s, n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(s, InterpValue::MakeU32(1)), Pair(n, InterpValue::MakeS32(4)))); } @@ -277,10 +310,11 @@ const BAR: xN[true][4] = xN[true][4]:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeS32(4)))); } @@ -299,10 +333,11 @@ const BAR: s4 = s4:1; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{s, n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{s, n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(s, InterpValue::MakeBool(1)), Pair(n, InterpValue::MakeS32(4)))); } @@ -321,10 +356,11 @@ const BAR: u32[3] = [u32:0, u32:1, u32:3]; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeU32(3)))); } @@ -343,10 +379,11 @@ const BAR: u32[33][34] = zero!(); absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{m, n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{m, n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(m, InterpValue::MakeU32(33)), Pair(n, InterpValue::MakeU32(34)))); } @@ -365,10 +402,11 @@ const BAR: u32[33][34] = zero!(); const ParametricBinding* n = foo->parametric_bindings()[0]; const Param* a = foo->params()[0]; EXPECT_THAT( - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); }), + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + }), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("u32:33 vs. u32:34"))); } @@ -391,11 +429,12 @@ const BAR: xN[false][24][33][34] = zero!(); const Param* a = foo->params()[0]; absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( - values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{m, n, s, w}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + values, SolveForParametrics( + bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{m, n, s, w}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(s, InterpValue::MakeU32(false)), Pair(w, InterpValue::MakeU32(24)), Pair(m, InterpValue::MakeU32(33)), @@ -417,10 +456,11 @@ const BAR: (u10, (s4, sN[10]))[20] = zero!<(u10, (s4, sN[10]))[20]>(); absl::flat_hash_map values; XLS_ASSERT_OK_AND_ASSIGN( values, - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n, x}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); })); + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n, x}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + })); EXPECT_THAT(values, UnorderedElementsAre(Pair(n, InterpValue::MakeU32(10)), Pair(x, InterpValue::MakeU32(20)))); } @@ -438,10 +478,11 @@ const BAR: (uN[5], bool) = zero!<(uN[5], bool)>(); const Param* a = foo->params()[0]; absl::flat_hash_map values; EXPECT_THAT( - SolveForParametrics( - bar->type_annotation(), a->type_annotation(), - absl::flat_hash_set{n}, - [&](const Expr* expr) { return EvaluateLiteral(expr, false, 32); }), + SolveForParametrics(bar->type_annotation(), a->type_annotation(), + absl::flat_hash_set{n}, + [&](const TypeAnnotation*, const Expr* expr) { + return EvaluateLiteral(expr, false, 32); + }), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Mismatch: (uN[5], bool) vs. u32[N]"))); } diff --git a/xls/dslx/type_system_v2/typecheck_module_v2.cc b/xls/dslx/type_system_v2/typecheck_module_v2.cc index 31a0580ae8..8f183483b8 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2.cc @@ -79,18 +79,15 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { absl::Status HandleConstantDef(const ConstantDef* node) override { VLOG(5) << "HandleConstantDef: " << node->ToString(); - XLS_ASSIGN_OR_RETURN( - const NameRef* variable, - table_.DefineInternalVariable(InferenceVariableKind::kType, - const_cast(node), - GenerateInternalTypeVariableName(node))); - XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node, variable)); - XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->name_def(), variable)); + XLS_ASSIGN_OR_RETURN(const NameRef* variable, + DefineTypeVariableForVariableOrConstant(node)); XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->value(), variable)); - if (node->type_annotation() != nullptr) { - XLS_RETURN_IF_ERROR( - table_.SetTypeAnnotation(node->name_def(), node->type_annotation())); - } + return DefaultHandler(node); + } + + absl::Status HandleParam(const Param* node) override { + VLOG(5) << "HandleParam: " << node->ToString(); + XLS_RETURN_IF_ERROR(DefineTypeVariableForVariableOrConstant(node).status()); return DefaultHandler(node); } @@ -339,6 +336,25 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { } private: + // Helper that creates an internal type variable for a `ConstantDef`, `Param`, + // or similar type of node that contains a `NameDef` and optional + // `TypeAnnotation`. + template + absl::StatusOr DefineTypeVariableForVariableOrConstant( + const T* node) { + XLS_ASSIGN_OR_RETURN(const NameRef* variable, + table_.DefineInternalVariable( + InferenceVariableKind::kType, const_cast(node), + GenerateInternalTypeVariableName(node))); + XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node, variable)); + XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->name_def(), variable)); + if (node->type_annotation() != nullptr) { + XLS_RETURN_IF_ERROR( + table_.SetTypeAnnotation(node->name_def(), node->type_annotation())); + } + return variable; + } + // Helper that handles invocation nodes calling free functions, i.e. functions // that do not require callee object type info to be looked up. If a // `parametric_invocation` is specified, it is for the invocation actually @@ -599,8 +615,13 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault { XLS_ASSIGN_OR_RETURN(AstNode * cloned, CloneAst(annotation)); annotation = dynamic_cast(cloned); CHECK(annotation != nullptr); - invocation_scoped_type_annotations_.emplace(annotation, - *parametric_invocation); + for (const AstNode* next : FlattenToSet(annotation)) { + if (const auto* next_annotation = + dynamic_cast(next)) { + invocation_scoped_type_annotations_.emplace(next_annotation, + *parametric_invocation); + } + } return annotation; } diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index 2500af49c5..1603349e93 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -804,7 +804,7 @@ const Y:u32 = foo(1); TEST(TypecheckV2Test, FunctionCallPassingInTooLargeAutoSizeFails) { EXPECT_THAT(R"( -fn foo(a: u4) -> u32 { a } +fn foo(a: u4) -> u4 { a } const Y = foo(32767); )", TypecheckFails(HasSizeMismatch("uN[15]", "u4"))); @@ -813,7 +813,7 @@ const Y = foo(32767); TEST(TypecheckV2Test, FunctionCallPassingInTooLargeExplicitIntegerSizeFails) { EXPECT_THAT(R"( const X:u32 = 1; -fn foo(a: u4) -> u32 { a } +fn foo(a: u4) -> u4 { a } const Y = foo(X); )", TypecheckFails(HasSizeMismatch("uN[32]", "u4"))); @@ -830,7 +830,7 @@ const Y = foo(X); TEST(TypecheckV2Test, FunctionCallPassingInArrayForIntegerFails) { EXPECT_THAT(R"( -fn foo(a: u4) -> u32 { a } +fn foo(a: u4) -> u4 { a } const Y = foo([u32:1, u32:2]); )", TypecheckFails(HasTypeMismatch("uN[32][2]", "u4"))); @@ -918,6 +918,28 @@ const Y = foo<11>(u11:5); HasSubstr("node: `const Y = foo<11>(u11:5);`, type: uN[11]"))); } +TEST(TypecheckV2Test, + ParametricFunctionTakingIntegerOfImplicitParameterizedSize) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: uN[N]) -> uN[N] { a } +const X = foo(u10:5); +const Y = foo(u11:5); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + AllOf(HasSubstr("node: `const X = foo(u10:5);`, type: uN[10]"), + HasSubstr("node: `const Y = foo(u11:5);`, type: uN[11]"))); +} + +TEST(TypecheckV2Test, ParametricFunctionWithNonInferrableParametric) { + EXPECT_THAT(R"( +fn foo(a: uN[M]) -> uN[M] { a } +const X = foo(u10:5); +)", + TypecheckFails(HasSubstr("Could not infer parametric(s): N"))); +} + TEST(TypecheckV2Test, ParametricFunctionTakingIntegerOfParameterizedSignedness) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( @@ -933,6 +955,20 @@ const Y = foo(s32:5); HasSubstr("node: `const Y = foo(s32:5);`, type: sN[32]"))); } +TEST(TypecheckV2Test, + ParametricFunctionTakingIntegerOfImplicitParameterizedSignedness) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: xN[S][32]) -> xN[S][32] { a } +const X = foo(u32:5); +const Y = foo(s32:5); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + AllOf(HasSubstr("node: `const X = foo(u32:5);`, type: uN[32]"), + HasSubstr("node: `const Y = foo(s32:5);`, type: sN[32]"))); +} + TEST(TypecheckV2Test, ParametricFunctionTakingIntegerOfParameterizedSignednessAndSize) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( @@ -949,6 +985,20 @@ const Y = foo(s11:5); HasSubstr("node: `const Y = foo(s11:5);`, type: sN[11]"))); } +TEST(TypecheckV2Test, + ParametricFunctionTakingIntegerOfImplicitParameterizedSignednessAndSize) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: xN[S][N]) -> xN[S][N] { a } +const X = foo(u10:5); +const Y = foo(s11:5); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + AllOf(HasSubstr("node: `const X = foo(u10:5);`, type: uN[10]"), + HasSubstr("node: `const Y = foo(s11:5);`, type: sN[11]"))); +} + TEST(TypecheckV2Test, ParametricFunctionTakingIntegerOfDefaultParameterizedSize) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( @@ -985,16 +1035,41 @@ const X = foo<11>(u12:5); HasSubstr("node: `const X = foo<11>(u12:5);`, type: uN[12]")); } +TEST(TypecheckV2Test, ParametricFunctionWithDefaultImplicitlyOverriddenFails) { + // In a case like this, the "overridden" value for `N` must be explicit (v1 + // agrees). + EXPECT_THAT(R"( +fn foo(a: uN[N]) -> uN[N] { a } +const X = foo<11>(u20:5); +)", + TypecheckFails(HasSizeMismatch("uN[20]", "uN[12]"))); +} + TEST(TypecheckV2Test, - ParametricFunctionTakingIntegerWithOverriddenDependentDefaultParametric) { + ParametricFunctionWithDefaultDependingOnInferredParametric) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( -fn foo(a: uN[N]) -> uN[N] { a } -const X = foo<11, 20>(u20:5); +fn foo(a: uN[M]) -> uN[M] { a } +const X = foo(u10:5); )")); XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, TypeInfoToString(result.tm)); EXPECT_THAT(type_info_string, - HasSubstr("node: `const X = foo<11, 20>(u20:5);`, type: uN[20]")); + HasSubstr("node: `const X = foo(u10:5);`, type: uN[10]")); +} + +TEST(TypecheckV2Test, + ParametricFunctionWithInferredThenDefaultThenInferredParametric) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(x: uN[A], y: uN[C][B]) -> uN[A] { + x +} +const X = foo(u3:1, [u24:6, u24:7, u24:8, u24:9]); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + HasSubstr("node: `const X = foo(u3:1, [u24:6, u24:7, u24:8, " + "u24:9]);`, type: uN[3]")); } TEST(TypecheckV2Test, @@ -1014,6 +1089,21 @@ const Z = foo<32>(X + Y + X + 50); HasSubstr("node: `const Z = foo<32>(X + Y + X + 50);`, type: uN[32]")); } +TEST(TypecheckV2Test, + ParametricFunctionTakingIntegerOfImplicitSignednessAndSizeWithSum) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +const X = u32:3; +const Y = u32:4; +fn foo(a: uN[N]) -> uN[N] { a } +const Z = foo(X + Y + X + 50); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT( + type_info_string, + HasSubstr("node: `const Z = foo(X + Y + X + 50);`, type: uN[32]")); +} + TEST(TypecheckV2Test, ParametricFunctionTakingArrayOfParameterizedSize) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( fn foo(a: u32[N]) -> u32[N] { a } @@ -1029,6 +1119,21 @@ const Y = foo<4>([1, 2, 3, 4]); "node: `const Y = foo<4>([1, 2, 3, 4]);`, type: uN[32][4]"))); } +TEST(TypecheckV2Test, ParametricFunctionTakingArrayOfImplicitSize) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: u32[N]) -> u32[N] { a } +const X = foo([1, 2, 3]); +const Y = foo([1, 2, 3, 4]); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT( + type_info_string, + AllOf( + HasSubstr("node: `const X = foo([1, 2, 3]);`, type: uN[32][3]"), + HasSubstr("node: `const Y = foo([1, 2, 3, 4]);`, type: uN[32][4]"))); +} + TEST(TypecheckV2Test, ParametricFunctionWithArgumentMismatchingParameterizedSizeFails) { EXPECT_THAT(R"( @@ -1067,6 +1172,30 @@ const X = foo<24, 23>(4); HasSubstr("node: `const X = foo<24, 23>(4);`, type: uN[23]")); } +TEST(TypecheckV2Test, ParametricFunctionImplicitParameterPropagation) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn bar(a: uN[A], b: uN[B]) -> uN[A] { a + 1 } +fn foo(a: uN[A], b: uN[B]) -> uN[B] { bar(b, a) } +const X = foo(u23:4, u17:5); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + HasSubstr("node: `const X = foo(u23:4, u17:5);`, type: uN[17]")); +} + +TEST(TypecheckV2Test, ParametricFunctionImplicitParameterExplicitPropagation) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn bar(a: uN[A], b: uN[B]) -> uN[A] { a + 1 } +fn foo(a: uN[A], b: uN[B]) -> uN[B] { bar(b, a) } +const X = foo(u23:4, u17:5); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + HasSubstr("node: `const X = foo(u23:4, u17:5);`, type: uN[17]")); +} + TEST(TypecheckV2Test, ParametricFunctionInvocationNesting) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( fn foo(a: uN[N]) -> uN[N] { a + 1 } @@ -1080,6 +1209,46 @@ const X = foo<24>(foo<24>(4) + foo<24>(5)); "node: `const X = foo<24>(foo<24>(4) + foo<24>(5));`, type: uN[24]")); } +TEST(TypecheckV2Test, ParametricFunctionImplicitInvocationNesting) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: uN[N]) -> uN[N] { a + 1 } +const X = foo(foo(u24:4) + foo(u24:5)); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT( + type_info_string, + HasSubstr( + "node: `const X = foo(foo(u24:4) + foo(u24:5));`, type: uN[24]")); +} + +TEST(TypecheckV2Test, + ParametricFunctionImplicitInvocationNestingWithExplicitOuter) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: uN[N]) -> uN[N] { a + 1 } +const X = foo<24>(foo(u24:4 + foo(u24:6)) + foo(u24:5)); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + HasSubstr("node: `const X = foo<24>(foo(u24:4 + foo(u24:6)) + " + "foo(u24:5));`, type: uN[24]")); +} + +TEST(TypecheckV2Test, + ParametricFunctionImplicitInvocationNestingWithExplicitInner) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: uN[N]) -> uN[N] { a + 1 } +const X = foo(foo<24>(4) + foo<24>(5)); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT( + type_info_string, + HasSubstr( + "node: `const X = foo(foo<24>(4) + foo<24>(5));`, type: uN[24]")); +} + TEST(TypecheckV2Test, ParametricFunctionUsingGlobalConstantInParametricDefault) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( @@ -1106,6 +1275,19 @@ const Z = foo(u3:1); HasSubstr("node: `const Z = foo(u3:1);`, type: uN[3]")); } +TEST(TypecheckV2Test, + ParametricFunctionCallUsingGlobalConstantInImplicitParametricArgument) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: uN[N]) -> uN[N] { a } +const X = u3:1; +const Z = foo(X); +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + HasSubstr("node: `const Z = foo(X);`, type: uN[3]")); +} + TEST(TypecheckV2Test, ParametricFunctionCallFollowedByTypePropagation) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( fn foo(a: uN[N]) -> uN[N] { a } @@ -1118,6 +1300,19 @@ const Z = Y + 1; HasSubstr("node: `const Z = Y + 1;`, type: uN[15]")); } +TEST(TypecheckV2Test, + ParametricFunctionCallWithImplicitParameterFollowedByTypePropagation) { + XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( +fn foo(a: uN[N]) -> uN[N] { a } +const Y = foo(u15:1); +const Z = Y + 1; +)")); + XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string, + TypeInfoToString(result.tm)); + EXPECT_THAT(type_info_string, + HasSubstr("node: `const Z = Y + 1;`, type: uN[15]")); +} + TEST(TypecheckV2Test, GlobalConstantUsingParametricFunction) { XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"( fn foo(a: uN[N]) -> uN[N] { a }