diff --git a/core/src/closurize.rs b/core/src/closurize.rs index 62f7ef7f53..7eec2d420b 100644 --- a/core/src/closurize.rs +++ b/core/src/closurize.rs @@ -289,9 +289,10 @@ pub fn should_share(t: &Term) -> bool { | Term::Var(_) | Term::Enum(_) | Term::Fun(_, _) + | Term::Closure(_) + | Term::Type { .. } // match acts like a function, and is a WHNF - | Term::Match {..} - | Term::Type(_) => false, + | Term::Match {..} => false, _ => true, } } diff --git a/core/src/eval/mod.rs b/core/src/eval/mod.rs index cd0c4dd351..93c0da0441 100644 --- a/core/src/eval/mod.rs +++ b/core/src/eval/mod.rs @@ -1165,7 +1165,7 @@ pub fn subst( // We could recurse here, because types can contain terms which would then be subject to // substitution. Not recursing should be fine, though, because a type in term position // turns into a contract, and we don't substitute inside contracts either currently. - | v @ Term::Type(_) => RichTerm::new(v, pos), + | v @ Term::Type {..} => RichTerm::new(v, pos), Term::EnumVariant { tag, arg, attrs } => { let arg = subst(cache, arg, initial_env, env); diff --git a/core/src/eval/operation.rs b/core/src/eval/operation.rs index 6c8a054fd8..d6a0fa1168 100644 --- a/core/src/eval/operation.rs +++ b/core/src/eval/operation.rs @@ -247,7 +247,7 @@ impl VirtualMachine { Term::Array(..) => "Array", Term::Record(..) | Term::RecRecord(..) => "Record", Term::Lbl(..) => "Label", - Term::Type(_) => "Type", + Term::Type { .. } => "Type", Term::ForeignId(_) => "ForeignId", _ => "Other", }; @@ -1736,15 +1736,25 @@ impl VirtualMachine { } } BinaryOp::ContractApply | BinaryOp::ContractCheck => { - // The translation of a type might return any kind of contract, including e.g. a - // record or a custom contract. The result thus needs to be passed to `b_op` again. - // In that case, we don't bother tracking the argument and updating the label: this - // will be done by the next call to `b_op`. - if let Term::Type(typ) = &*t1 { + // Doing just one `if let Term::Type` and putting the call to `increment!` there + // looks sensible at first, but it's annoying to explain to rustc and clippy that + // we match on `typ` but use it only if the `metrics` feature is enabled (we get + // unused variable warning otherwise). It's simpler to just make a separate `if` + // conditionally included. + #[cfg(feature = "metrics")] + if let Term::Type { typ, .. } = &*t1 { increment!(format!( "primop:contract/apply:{}", typ.pretty_print_cap(40) )); + } + + let t1 = if let Term::Type { typ: _, contract } = &*t1 { + // The contract generation from a static type might return any kind of + // contract, including e.g. a record or a custom contract. The result needs to + // be evaluated first, and then passed to `b_op` again. In that case, we don't + // bother tracking the argument and updating the label: this will be done by + // the next call to `b_op`. // We set the stack to represent the evaluation context ` [.] label` and // proceed to evaluate `` @@ -1765,10 +1775,12 @@ impl VirtualMachine { ); return Ok(Closure { - body: typ.contract()?, + body: contract.clone(), env: env1, }); - } + } else { + t1 + }; let t2 = t2.into_owned(); diff --git a/core/src/parser/uniterm.rs b/core/src/parser/uniterm.rs index 07b1bb8c38..bcef7fc933 100644 --- a/core/src/parser/uniterm.rs +++ b/core/src/parser/uniterm.rs @@ -126,12 +126,16 @@ impl TryFrom for RichTerm { let rt = match node { UniTermNode::Var(id) => RichTerm::new(Term::Var(id), pos), UniTermNode::Record(r) => RichTerm::try_from(r)?, - UniTermNode::Type(mut ty) => { - ty.fix_type_vars(pos.unwrap())?; - if let TypeF::Flat(rt) = ty.typ { + UniTermNode::Type(mut typ) => { + typ.fix_type_vars(pos.unwrap())?; + if let TypeF::Flat(rt) = typ.typ { rt.with_pos(pos) } else { - RichTerm::new(Term::Type(ty), pos) + let contract = typ + .contract() + .map_err(|err| ParseError::UnboundTypeVariables(vec![err.0]))?; + + RichTerm::new(Term::Type { typ, contract }, pos) } } UniTermNode::Term(rt) => rt, @@ -478,7 +482,7 @@ impl TryFrom for RichTerm { let result = if ur.tail.is_some() || (ur.is_record_type() && !ur.fields.is_empty()) { let tail_span = ur.tail.as_ref().and_then(|t| t.1.into_opt()); // We unwrap all positions: at this stage of the parsing, they must all be set - let mut ty = ur + let mut typ = ur .into_type_strict() .map_err(|cause| ParseError::InvalidRecordType { tail_span, @@ -486,8 +490,12 @@ impl TryFrom for RichTerm { cause, })?; - ty.fix_type_vars(pos.unwrap())?; - Ok(RichTerm::new(Term::Type(ty), pos)) + typ.fix_type_vars(pos.unwrap())?; + let contract = typ + .contract() + .map_err(|err| ParseError::UnboundTypeVariables(vec![err.0]))?; + + Ok(RichTerm::new(Term::Type { typ, contract }, pos)) } else { ur.check_typed_field_without_def()?; diff --git a/core/src/pretty.rs b/core/src/pretty.rs index 58ac14f99d..67d7820f03 100644 --- a/core/src/pretty.rs +++ b/core/src/pretty.rs @@ -1038,7 +1038,7 @@ where .append(allocator.as_string(f.to_string_lossy()).double_quotes()), ResolvedImport(id) => allocator.text(format!("import ")), // This type is in term position, so we don't need to add parentheses. - Type(ty) => ty.pretty(allocator), + Type { typ, contract: _ } => typ.pretty(allocator), ParseError(_) => allocator.text("%"), RuntimeError(_) => allocator.text("%"), Closure(idx) => allocator.text(format!("%")), diff --git a/core/src/term/mod.rs b/core/src/term/mod.rs index acf87974a7..3e50f8fbb7 100644 --- a/core/src/term/mod.rs +++ b/core/src/term/mod.rs @@ -21,8 +21,7 @@ use string::NickelString; use crate::{ error::{EvalError, ParseError}, - eval::cache::CacheIndex, - eval::Environment, + eval::{cache::CacheIndex, Environment}, identifier::LocIdent, impl_display_from_pretty, label::{Label, MergeLabel}, @@ -217,7 +216,14 @@ pub enum Term { /// /// During evaluation, this will get turned into a contract. #[serde(skip)] - Type(Type), + Type { + /// The static type. + typ: Type, + /// The conversion of this type to a contract, that is, `typ.contract()?`. This field + /// serves as a caching mechanism so we only run the contract generation code once per type + /// written by the user. + contract: RichTerm, + }, /// A custom contract. The content must be a function (or function-like terms like a match /// expression) of two arguments: a label and the value to be checked. In particular, it must @@ -362,7 +368,16 @@ impl PartialEq for Term { (Self::Annotated(l0, l1), Self::Annotated(r0, r1)) => l0 == r0 && l1 == r1, (Self::Import(l0), Self::Import(r0)) => l0 == r0, (Self::ResolvedImport(l0), Self::ResolvedImport(r0)) => l0 == r0, - (Self::Type(l0), Self::Type(r0)) => l0 == r0, + ( + Self::Type { + typ: l0, + contract: l1, + }, + Self::Type { + typ: r0, + contract: r1, + }, + ) => l0 == r0 && l1 == r1, (Self::ParseError(l0), Self::ParseError(r0)) => l0 == r0, (Self::RuntimeError(l0), Self::RuntimeError(r0)) => l0 == r0, // We don't compare closure, because we can't, without the evaluation cache at hand. @@ -948,7 +963,7 @@ impl Term { Term::SealingKey(_) => Some("SealingKey".to_owned()), Term::Sealed(..) => Some("Sealed".to_owned()), Term::Annotated(..) => Some("Annotated".to_owned()), - Term::Type(_) => Some("Type".to_owned()), + Term::Type { .. } => Some("Type".to_owned()), Term::ForeignId(_) => Some("ForeignId".to_owned()), Term::CustomContract(_) => Some("CustomContract".to_owned()), Term::Let(..) @@ -998,9 +1013,9 @@ impl Term { | Term::EnumVariant {..} | Term::Record(..) | Term::Array(..) - | Term::Type(_) | Term::ForeignId(_) - | Term::SealingKey(_) => true, + | Term::SealingKey(_) + | Term::Type {..} => true, Term::Let(..) | Term::LetPattern(..) | Term::FunPattern(..) @@ -1078,7 +1093,7 @@ impl Term { | Term::ResolvedImport(_) | Term::StrChunks(_) | Term::RecRecord(..) - | Term::Type(_) + | Term::Type { .. } | Term::ParseError(_) | Term::EnumVariant { .. } | Term::RuntimeError(_) => false, @@ -1115,6 +1130,7 @@ impl Term { | Term::Op1(UnaryOp::BoolOr, _) => true, // A number with a minus sign as a prefix isn't a proper atom Term::Num(n) if *n >= 0 => true, + Term::Type {typ, contract: _} => typ.fmt_is_atom(), Term::Let(..) | Term::Num(..) | Term::EnumVariant {..} @@ -1131,7 +1147,6 @@ impl Term { | Term::Annotated(..) | Term::Import(..) | Term::ResolvedImport(..) - | Term::Type(_) | Term::Closure(_) | Term::ParseError(_) | Term::RuntimeError(_) => false, @@ -2322,8 +2337,11 @@ impl Traverse for RichTerm { let term = term.traverse(f, order)?; RichTerm::new(Term::Annotated(annot, term), pos) } - Term::Type(ty) => { - RichTerm::new(Term::Type(ty.traverse(f, order)?), pos) + Term::Type { typ, contract } => { + let typ = typ.traverse(f, order)?; + let contract = contract.traverse(f, order)?; + + RichTerm::new(Term::Type { typ, contract }, pos) } _ => rt, }); @@ -2417,7 +2435,10 @@ impl Traverse for RichTerm { Term::Annotated(annot, t) => t .traverse_ref(f, state) .or_else(|| annot.traverse_ref(f, state)), - Term::Type(ty) => ty.traverse_ref(f, state), + Term::Type { typ, contract } => { + typ.traverse_ref(f, state)?; + contract.traverse_ref(f, state) + } } } } @@ -2430,9 +2451,10 @@ impl Traverse for RichTerm { self.traverse( &mut |rt: RichTerm| { match_sharedterm!(match (rt.term) { - Term::Type(ty) => ty - .traverse(f, order) - .map(|ty| RichTerm::new(Term::Type(ty), rt.pos)), + Term::Type { typ, contract } => { + let typ = typ.traverse(f, order)?; + Ok(RichTerm::new(Term::Type { typ, contract }, rt.pos)) + } _ => Ok(rt), }) }, @@ -2447,7 +2469,7 @@ impl Traverse for RichTerm { ) -> Option { self.traverse_ref( &mut |rt: &RichTerm, state: &S| match &*rt.term { - Term::Type(ty) => ty.traverse_ref(f, state).into(), + Term::Type { typ, contract: _ } => typ.traverse_ref(f, state).into(), _ => TraverseControl::Continue, }, state, diff --git a/core/src/transform/free_vars.rs b/core/src/transform/free_vars.rs index 823719c485..e4ab64d3fc 100644 --- a/core/src/transform/free_vars.rs +++ b/core/src/transform/free_vars.rs @@ -180,8 +180,9 @@ impl CollectFreeVars for RichTerm { t.collect_free_vars(free_vars); } - Term::Type(ty) => { - ty.collect_free_vars(free_vars); + Term::Type { typ, contract } => { + typ.collect_free_vars(free_vars); + contract.collect_free_vars(free_vars); } Term::Closure(_) => { unreachable!("should never see closures at the transformation stage"); diff --git a/core/src/typ.rs b/core/src/typ.rs index a7652c4bf2..1370272fc6 100644 --- a/core/src/typ.rs +++ b/core/src/typ.rs @@ -51,6 +51,7 @@ use crate::{ position::TermPos, pretty::PrettyPrintCap, stdlib::internals, + term::pattern::compile::Compile, term::{ array::Array, make as mk_term, record::RecordData, string::NickelString, IndexMap, MatchBranch, MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder, @@ -987,8 +988,8 @@ impl Subcontract for EnumRows { // For example, for an enum type [| 'foo, 'bar, 'Baz T |], the function looks like: // // ``` - // fun l x => - // x |> match { + // fun label value => + // value |> match { // 'foo => 'Ok x, // 'bar => 'Ok x, // 'Baz variant_arg => 'Ok ('Baz (%apply_contract% T label_arg variant_arg)), @@ -1075,7 +1076,9 @@ impl Subcontract for EnumRows { body: default, }); - let match_expr = mk_app!(Term::Match(MatchData { branches }), mk_term::var(value_arg)); + // We pre-compile the match expression, so that it's not compiled again and again at each + // application of the contract. + let match_expr = MatchData { branches }.compile(mk_term::var(value_arg), TermPos::None); let case = mk_fun!(label_arg, value_arg, match_expr); Ok(mk_app!(internals::enumeration(), case)) diff --git a/core/src/typecheck/eq.rs b/core/src/typecheck/eq.rs index e6fb49ab03..0cceb9e377 100644 --- a/core/src/typecheck/eq.rs +++ b/core/src/typecheck/eq.rs @@ -463,7 +463,19 @@ fn contract_eq_bounded( (Op1(UnaryOp::RecordAccess(id1), t1), Op1(UnaryOp::RecordAccess(id2), t2)) => { id1 == id2 && contract_eq_bounded(state, t1, env1, t2, env2) } - (Type(ty1), Type(ty2)) => type_eq_bounded( + // Contract is just a caching mechanism. `typ` should be the source of truth for equality + // (and it's probably easier to prove that type are equal rather than their generated + // contract version). + ( + Type { + typ: ty1, + contract: _, + }, + Type { + typ: ty2, + contract: _, + }, + ) => type_eq_bounded( state, &GenericUnifType::from_type(ty1.clone(), env1), env1, diff --git a/core/src/typecheck/mod.rs b/core/src/typecheck/mod.rs index ff5823dc7a..a003386d68 100644 --- a/core/src/typecheck/mod.rs +++ b/core/src/typecheck/mod.rs @@ -1619,7 +1619,9 @@ fn walk( Term::Annotated(annot, rt) => { walk_annotated(state, ctxt, visitor, annot, rt) } - Term::Type(ty) => walk_type(state, ctxt, visitor, ty), + // The contract field is just a caching mechanism, and should be set to `None` at this + // point anyway. We can safely ignore it. + Term::Type { typ, contract: _ } => walk_type(state, ctxt, visitor, typ), Term::Closure(_) => unreachable!("should never see a closure at typechecking time"), } } @@ -2350,7 +2352,7 @@ fn check( ty.unify(ty_import, state, &ctxt) .map_err(|err| err.into_typecheck_err(state, rt.pos)) } - Term::Type(typ) => { + Term::Type { typ, contract: _ } => { if let Some(flat) = typ.find_flat() { Err(TypecheckError::FlatTypeInTermPosition { flat, pos: *pos }) } else { diff --git a/lsp/nls/src/field_walker.rs b/lsp/nls/src/field_walker.rs index b62d68a671..bf5674b8ea 100644 --- a/lsp/nls/src/field_walker.rs +++ b/lsp/nls/src/field_walker.rs @@ -440,7 +440,7 @@ impl<'a> FieldResolver<'a> { let defs = self.resolve_annot(annot); defs.chain(self.resolve_container(term)).collect() } - Term::Type(typ) => self.resolve_type(typ), + Term::Type { typ, contract: _ } => self.resolve_type(typ), _ => Default::default(), };