From b3c8e30b688375870e7c997250efddcbc72be492 Mon Sep 17 00:00:00 2001 From: Yann Hamdaoui Date: Fri, 2 Aug 2024 08:36:42 +0200 Subject: [PATCH] [Optimization] Cache contract generation and pre-compile some match expression (#2013) * Pre-generate contracts After the introduction of a proper node for types in the AST, we switched to a lazy way of generating contract: we keep the type as it is, and convert it to a contract only once it's actually applied. While simpler, this approach has the drawback of potentially wasting computations by running the contract generation code many times for the same contract, as some metrics showed (up to 1.6 thousand times on large codebases). This commit add a `contract` field to `Term::Type`, and statically generates - in the parser - the contract corresponding to a type. This is what we used to do before the introduction of a type node. Doing so, we only generate at most one contract per user annotation, regardless of how many times the contract is applied at runtime. * Pre-compile match in enum contract This commit saves some re-compilation of the match expression generated by contracts converted from enum types by generating the compiled version directly instead of a match. * Fix unused import warnings --- core/src/closurize.rs | 5 +-- core/src/eval/mod.rs | 2 +- core/src/eval/operation.rs | 28 ++++++++++++----- core/src/parser/uniterm.rs | 22 +++++++++----- core/src/pretty.rs | 2 +- core/src/term/mod.rs | 54 +++++++++++++++++++++++---------- core/src/transform/free_vars.rs | 5 +-- core/src/typ.rs | 9 ++++-- core/src/typecheck/eq.rs | 14 ++++++++- core/src/typecheck/mod.rs | 6 ++-- lsp/nls/src/field_walker.rs | 2 +- 11 files changed, 105 insertions(+), 44 deletions(-) diff --git a/core/src/closurize.rs b/core/src/closurize.rs index 62f7ef7f53..7eec2d420b 100644 --- a/core/src/closurize.rs +++ b/core/src/closurize.rs @@ -289,9 +289,10 @@ pub fn should_share(t: &Term) -> bool { | Term::Var(_) | Term::Enum(_) | Term::Fun(_, _) + | Term::Closure(_) + | Term::Type { .. } // match acts like a function, and is a WHNF - | Term::Match {..} - | Term::Type(_) => false, + | Term::Match {..} => false, _ => true, } } diff --git a/core/src/eval/mod.rs b/core/src/eval/mod.rs index cd0c4dd351..93c0da0441 100644 --- a/core/src/eval/mod.rs +++ b/core/src/eval/mod.rs @@ -1165,7 +1165,7 @@ pub fn subst( // We could recurse here, because types can contain terms which would then be subject to // substitution. Not recursing should be fine, though, because a type in term position // turns into a contract, and we don't substitute inside contracts either currently. - | v @ Term::Type(_) => RichTerm::new(v, pos), + | v @ Term::Type {..} => RichTerm::new(v, pos), Term::EnumVariant { tag, arg, attrs } => { let arg = subst(cache, arg, initial_env, env); diff --git a/core/src/eval/operation.rs b/core/src/eval/operation.rs index 6c8a054fd8..d6a0fa1168 100644 --- a/core/src/eval/operation.rs +++ b/core/src/eval/operation.rs @@ -247,7 +247,7 @@ impl VirtualMachine { Term::Array(..) => "Array", Term::Record(..) | Term::RecRecord(..) => "Record", Term::Lbl(..) => "Label", - Term::Type(_) => "Type", + Term::Type { .. } => "Type", Term::ForeignId(_) => "ForeignId", _ => "Other", }; @@ -1736,15 +1736,25 @@ impl VirtualMachine { } } BinaryOp::ContractApply | BinaryOp::ContractCheck => { - // The translation of a type might return any kind of contract, including e.g. a - // record or a custom contract. The result thus needs to be passed to `b_op` again. - // In that case, we don't bother tracking the argument and updating the label: this - // will be done by the next call to `b_op`. - if let Term::Type(typ) = &*t1 { + // Doing just one `if let Term::Type` and putting the call to `increment!` there + // looks sensible at first, but it's annoying to explain to rustc and clippy that + // we match on `typ` but use it only if the `metrics` feature is enabled (we get + // unused variable warning otherwise). It's simpler to just make a separate `if` + // conditionally included. + #[cfg(feature = "metrics")] + if let Term::Type { typ, .. } = &*t1 { increment!(format!( "primop:contract/apply:{}", typ.pretty_print_cap(40) )); + } + + let t1 = if let Term::Type { typ: _, contract } = &*t1 { + // The contract generation from a static type might return any kind of + // contract, including e.g. a record or a custom contract. The result needs to + // be evaluated first, and then passed to `b_op` again. In that case, we don't + // bother tracking the argument and updating the label: this will be done by + // the next call to `b_op`. // We set the stack to represent the evaluation context ` [.] label` and // proceed to evaluate `` @@ -1765,10 +1775,12 @@ impl VirtualMachine { ); return Ok(Closure { - body: typ.contract()?, + body: contract.clone(), env: env1, }); - } + } else { + t1 + }; let t2 = t2.into_owned(); diff --git a/core/src/parser/uniterm.rs b/core/src/parser/uniterm.rs index 07b1bb8c38..bcef7fc933 100644 --- a/core/src/parser/uniterm.rs +++ b/core/src/parser/uniterm.rs @@ -126,12 +126,16 @@ impl TryFrom for RichTerm { let rt = match node { UniTermNode::Var(id) => RichTerm::new(Term::Var(id), pos), UniTermNode::Record(r) => RichTerm::try_from(r)?, - UniTermNode::Type(mut ty) => { - ty.fix_type_vars(pos.unwrap())?; - if let TypeF::Flat(rt) = ty.typ { + UniTermNode::Type(mut typ) => { + typ.fix_type_vars(pos.unwrap())?; + if let TypeF::Flat(rt) = typ.typ { rt.with_pos(pos) } else { - RichTerm::new(Term::Type(ty), pos) + let contract = typ + .contract() + .map_err(|err| ParseError::UnboundTypeVariables(vec![err.0]))?; + + RichTerm::new(Term::Type { typ, contract }, pos) } } UniTermNode::Term(rt) => rt, @@ -478,7 +482,7 @@ impl TryFrom for RichTerm { let result = if ur.tail.is_some() || (ur.is_record_type() && !ur.fields.is_empty()) { let tail_span = ur.tail.as_ref().and_then(|t| t.1.into_opt()); // We unwrap all positions: at this stage of the parsing, they must all be set - let mut ty = ur + let mut typ = ur .into_type_strict() .map_err(|cause| ParseError::InvalidRecordType { tail_span, @@ -486,8 +490,12 @@ impl TryFrom for RichTerm { cause, })?; - ty.fix_type_vars(pos.unwrap())?; - Ok(RichTerm::new(Term::Type(ty), pos)) + typ.fix_type_vars(pos.unwrap())?; + let contract = typ + .contract() + .map_err(|err| ParseError::UnboundTypeVariables(vec![err.0]))?; + + Ok(RichTerm::new(Term::Type { typ, contract }, pos)) } else { ur.check_typed_field_without_def()?; diff --git a/core/src/pretty.rs b/core/src/pretty.rs index 58ac14f99d..67d7820f03 100644 --- a/core/src/pretty.rs +++ b/core/src/pretty.rs @@ -1038,7 +1038,7 @@ where .append(allocator.as_string(f.to_string_lossy()).double_quotes()), ResolvedImport(id) => allocator.text(format!("import ")), // This type is in term position, so we don't need to add parentheses. - Type(ty) => ty.pretty(allocator), + Type { typ, contract: _ } => typ.pretty(allocator), ParseError(_) => allocator.text("%"), RuntimeError(_) => allocator.text("%"), Closure(idx) => allocator.text(format!("%")), diff --git a/core/src/term/mod.rs b/core/src/term/mod.rs index acf87974a7..3e50f8fbb7 100644 --- a/core/src/term/mod.rs +++ b/core/src/term/mod.rs @@ -21,8 +21,7 @@ use string::NickelString; use crate::{ error::{EvalError, ParseError}, - eval::cache::CacheIndex, - eval::Environment, + eval::{cache::CacheIndex, Environment}, identifier::LocIdent, impl_display_from_pretty, label::{Label, MergeLabel}, @@ -217,7 +216,14 @@ pub enum Term { /// /// During evaluation, this will get turned into a contract. #[serde(skip)] - Type(Type), + Type { + /// The static type. + typ: Type, + /// The conversion of this type to a contract, that is, `typ.contract()?`. This field + /// serves as a caching mechanism so we only run the contract generation code once per type + /// written by the user. + contract: RichTerm, + }, /// A custom contract. The content must be a function (or function-like terms like a match /// expression) of two arguments: a label and the value to be checked. In particular, it must @@ -362,7 +368,16 @@ impl PartialEq for Term { (Self::Annotated(l0, l1), Self::Annotated(r0, r1)) => l0 == r0 && l1 == r1, (Self::Import(l0), Self::Import(r0)) => l0 == r0, (Self::ResolvedImport(l0), Self::ResolvedImport(r0)) => l0 == r0, - (Self::Type(l0), Self::Type(r0)) => l0 == r0, + ( + Self::Type { + typ: l0, + contract: l1, + }, + Self::Type { + typ: r0, + contract: r1, + }, + ) => l0 == r0 && l1 == r1, (Self::ParseError(l0), Self::ParseError(r0)) => l0 == r0, (Self::RuntimeError(l0), Self::RuntimeError(r0)) => l0 == r0, // We don't compare closure, because we can't, without the evaluation cache at hand. @@ -948,7 +963,7 @@ impl Term { Term::SealingKey(_) => Some("SealingKey".to_owned()), Term::Sealed(..) => Some("Sealed".to_owned()), Term::Annotated(..) => Some("Annotated".to_owned()), - Term::Type(_) => Some("Type".to_owned()), + Term::Type { .. } => Some("Type".to_owned()), Term::ForeignId(_) => Some("ForeignId".to_owned()), Term::CustomContract(_) => Some("CustomContract".to_owned()), Term::Let(..) @@ -998,9 +1013,9 @@ impl Term { | Term::EnumVariant {..} | Term::Record(..) | Term::Array(..) - | Term::Type(_) | Term::ForeignId(_) - | Term::SealingKey(_) => true, + | Term::SealingKey(_) + | Term::Type {..} => true, Term::Let(..) | Term::LetPattern(..) | Term::FunPattern(..) @@ -1078,7 +1093,7 @@ impl Term { | Term::ResolvedImport(_) | Term::StrChunks(_) | Term::RecRecord(..) - | Term::Type(_) + | Term::Type { .. } | Term::ParseError(_) | Term::EnumVariant { .. } | Term::RuntimeError(_) => false, @@ -1115,6 +1130,7 @@ impl Term { | Term::Op1(UnaryOp::BoolOr, _) => true, // A number with a minus sign as a prefix isn't a proper atom Term::Num(n) if *n >= 0 => true, + Term::Type {typ, contract: _} => typ.fmt_is_atom(), Term::Let(..) | Term::Num(..) | Term::EnumVariant {..} @@ -1131,7 +1147,6 @@ impl Term { | Term::Annotated(..) | Term::Import(..) | Term::ResolvedImport(..) - | Term::Type(_) | Term::Closure(_) | Term::ParseError(_) | Term::RuntimeError(_) => false, @@ -2322,8 +2337,11 @@ impl Traverse for RichTerm { let term = term.traverse(f, order)?; RichTerm::new(Term::Annotated(annot, term), pos) } - Term::Type(ty) => { - RichTerm::new(Term::Type(ty.traverse(f, order)?), pos) + Term::Type { typ, contract } => { + let typ = typ.traverse(f, order)?; + let contract = contract.traverse(f, order)?; + + RichTerm::new(Term::Type { typ, contract }, pos) } _ => rt, }); @@ -2417,7 +2435,10 @@ impl Traverse for RichTerm { Term::Annotated(annot, t) => t .traverse_ref(f, state) .or_else(|| annot.traverse_ref(f, state)), - Term::Type(ty) => ty.traverse_ref(f, state), + Term::Type { typ, contract } => { + typ.traverse_ref(f, state)?; + contract.traverse_ref(f, state) + } } } } @@ -2430,9 +2451,10 @@ impl Traverse for RichTerm { self.traverse( &mut |rt: RichTerm| { match_sharedterm!(match (rt.term) { - Term::Type(ty) => ty - .traverse(f, order) - .map(|ty| RichTerm::new(Term::Type(ty), rt.pos)), + Term::Type { typ, contract } => { + let typ = typ.traverse(f, order)?; + Ok(RichTerm::new(Term::Type { typ, contract }, rt.pos)) + } _ => Ok(rt), }) }, @@ -2447,7 +2469,7 @@ impl Traverse for RichTerm { ) -> Option { self.traverse_ref( &mut |rt: &RichTerm, state: &S| match &*rt.term { - Term::Type(ty) => ty.traverse_ref(f, state).into(), + Term::Type { typ, contract: _ } => typ.traverse_ref(f, state).into(), _ => TraverseControl::Continue, }, state, diff --git a/core/src/transform/free_vars.rs b/core/src/transform/free_vars.rs index 823719c485..e4ab64d3fc 100644 --- a/core/src/transform/free_vars.rs +++ b/core/src/transform/free_vars.rs @@ -180,8 +180,9 @@ impl CollectFreeVars for RichTerm { t.collect_free_vars(free_vars); } - Term::Type(ty) => { - ty.collect_free_vars(free_vars); + Term::Type { typ, contract } => { + typ.collect_free_vars(free_vars); + contract.collect_free_vars(free_vars); } Term::Closure(_) => { unreachable!("should never see closures at the transformation stage"); diff --git a/core/src/typ.rs b/core/src/typ.rs index a7652c4bf2..1370272fc6 100644 --- a/core/src/typ.rs +++ b/core/src/typ.rs @@ -51,6 +51,7 @@ use crate::{ position::TermPos, pretty::PrettyPrintCap, stdlib::internals, + term::pattern::compile::Compile, term::{ array::Array, make as mk_term, record::RecordData, string::NickelString, IndexMap, MatchBranch, MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder, @@ -987,8 +988,8 @@ impl Subcontract for EnumRows { // For example, for an enum type [| 'foo, 'bar, 'Baz T |], the function looks like: // // ``` - // fun l x => - // x |> match { + // fun label value => + // value |> match { // 'foo => 'Ok x, // 'bar => 'Ok x, // 'Baz variant_arg => 'Ok ('Baz (%apply_contract% T label_arg variant_arg)), @@ -1075,7 +1076,9 @@ impl Subcontract for EnumRows { body: default, }); - let match_expr = mk_app!(Term::Match(MatchData { branches }), mk_term::var(value_arg)); + // We pre-compile the match expression, so that it's not compiled again and again at each + // application of the contract. + let match_expr = MatchData { branches }.compile(mk_term::var(value_arg), TermPos::None); let case = mk_fun!(label_arg, value_arg, match_expr); Ok(mk_app!(internals::enumeration(), case)) diff --git a/core/src/typecheck/eq.rs b/core/src/typecheck/eq.rs index e6fb49ab03..0cceb9e377 100644 --- a/core/src/typecheck/eq.rs +++ b/core/src/typecheck/eq.rs @@ -463,7 +463,19 @@ fn contract_eq_bounded( (Op1(UnaryOp::RecordAccess(id1), t1), Op1(UnaryOp::RecordAccess(id2), t2)) => { id1 == id2 && contract_eq_bounded(state, t1, env1, t2, env2) } - (Type(ty1), Type(ty2)) => type_eq_bounded( + // Contract is just a caching mechanism. `typ` should be the source of truth for equality + // (and it's probably easier to prove that type are equal rather than their generated + // contract version). + ( + Type { + typ: ty1, + contract: _, + }, + Type { + typ: ty2, + contract: _, + }, + ) => type_eq_bounded( state, &GenericUnifType::from_type(ty1.clone(), env1), env1, diff --git a/core/src/typecheck/mod.rs b/core/src/typecheck/mod.rs index ff5823dc7a..a003386d68 100644 --- a/core/src/typecheck/mod.rs +++ b/core/src/typecheck/mod.rs @@ -1619,7 +1619,9 @@ fn walk( Term::Annotated(annot, rt) => { walk_annotated(state, ctxt, visitor, annot, rt) } - Term::Type(ty) => walk_type(state, ctxt, visitor, ty), + // The contract field is just a caching mechanism, and should be set to `None` at this + // point anyway. We can safely ignore it. + Term::Type { typ, contract: _ } => walk_type(state, ctxt, visitor, typ), Term::Closure(_) => unreachable!("should never see a closure at typechecking time"), } } @@ -2350,7 +2352,7 @@ fn check( ty.unify(ty_import, state, &ctxt) .map_err(|err| err.into_typecheck_err(state, rt.pos)) } - Term::Type(typ) => { + Term::Type { typ, contract: _ } => { if let Some(flat) = typ.find_flat() { Err(TypecheckError::FlatTypeInTermPosition { flat, pos: *pos }) } else { diff --git a/lsp/nls/src/field_walker.rs b/lsp/nls/src/field_walker.rs index b62d68a671..bf5674b8ea 100644 --- a/lsp/nls/src/field_walker.rs +++ b/lsp/nls/src/field_walker.rs @@ -440,7 +440,7 @@ impl<'a> FieldResolver<'a> { let defs = self.resolve_annot(annot); defs.chain(self.resolve_container(term)).collect() } - Term::Type(typ) => self.resolve_type(typ), + Term::Type { typ, contract: _ } => self.resolve_type(typ), _ => Default::default(), };