Skip to content

Commit

Permalink
[Optimization] Cache contract generation and pre-compile some match e…
Browse files Browse the repository at this point in the history
…xpression (#2013)

* Pre-generate contracts

After the introduction of a proper node for types in the AST, we
switched to a lazy way of generating contract: we keep the type as it
is, and convert it to a contract only once it's actually applied. While
simpler, this approach has the drawback of potentially wasting
computations by running the contract generation code many times for the
same contract, as some metrics showed (up to 1.6 thousand times on large
codebases).

This commit add a `contract` field to `Term::Type`, and statically
generates - in the parser - the contract corresponding to a type. This
is what we used to do before the introduction of a type node. Doing so,
we only generate at most one contract per user annotation, regardless of
how many times the contract is applied at runtime.

* Pre-compile match in enum contract

This commit saves some re-compilation of the match expression generated
by contracts converted from enum types by generating the compiled
version directly instead of a match.

* Fix unused import warnings
  • Loading branch information
yannham authored Aug 2, 2024
1 parent 734bcc9 commit b3c8e30
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 44 deletions.
5 changes: 3 additions & 2 deletions core/src/closurize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,10 @@ pub fn should_share(t: &Term) -> bool {
| Term::Var(_)
| Term::Enum(_)
| Term::Fun(_, _)
| Term::Closure(_)
| Term::Type { .. }
// match acts like a function, and is a WHNF
| Term::Match {..}
| Term::Type(_) => false,
| Term::Match {..} => false,
_ => true,
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ pub fn subst<C: Cache>(
// We could recurse here, because types can contain terms which would then be subject to
// substitution. Not recursing should be fine, though, because a type in term position
// turns into a contract, and we don't substitute inside contracts either currently.
| v @ Term::Type(_) => RichTerm::new(v, pos),
| v @ Term::Type {..} => RichTerm::new(v, pos),
Term::EnumVariant { tag, arg, attrs } => {
let arg = subst(cache, arg, initial_env, env);

Expand Down
28 changes: 20 additions & 8 deletions core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
Term::Array(..) => "Array",
Term::Record(..) | Term::RecRecord(..) => "Record",
Term::Lbl(..) => "Label",
Term::Type(_) => "Type",
Term::Type { .. } => "Type",
Term::ForeignId(_) => "ForeignId",
_ => "Other",
};
Expand Down Expand Up @@ -1736,15 +1736,25 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
}
}
BinaryOp::ContractApply | BinaryOp::ContractCheck => {
// The translation of a type might return any kind of contract, including e.g. a
// record or a custom contract. The result thus needs to be passed to `b_op` again.
// In that case, we don't bother tracking the argument and updating the label: this
// will be done by the next call to `b_op`.
if let Term::Type(typ) = &*t1 {
// Doing just one `if let Term::Type` and putting the call to `increment!` there
// looks sensible at first, but it's annoying to explain to rustc and clippy that
// we match on `typ` but use it only if the `metrics` feature is enabled (we get
// unused variable warning otherwise). It's simpler to just make a separate `if`
// conditionally included.
#[cfg(feature = "metrics")]
if let Term::Type { typ, .. } = &*t1 {
increment!(format!(
"primop:contract/apply:{}",
typ.pretty_print_cap(40)
));
}

let t1 = if let Term::Type { typ: _, contract } = &*t1 {
// The contract generation from a static type might return any kind of
// contract, including e.g. a record or a custom contract. The result needs to
// be evaluated first, and then passed to `b_op` again. In that case, we don't
// bother tracking the argument and updating the label: this will be done by
// the next call to `b_op`.

// We set the stack to represent the evaluation context `<b_op> [.] label` and
// proceed to evaluate `<typ.contract()>`
Expand All @@ -1765,10 +1775,12 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {
);

return Ok(Closure {
body: typ.contract()?,
body: contract.clone(),
env: env1,
});
}
} else {
t1
};

let t2 = t2.into_owned();

Expand Down
22 changes: 15 additions & 7 deletions core/src/parser/uniterm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,16 @@ impl TryFrom<UniTerm> for RichTerm {
let rt = match node {
UniTermNode::Var(id) => RichTerm::new(Term::Var(id), pos),
UniTermNode::Record(r) => RichTerm::try_from(r)?,
UniTermNode::Type(mut ty) => {
ty.fix_type_vars(pos.unwrap())?;
if let TypeF::Flat(rt) = ty.typ {
UniTermNode::Type(mut typ) => {
typ.fix_type_vars(pos.unwrap())?;
if let TypeF::Flat(rt) = typ.typ {
rt.with_pos(pos)
} else {
RichTerm::new(Term::Type(ty), pos)
let contract = typ
.contract()
.map_err(|err| ParseError::UnboundTypeVariables(vec![err.0]))?;

RichTerm::new(Term::Type { typ, contract }, pos)
}
}
UniTermNode::Term(rt) => rt,
Expand Down Expand Up @@ -478,16 +482,20 @@ impl TryFrom<UniRecord> for RichTerm {
let result = if ur.tail.is_some() || (ur.is_record_type() && !ur.fields.is_empty()) {
let tail_span = ur.tail.as_ref().and_then(|t| t.1.into_opt());
// We unwrap all positions: at this stage of the parsing, they must all be set
let mut ty = ur
let mut typ = ur
.into_type_strict()
.map_err(|cause| ParseError::InvalidRecordType {
tail_span,
record_span: pos.unwrap(),
cause,
})?;

ty.fix_type_vars(pos.unwrap())?;
Ok(RichTerm::new(Term::Type(ty), pos))
typ.fix_type_vars(pos.unwrap())?;
let contract = typ
.contract()
.map_err(|err| ParseError::UnboundTypeVariables(vec![err.0]))?;

Ok(RichTerm::new(Term::Type { typ, contract }, pos))
} else {
ur.check_typed_field_without_def()?;

Expand Down
2 changes: 1 addition & 1 deletion core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ where
.append(allocator.as_string(f.to_string_lossy()).double_quotes()),
ResolvedImport(id) => allocator.text(format!("import <file_id: {id:?}>")),
// This type is in term position, so we don't need to add parentheses.
Type(ty) => ty.pretty(allocator),
Type { typ, contract: _ } => typ.pretty(allocator),
ParseError(_) => allocator.text("%<PARSE ERROR>"),
RuntimeError(_) => allocator.text("%<RUNTIME ERROR>"),
Closure(idx) => allocator.text(format!("%<closure@{idx:p}>")),
Expand Down
54 changes: 38 additions & 16 deletions core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ use string::NickelString;

use crate::{
error::{EvalError, ParseError},
eval::cache::CacheIndex,
eval::Environment,
eval::{cache::CacheIndex, Environment},
identifier::LocIdent,
impl_display_from_pretty,
label::{Label, MergeLabel},
Expand Down Expand Up @@ -217,7 +216,14 @@ pub enum Term {
///
/// During evaluation, this will get turned into a contract.
#[serde(skip)]
Type(Type),
Type {
/// The static type.
typ: Type,
/// The conversion of this type to a contract, that is, `typ.contract()?`. This field
/// serves as a caching mechanism so we only run the contract generation code once per type
/// written by the user.
contract: RichTerm,
},

/// A custom contract. The content must be a function (or function-like terms like a match
/// expression) of two arguments: a label and the value to be checked. In particular, it must
Expand Down Expand Up @@ -362,7 +368,16 @@ impl PartialEq for Term {
(Self::Annotated(l0, l1), Self::Annotated(r0, r1)) => l0 == r0 && l1 == r1,
(Self::Import(l0), Self::Import(r0)) => l0 == r0,
(Self::ResolvedImport(l0), Self::ResolvedImport(r0)) => l0 == r0,
(Self::Type(l0), Self::Type(r0)) => l0 == r0,
(
Self::Type {
typ: l0,
contract: l1,
},
Self::Type {
typ: r0,
contract: r1,
},
) => l0 == r0 && l1 == r1,
(Self::ParseError(l0), Self::ParseError(r0)) => l0 == r0,
(Self::RuntimeError(l0), Self::RuntimeError(r0)) => l0 == r0,
// We don't compare closure, because we can't, without the evaluation cache at hand.
Expand Down Expand Up @@ -948,7 +963,7 @@ impl Term {
Term::SealingKey(_) => Some("SealingKey".to_owned()),
Term::Sealed(..) => Some("Sealed".to_owned()),
Term::Annotated(..) => Some("Annotated".to_owned()),
Term::Type(_) => Some("Type".to_owned()),
Term::Type { .. } => Some("Type".to_owned()),
Term::ForeignId(_) => Some("ForeignId".to_owned()),
Term::CustomContract(_) => Some("CustomContract".to_owned()),
Term::Let(..)
Expand Down Expand Up @@ -998,9 +1013,9 @@ impl Term {
| Term::EnumVariant {..}
| Term::Record(..)
| Term::Array(..)
| Term::Type(_)
| Term::ForeignId(_)
| Term::SealingKey(_) => true,
| Term::SealingKey(_)
| Term::Type {..} => true,
Term::Let(..)
| Term::LetPattern(..)
| Term::FunPattern(..)
Expand Down Expand Up @@ -1078,7 +1093,7 @@ impl Term {
| Term::ResolvedImport(_)
| Term::StrChunks(_)
| Term::RecRecord(..)
| Term::Type(_)
| Term::Type { .. }
| Term::ParseError(_)
| Term::EnumVariant { .. }
| Term::RuntimeError(_) => false,
Expand Down Expand Up @@ -1115,6 +1130,7 @@ impl Term {
| Term::Op1(UnaryOp::BoolOr, _) => true,
// A number with a minus sign as a prefix isn't a proper atom
Term::Num(n) if *n >= 0 => true,
Term::Type {typ, contract: _} => typ.fmt_is_atom(),
Term::Let(..)
| Term::Num(..)
| Term::EnumVariant {..}
Expand All @@ -1131,7 +1147,6 @@ impl Term {
| Term::Annotated(..)
| Term::Import(..)
| Term::ResolvedImport(..)
| Term::Type(_)
| Term::Closure(_)
| Term::ParseError(_)
| Term::RuntimeError(_) => false,
Expand Down Expand Up @@ -2322,8 +2337,11 @@ impl Traverse<RichTerm> for RichTerm {
let term = term.traverse(f, order)?;
RichTerm::new(Term::Annotated(annot, term), pos)
}
Term::Type(ty) => {
RichTerm::new(Term::Type(ty.traverse(f, order)?), pos)
Term::Type { typ, contract } => {
let typ = typ.traverse(f, order)?;
let contract = contract.traverse(f, order)?;

RichTerm::new(Term::Type { typ, contract }, pos)
}
_ => rt,
});
Expand Down Expand Up @@ -2417,7 +2435,10 @@ impl Traverse<RichTerm> for RichTerm {
Term::Annotated(annot, t) => t
.traverse_ref(f, state)
.or_else(|| annot.traverse_ref(f, state)),
Term::Type(ty) => ty.traverse_ref(f, state),
Term::Type { typ, contract } => {
typ.traverse_ref(f, state)?;
contract.traverse_ref(f, state)
}
}
}
}
Expand All @@ -2430,9 +2451,10 @@ impl Traverse<Type> for RichTerm {
self.traverse(
&mut |rt: RichTerm| {
match_sharedterm!(match (rt.term) {
Term::Type(ty) => ty
.traverse(f, order)
.map(|ty| RichTerm::new(Term::Type(ty), rt.pos)),
Term::Type { typ, contract } => {
let typ = typ.traverse(f, order)?;
Ok(RichTerm::new(Term::Type { typ, contract }, rt.pos))
}
_ => Ok(rt),
})
},
Expand All @@ -2447,7 +2469,7 @@ impl Traverse<Type> for RichTerm {
) -> Option<U> {
self.traverse_ref(
&mut |rt: &RichTerm, state: &S| match &*rt.term {
Term::Type(ty) => ty.traverse_ref(f, state).into(),
Term::Type { typ, contract: _ } => typ.traverse_ref(f, state).into(),
_ => TraverseControl::Continue,
},
state,
Expand Down
5 changes: 3 additions & 2 deletions core/src/transform/free_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ impl CollectFreeVars for RichTerm {

t.collect_free_vars(free_vars);
}
Term::Type(ty) => {
ty.collect_free_vars(free_vars);
Term::Type { typ, contract } => {
typ.collect_free_vars(free_vars);
contract.collect_free_vars(free_vars);
}
Term::Closure(_) => {
unreachable!("should never see closures at the transformation stage");
Expand Down
9 changes: 6 additions & 3 deletions core/src/typ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ use crate::{
position::TermPos,
pretty::PrettyPrintCap,
stdlib::internals,
term::pattern::compile::Compile,
term::{
array::Array, make as mk_term, record::RecordData, string::NickelString, IndexMap,
MatchBranch, MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder,
Expand Down Expand Up @@ -987,8 +988,8 @@ impl Subcontract for EnumRows {
// For example, for an enum type [| 'foo, 'bar, 'Baz T |], the function looks like:
//
// ```
// fun l x =>
// x |> match {
// fun label value =>
// value |> match {
// 'foo => 'Ok x,
// 'bar => 'Ok x,
// 'Baz variant_arg => 'Ok ('Baz (%apply_contract% T label_arg variant_arg)),
Expand Down Expand Up @@ -1075,7 +1076,9 @@ impl Subcontract for EnumRows {
body: default,
});

let match_expr = mk_app!(Term::Match(MatchData { branches }), mk_term::var(value_arg));
// We pre-compile the match expression, so that it's not compiled again and again at each
// application of the contract.
let match_expr = MatchData { branches }.compile(mk_term::var(value_arg), TermPos::None);

let case = mk_fun!(label_arg, value_arg, match_expr);
Ok(mk_app!(internals::enumeration(), case))
Expand Down
14 changes: 13 additions & 1 deletion core/src/typecheck/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,19 @@ fn contract_eq_bounded<E: TermEnvironment>(
(Op1(UnaryOp::RecordAccess(id1), t1), Op1(UnaryOp::RecordAccess(id2), t2)) => {
id1 == id2 && contract_eq_bounded(state, t1, env1, t2, env2)
}
(Type(ty1), Type(ty2)) => type_eq_bounded(
// Contract is just a caching mechanism. `typ` should be the source of truth for equality
// (and it's probably easier to prove that type are equal rather than their generated
// contract version).
(
Type {
typ: ty1,
contract: _,
},
Type {
typ: ty2,
contract: _,
},
) => type_eq_bounded(
state,
&GenericUnifType::from_type(ty1.clone(), env1),
env1,
Expand Down
6 changes: 4 additions & 2 deletions core/src/typecheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,9 @@ fn walk<V: TypecheckVisitor>(
Term::Annotated(annot, rt) => {
walk_annotated(state, ctxt, visitor, annot, rt)
}
Term::Type(ty) => walk_type(state, ctxt, visitor, ty),
// The contract field is just a caching mechanism, and should be set to `None` at this
// point anyway. We can safely ignore it.
Term::Type { typ, contract: _ } => walk_type(state, ctxt, visitor, typ),
Term::Closure(_) => unreachable!("should never see a closure at typechecking time"),
}
}
Expand Down Expand Up @@ -2350,7 +2352,7 @@ fn check<V: TypecheckVisitor>(
ty.unify(ty_import, state, &ctxt)
.map_err(|err| err.into_typecheck_err(state, rt.pos))
}
Term::Type(typ) => {
Term::Type { typ, contract: _ } => {
if let Some(flat) = typ.find_flat() {
Err(TypecheckError::FlatTypeInTermPosition { flat, pos: *pos })
} else {
Expand Down
2 changes: 1 addition & 1 deletion lsp/nls/src/field_walker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ impl<'a> FieldResolver<'a> {
let defs = self.resolve_annot(annot);
defs.chain(self.resolve_container(term)).collect()
}
Term::Type(typ) => self.resolve_type(typ),
Term::Type { typ, contract: _ } => self.resolve_type(typ),
_ => Default::default(),
};

Expand Down

0 comments on commit b3c8e30

Please sign in to comment.