Skip to content

Commit

Permalink
Change function type param representation to include parent generic p…
Browse files Browse the repository at this point in the history
…arams explicitly
  • Loading branch information
Y-Nak committed Feb 20, 2024
1 parent d5c407f commit e0346f8
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 96 deletions.
20 changes: 10 additions & 10 deletions crates/hir-analysis/src/ty/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@ impl<'db> SuperTraitCollector<'db> {
fn collect(mut self) -> BTreeSet<TraitInstId> {
let hir_trait = self.trait_.trait_(self.db);
let hir_db = self.db.as_hir_db();
let self_param = self.trait_.self_param(self.db);

for &super_ in hir_trait.super_traits(hir_db).iter() {
if let Ok(inst) = lower_trait_ref(self.db, super_, self.scope) {
if let Ok(inst) = lower_trait_ref(self.db, self_param, super_, self.scope) {
self.super_traits.insert(inst);
}
}
Expand All @@ -387,7 +389,7 @@ impl<'db> SuperTraitCollector<'db> {
{
for bound in &pred.bounds {
if let TypeBound::Trait(bound) = bound {
if let Ok(inst) = lower_trait_ref(self.db, *bound, self.scope) {
if let Ok(inst) = lower_trait_ref(self.db, self_param, *bound, self.scope) {
self.super_traits.insert(inst);
}
}
Expand Down Expand Up @@ -496,17 +498,14 @@ impl<'db> ConstraintCollector<'db> {

fn collect_constraints_from_generic_params(&mut self) {
let param_set = collect_generic_params(self.db, self.owner);
let params_list = self.owner.params(self.db);
assert!(param_set.params(self.db).len() == params_list.len(self.db.as_hir_db()));
for (&ty, hir_param) in param_set
.params(self.db)
.iter()
.zip(params_list.data(self.db.as_hir_db()))
{
let param_list = self.owner.params(self.db);

for (i, hir_param) in param_list.data(self.db.as_hir_db()).iter().enumerate() {
let GenericParam::Type(hir_param) = hir_param else {
continue;
};

let ty = param_set.param_by_original_idx(self.db, i).unwrap();
let bounds = &hir_param.bounds;
self.add_bounds(ty, bounds)
}
Expand All @@ -518,7 +517,8 @@ impl<'db> ConstraintCollector<'db> {
continue;
};

let Ok(trait_inst) = lower_trait_ref(self.db, *trait_ref, self.owner.scope(self.db))
let Ok(trait_inst) =
lower_trait_ref(self.db, bound_ty, *trait_ref, self.owner.scope(self.db))
else {
continue;
};
Expand Down
123 changes: 87 additions & 36 deletions crates/hir-analysis/src/ty/def_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use std::collections::{hash_map::Entry, BTreeSet};

use hir::{
hir_def::{
scope_graph::ScopeId, FieldDef, Func, FuncParamListId, GenericParam, IdentId,
Impl as HirImpl, ImplTrait, ItemKind, PathId, Trait, TraitRefId, TypeAlias,
scope_graph::ScopeId, FieldDef, Func, FuncParamListId, GenericParam, GenericParamListId,
IdentId, Impl as HirImpl, ImplTrait, ItemKind, PathId, Trait, TraitRefId, TypeAlias,
TypeId as HirTyId, VariantKind,
},
visitor::prelude::*,
Expand Down Expand Up @@ -274,7 +274,7 @@ impl<'db> DefAnalyzer<'db> {
/// This method verifies if
/// 1. the given `ty` has `*` kind(i.e, concrete type)
/// 2. the given `ty` is not const type
/// TODo: This method is a stop-gap implementation until we design a true
/// TODO: This method is a stop-gap implementation until we design a true
/// const type system.
fn verify_normal_star_type(&mut self, ty: HirTyId, span: DynLazySpan) -> bool {
let ty = lower_hir_ty(self.db, ty, self.scope());
Expand All @@ -291,6 +291,49 @@ impl<'db> DefAnalyzer<'db> {
}
}

// Check if the same generic parameter is already defined in the parent item.
// Other name conflict check is done in the name resolution.
//
// This check is necessary because the conflict rule
// for the generic parameter is the exceptional case where shadowing shouldn't
// occur.
fn verify_method_generic_param_conflict(
&mut self,
params: GenericParamListId,
span: LazyGenericParamListSpan,
) -> bool {
let mut is_conflict = false;
for (i, param) in params.data(self.db.as_hir_db()).iter().enumerate() {
if let Some(name) = param.name().to_opt() {
let scope = self.scope();
let parent_scope = scope.parent_item(self.db.as_hir_db()).unwrap().scope();
let path = PathId::from_ident(self.db.as_hir_db(), name);
if let EarlyResolvedPath::Full(bucket) =
resolve_path_early(self.db, path, parent_scope)
{
if let Ok(res) = bucket.pick(NameDomain::Type) {
if let NameResKind::Scope(conflict_with @ ScopeId::GenericParam(..)) =
res.kind
{
self.diags.push(
TyLowerDiag::generic_param_conflict(
span.param(i).into(),
conflict_with.name_span(self.db.as_hir_db()).unwrap(),
name,
)
.into(),
);

is_conflict = true;
}
}
}
}
}

!is_conflict
}

fn verify_self_type(&mut self, self_ty: HirTyId, span: DynLazySpan) -> bool {
let expected_ty = self.self_ty.unwrap();

Expand Down Expand Up @@ -513,12 +556,6 @@ impl<'db> Visitor for DefAnalyzer<'db> {
unreachable!()
};

// Check if the same generic parameter is already defined in the parent item.
// Other name conflict check is done in the name resolution.
//
// This check is necessary because the conflict rule
// for the generic parameter is the exceptional case where shadowing shouldn't
// occur.
if let Some(name) = param.name().to_opt() {
let scope = self.scope();
let parent_scope = scope.parent_item(self.db.as_hir_db()).unwrap().scope();
Expand Down Expand Up @@ -546,13 +583,13 @@ impl<'db> Visitor for DefAnalyzer<'db> {
match param {
GenericParam::Type(_) => {
self.current_ty = Some((
self.def.params(self.db)[idx],
self.def.original_params(self.db)[idx],
ctxt.span().unwrap().into_type_param().name().into(),
));
walk_generic_param(self, ctxt, param)
}
GenericParam::Const(_) => {
let ty = self.def.params(self.db)[idx];
let ty = self.def.original_params(self.db)[idx];
let Some(const_ty_param) = ty.const_ty_param(self.db) else {
return;
};
Expand Down Expand Up @@ -596,14 +633,17 @@ impl<'db> Visitor for DefAnalyzer<'db> {
ctxt: &mut VisitorCtxt<'_, LazyTraitRefSpan>,
trait_ref: TraitRefId,
) {
if self
let current_ty = self
.current_ty
.as_ref()
.map(|(ty, _)| ty.is_trait_self(self.db))
.unwrap_or_default()
{
.map(|(ty, _)| *ty)
.unwrap_or(TyId::invalid(self.db, InvalidCause::Other));

if current_ty.is_trait_self(self.db) {
if let Some(cycle) = self.def.collect_super_trait_cycle(self.db) {
if let Ok(trait_inst) = lower_trait_ref(self.db, trait_ref, self.scope()) {
if let Ok(trait_inst) =
lower_trait_ref(self.db, current_ty, trait_ref, self.scope())
{
if cycle.contains(trait_inst.def(self.db)) {
self.diags.push(
TraitLowerDiag::CyclicSuperTraits(ctxt.span().unwrap().path().into())
Expand All @@ -617,7 +657,7 @@ impl<'db> Visitor for DefAnalyzer<'db> {

if let (Some((ty, span)), Ok(trait_inst)) = (
&self.current_ty,
lower_trait_ref(self.db, trait_ref, self.scope()),
lower_trait_ref(self.db, current_ty, trait_ref, self.scope()),
) {
let expected_kind = trait_inst.def(self.db).expected_implementor_kind(self.db);
if !expected_kind.does_match(ty.kind(self.db)) {
Expand All @@ -630,6 +670,7 @@ impl<'db> Visitor for DefAnalyzer<'db> {

if let Some(diag) = analyze_trait_ref(
self.db,
current_ty,
trait_ref,
self.scope(),
Some(self.assumptions),
Expand Down Expand Up @@ -687,6 +728,13 @@ impl<'db> Visitor for DefAnalyzer<'db> {
return;
}

if !self.verify_method_generic_param_conflict(
hir_func.generic_params(self.db.as_hir_db()),
hir_func.lazy_span().generic_params_moved(),
) {
return;
}

let def = std::mem::replace(&mut self.def, func.into());
let constraints = std::mem::replace(&mut self.assumptions, func.constraints(self.db));

Expand Down Expand Up @@ -836,12 +884,13 @@ impl TyId {

fn analyze_trait_ref(
db: &dyn HirAnalysisDb,
self_ty: TyId,
trait_ref: TraitRefId,
scope: ScopeId,
assumptions: Option<AssumptionListId>,
span: DynLazySpan,
) -> Option<TyDiagCollection> {
let trait_inst = match lower_trait_ref(db, trait_ref, scope) {
let trait_inst = match lower_trait_ref(db, self_ty, trait_ref, scope) {
Ok(trait_ref) => trait_ref,

Err(TraitRefLowerError::ArgNumMismatch { expected, given }) => {
Expand Down Expand Up @@ -894,15 +943,15 @@ enum DefKind {
}

impl DefKind {
fn params(self, db: &dyn HirAnalysisDb) -> &[TyId] {
fn original_params(self, db: &dyn HirAnalysisDb) -> &[TyId] {
match self {
Self::Adt(def) => def.params(db),
Self::Trait(def) => def.params(db),
Self::ImplTrait(def) => def.params(db),
Self::Adt(def) => def.original_params(db),
Self::Trait(def) => def.original_params(db),
Self::ImplTrait(def) => def.original_params(db),
Self::Impl(hir_impl, _) => {
collect_generic_params(db, GenericParamOwnerId::new(db, hir_impl.into())).params(db)
}
Self::Func(def) => def.params(db),
Self::Func(def) => def.original_params(db),
}
}

Expand Down Expand Up @@ -955,9 +1004,16 @@ fn analyze_impl_trait_specific_error(
return Err(diags);
};

// 1. Checks if the trait ref is well-formed except for the satisfiability.
// 1. Checks if implementor type is well-formed except for the satisfiability.
let ty = lower_hir_ty(db, ty, impl_trait.scope());
if let Some(diag) = ty.emit_diag(db, impl_trait.lazy_span().ty().into()) {
diags.push(diag);
}

// 2. Checks if the trait ref is well-formed except for the satisfiability.
if let Some(diag) = analyze_trait_ref(
db,
ty,
trait_ref,
impl_trait.scope(),
None,
Expand All @@ -966,25 +1022,19 @@ fn analyze_impl_trait_specific_error(
diags.push(diag);
}

// 2. Checks if implementor type is well-formed except for the satisfiability.
let ty = lower_hir_ty(db, ty, impl_trait.scope());
if let Some(diag) = ty.emit_diag(db, impl_trait.lazy_span().ty().into()) {
diags.push(diag);
}

// If there is any error at the point, it means that `Implementor` is not
// well-formed and no more analysis is needed to reduce the amount of error
// messages.
if !diags.is_empty() || ty.contains_invalid(db) {
return Err(diags);
}

let trait_inst = match lower_trait_ref(db, trait_ref, impl_trait.scope()) {
let trait_inst = match lower_trait_ref(db, ty, trait_ref, impl_trait.scope()) {
Ok(trait_inst) => trait_inst,
Err(_) => return Err(vec![]),
};

// 3. Check if the ingot contains impl trait is the same as the ingot which
// 3. Check if the ingot containing impl trait is the same as the ingot which
// contains either the type or trait.
let impl_trait_ingot = impl_trait.top_mod(hir_db).ingot(hir_db);
if Some(impl_trait_ingot) != ty.ingot(db) && impl_trait_ingot != trait_inst.def(db).ingot(db) {
Expand Down Expand Up @@ -1167,8 +1217,8 @@ impl<'db> ImplTraitMethodAnalyzer<'db> {
let hir_expected_method = expected_method.0.hir_func(self.db);

// Checks if the number of parameters are the same.
let method_params = impl_method.params(self.db);
let expected_params = expected_method.0.params(self.db);
let method_params = impl_method.original_params(self.db);
let expected_params = expected_method.0.original_params(self.db);
if method_params.len() != expected_params.len() {
self.diags.push(
ImplDiag::method_param_num_mismatch(
Expand All @@ -1185,7 +1235,7 @@ impl<'db> ImplTraitMethodAnalyzer<'db> {
return;
}

// Checks if the parameter kinds are the same.
// Checks if the generic parameter kinds are the same.
for (idx, (&expected_param, &method_param)) in
expected_params.iter().zip(method_params).enumerate()
{
Expand Down Expand Up @@ -1318,7 +1368,7 @@ impl<'db> ImplTraitMethodAnalyzer<'db> {
}

// Check if the method constraints are stricter than the trait constraints.
// This check can be performed to check if the `impl_method` constraints are
// This check is performed by checking if the `impl_method` constraints are
// satisfied under the assumptions that is obtained from the `expected_method`
// constraints.
let expected_constraints = expected_method
Expand All @@ -1337,6 +1387,7 @@ impl<'db> ImplTraitMethodAnalyzer<'db> {
}

if !unsatisfied_goals.is_empty() {
unsatisfied_goals.sort_by_key(|goal| goal.ty(self.db).pretty_print(self.db));
self.diags.push(
ImplDiag::method_stricter_bound(
self.db,
Expand Down
17 changes: 17 additions & 0 deletions crates/hir-analysis/src/ty/trait_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ impl Implementor {
self.trait_(db).def(db)
}

pub(crate) fn original_params(self, db: &dyn HirAnalysisDb) -> &[TyId] {
self.params(db)
}

/// Generalizes the implementor by replacing all type parameters with fresh
/// type variables.
pub(super) fn generalize(
Expand All @@ -152,6 +156,7 @@ impl Implementor {
let hir_impl = self.hir_impl_trait(db);
let trait_ = self.trait_(db).apply_subst(db, &mut subst);
let ty = self.ty(db).apply_subst(db, &mut subst);

let params = self
.params(db)
.iter()
Expand Down Expand Up @@ -208,6 +213,9 @@ impl TraitInstId {
let mut s = self.def(db).name(db).unwrap_or("<unknown>").to_string();

let mut args = self.substs(db).iter().map(|ty| ty.pretty_print(db));
// Skip the first type parameter since it's the implementor type.
args.next();

if let Some(first) = args.next() {
s.push('<');
s.push_str(first);
Expand Down Expand Up @@ -240,7 +248,11 @@ impl TraitInstId {
)
}

/// Returns subst from the trait definition parameter to this instantiated
/// parameters.
pub(super) fn subst_table(self, db: &dyn HirAnalysisDb) -> FxHashMap<TyId, TyId> {
assert!(self.def(db).params(db).len() == self.substs(db).len());

let mut table = FxHashMap::default();
for (from, to) in self.def(db).params(db).iter().zip(self.substs(db)) {
table.insert(*from, *to);
Expand Down Expand Up @@ -279,6 +291,7 @@ impl TraitInstId {
#[salsa::tracked]
pub struct TraitDef {
pub trait_: Trait,
#[return_ref]
pub(crate) param_set: GenericParamTypeSet,
#[return_ref]
pub methods: BTreeMap<IdentId, TraitMethod>,
Expand All @@ -292,6 +305,10 @@ impl TraitDef {
pub fn self_param(self, db: &dyn HirAnalysisDb) -> TyId {
self.param_set(db).trait_self(db).unwrap()
}

pub fn original_params(self, db: &dyn HirAnalysisDb) -> &[TyId] {
self.param_set(db).original_params(db)
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
Expand Down
Loading

0 comments on commit e0346f8

Please sign in to comment.