From 0721a5e125e87fcafbaad6fc7904f6fc8995ba28 Mon Sep 17 00:00:00 2001 From: Jacques-Henri Jourdan Date: Thu, 16 Jan 2025 15:02:24 +0100 Subject: [PATCH] Attributes: use .value_str() instead of .symbol, better error messages. value_str() returns the unescaped string, while the string in .symbol may be escaped. --- creusot/src/contracts_items/attributes.rs | 83 +++++++++-------------- creusot/src/ctx.rs | 14 ++-- 2 files changed, 40 insertions(+), 57 deletions(-) diff --git a/creusot/src/contracts_items/attributes.rs b/creusot/src/contracts_items/attributes.rs index 161eb71f8a..60cbba7796 100644 --- a/creusot/src/contracts_items/attributes.rs +++ b/creusot/src/contracts_items/attributes.rs @@ -1,6 +1,6 @@ //! Defines all the internal creusot attributes. -use rustc_ast::{AttrArgs, AttrArgsEq, Attribute, Param}; +use rustc_ast::{AttrArgs, Attribute, Param}; use rustc_hir::def_id::DefId; use rustc_middle::ty::TyCtxt; use rustc_span::Symbol; @@ -27,7 +27,7 @@ macro_rules! attribute_functions { #[doc = concat!("Detect if `def_id` has the attribute `", stringify!($($p)*), "`")] pub(crate) fn $fn_name(tcx: TyCtxt, def_id: DefId) -> bool { let path = &path_to_str!($($p)*); - let has_attr = get_attr(tcx.get_attrs_unchecked(def_id), path).is_some(); + let has_attr = get_attr(tcx, tcx.get_attrs_unchecked(def_id), path).is_some(); attribute_functions!(@negate $($not)? has_attr) } )+ @@ -60,12 +60,8 @@ attribute_functions! { } pub fn get_invariant_expl(tcx: TyCtxt, def_id: DefId) -> Option { - get_attr(tcx.get_attrs_unchecked(def_id), &["creusot", "spec", "invariant"]).map(|a| { - match a.args { - AttrArgs::Eq(_, AttrArgsEq::Hir(ref expl)) => expl.symbol.to_string(), - _ => "expl:loop invariant".to_string(), - } - }) + get_attr(tcx, tcx.get_attrs_unchecked(def_id), &["creusot", "spec", "invariant"]) + .map(|a| a.value_str().map_or("expl:loop invariant".to_string(), |s| s.to_string())) } pub(crate) fn no_mir(tcx: TyCtxt, def_id: DefId) -> bool { @@ -82,21 +78,19 @@ pub(crate) fn is_pearlite(tcx: TyCtxt, def_id: DefId) -> bool { /// Get the string on the right of `creusot::builtin = ...` pub(crate) fn get_builtin(tcx: TyCtxt, def_id: DefId) -> Option { - get_attr(tcx.get_attrs_unchecked(def_id), &["creusot", "builtins"]).and_then(|a| { - match &a.args { - AttrArgs::Eq(_, AttrArgsEq::Hir(l)) => Some(l.symbol), - _ => None, - } + get_attr(tcx, tcx.get_attrs_unchecked(def_id), &["creusot", "builtins"]).map(|a| { + a.value_str().unwrap_or_else(|| { + tcx.dcx().span_fatal( + a.span, + "Attribute `creusot::builtin` should be followed by a string.".to_string(), + ) + }) }) } pub(crate) fn opacity_witness_name(tcx: TyCtxt, def_id: DefId) -> Option { - get_attr(tcx.get_attrs_unchecked(def_id), &["creusot", "clause", "open"]).and_then(|item| { - match &item.args { - AttrArgs::Eq(_, AttrArgsEq::Hir(l)) => Some(l.symbol), - _ => None, - } - }) + get_attr(tcx, tcx.get_attrs_unchecked(def_id), &["creusot", "clause", "open"]) + .map(|a| a.value_str().expect("invalid creusot::clause::open")) } pub(crate) fn why3_attrs(tcx: TyCtxt, def_id: DefId) -> Vec { @@ -122,39 +116,15 @@ pub(crate) fn creusot_clause_attrs<'tcx>( } pub(crate) fn get_creusot_item(tcx: TyCtxt, def_id: DefId) -> Option { - match &get_attr(tcx.get_attrs_unchecked(def_id), &["creusot", "item"])?.args { - AttrArgs::Eq(_, AttrArgsEq::Hir(l)) => Some(l.symbol), - _ => unreachable!("invalid creusot::item attribute"), - } -} - -pub(crate) fn is_open_inv_param(p: &Param) -> bool { - return get_attr(&p.attrs, &["creusot", "open_inv"]).is_some(); + Some( + get_attr(tcx, tcx.get_attrs_unchecked(def_id), &["creusot", "item"])? + .value_str() + .expect("invalid creusot::item attribute"), + ) } -fn get_attr<'a>( - attrs: &'a [rustc_ast::Attribute], - path: &[&str], -) -> Option<&'a rustc_ast::AttrItem> { - for attr in attrs.iter() { - if attr.is_doc_comment() { - continue; - } - - let attr = attr.get_normal_item(); - - if attr.path.segments.len() != path.len() { - continue; - } - - let matches = - attr.path.segments.iter().zip(path.iter()).all(|(seg, s)| seg.ident.as_str() == *s); - - if matches { - return Some(attr); - } - } - None +pub(crate) fn is_open_inv_param<'tcx>(tcx: TyCtxt<'tcx>, p: &Param) -> bool { + return get_attr(tcx, &p.attrs, &["creusot", "open_inv"]).is_some(); } fn get_attrs<'a>(attrs: &'a [Attribute], path: &[&str]) -> Vec<&'a Attribute> { @@ -180,3 +150,16 @@ fn get_attrs<'a>(attrs: &'a [Attribute], path: &[&str]) -> Vec<&'a Attribute> { } matched } + +fn get_attr<'a, 'tcx>( + tcx: TyCtxt<'tcx>, + attrs: &'a [rustc_ast::Attribute], + path: &[&str], +) -> Option<&'a Attribute> { + let matched = get_attrs(attrs, path); + match matched.len() { + 0 => return None, + 1 => return Some(matched[0]), + _ => tcx.dcx().span_fatal(matched[0].span, "Unexpected duplicate attribute.".to_string()), + } +} diff --git a/creusot/src/ctx.rs b/creusot/src/ctx.rs index 83e56d5777..4c20f36c66 100644 --- a/creusot/src/ctx.rs +++ b/creusot/src/ctx.rs @@ -174,8 +174,8 @@ impl<'tcx> Deref for TranslationCtx<'tcx> { } fn gather_params_open_inv(tcx: TyCtxt) -> HashMap> { - struct VisitFns<'a>(HashMap>, &'a ResolverAstLowering); - impl<'a> Visitor<'a> for VisitFns<'a> { + struct VisitFns<'tcx, 'a>(TyCtxt<'tcx>, HashMap>, &'a ResolverAstLowering); + impl<'tcx, 'a> Visitor<'a> for VisitFns<'tcx, 'a> { fn visit_fn(&mut self, fk: FnKind<'a>, _: Span, node: NodeId) { let decl = match fk { FnKind::Fn(_, _, FnSig { decl, .. }, _, _, _) => decl, @@ -183,21 +183,21 @@ fn gather_params_open_inv(tcx: TyCtxt) -> HashMap> { }; let mut open_inv_params = vec![]; for (i, p) in decl.inputs.iter().enumerate() { - if is_open_inv_param(p) { + if is_open_inv_param(self.0, p) { open_inv_params.push(i); } } - let defid = self.1.node_id_to_def_id[&node].to_def_id(); - assert!(self.0.insert(defid, open_inv_params).is_none()); + let defid = self.2.node_id_to_def_id[&node].to_def_id(); + assert!(self.1.insert(defid, open_inv_params).is_none()); walk_fn(self, fk) } } let (resolver, cr) = &*tcx.resolver_for_lowering().borrow(); - let mut visit = VisitFns(HashMap::new(), resolver); + let mut visit = VisitFns(tcx, HashMap::new(), resolver); visit.visit_crate(cr); - visit.0 + visit.1 } impl<'tcx> TranslationCtx<'tcx> {