Skip to content

Commit

Permalink
Merge pull request #418 from AeneasVerif/son/lean
Browse files Browse the repository at this point in the history
Bump Lean to v4.15.0 and fix some issues
  • Loading branch information
sonmarcho authored Jan 17, 2025
2 parents 866188d + ee78db1 commit 35058c1
Show file tree
Hide file tree
Showing 47 changed files with 382 additions and 358 deletions.
9 changes: 6 additions & 3 deletions backends/lean/Base/Arith/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,17 @@ def intTac (tacName : String) (splitAllDisjs splitGoalConjs : Bool)
-/
Utils.tryTac (
-- TODO: is there a simproc to simplify propositional logic?
Utils.simpAll {failIfUnchanged := false, maxSteps := 75} true [``reduceIte] []
Utils.simpAll {failIfUnchanged := false, maxSteps := 1000} true [``reduceIte] []
[``and_self, ``false_implies, ``true_implies, ``Prod.mk.injEq,
``not_false_eq_true, ``not_true_eq_false,
``true_and, ``and_true, ``false_and, ``and_false,
``true_or, ``or_true,``false_or, ``or_false] [])
``true_or, ``or_true,``false_or, ``or_false,
``Bool.true_eq_false, ``Bool.false_eq_true] [])
allGoalsNoRecover (do
trace[Arith] "Goal after simplification: {← getMainGoal}"
Tactic.Omega.omegaTactic {})
trace[Arith] "Calling omega"
Tactic.Omega.omegaTactic {}
trace[Arith] "Omega solved the goal")
if splitAllDisjs then do
/- In order to improve performance, we first try to prove the goal without splitting. If it
fails, we split. -/
Expand Down
6 changes: 4 additions & 2 deletions backends/lean/Base/Arith/ScalarNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ macro_rules

open Lean.Parser.Tactic in
open Mathlib.Tactic.Ring in
macro "scalar_nf" cfg:(config)? loc:(location)? : tactic =>
`(tactic| ring_nf $(cfg)? $(loc)?)
macro "scalar_nf" cfg:optConfig loc:(location)? : tactic =>
`(tactic| ring_nf $cfg:optConfig $(loc)?)

example : True := by ring_nf

end Arith
11 changes: 5 additions & 6 deletions backends/lean/Base/Diverge/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ namespace FixI
-- However, by parameterizing Funs with those parameters, we can state
-- and prove lemmas like Funs.is_valid_p_is_valid_p
inductive Funs (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :
List in_out_ty.{v, w} → Type (max (u + 1) (max (v + 1) (w + 1))) :=
List in_out_ty.{v, w} → Type (max (u + 1) (max (v + 1) (w + 1))) where
| Nil : Funs id a b []
| Cons {ity : Type v} {oty : ity → Type w} {tys : List in_out_ty}
(f : kk_ty id a b → (x:ity) → Result (oty x)) (tl : Funs id a b tys) :
Expand Down Expand Up @@ -794,7 +794,7 @@ namespace FixII
-- and prove lemmas like Funs.is_valid_p_is_valid_p
inductive Funs (id : Type u) (ty : id → Type v)
(a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) :
List in_out_ty.{v, w, x} → Type (max (u + 1) (max (v + 1) (max (w + 1) (x + 1)))) :=
List in_out_ty.{v, w, x} → Type (max (u + 1) (max (v + 1) (max (w + 1) (x + 1)))) where
| Nil : Funs id ty a b []
| Cons {it: Type v} {ity : it → Type w} {oty : it → Type x} {tys : List in_out_ty}
(f : kk_ty id ty a b → (i:it) → (x:ity i) → Result (oty i)) (tl : Funs id ty a b tys) :
Expand Down Expand Up @@ -1223,7 +1223,7 @@ namespace Ex5
apply is_valid_p_bind <;> try simp_all

/- An example which uses map -/
inductive Tree (a : Type) :=
inductive Tree (a : Type) where
| leaf (x : a)
| node (tl : List (Tree a))

Expand Down Expand Up @@ -1262,7 +1262,7 @@ namespace Ex5
:= by
have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a)
simp [id]
conv => lhs; rw [Heq]; simp; rw [id_body]
conv => lhs; rw [Heq]; simp; unfold id_body
rfl

end Ex5
Expand Down Expand Up @@ -1304,7 +1304,6 @@ namespace Ex6
intro x; try simp at x
simp only [list_nth_body]
-- Prove the validity of the individual bodies
intro k x
split <;> try simp
split <;> simp

Expand Down Expand Up @@ -1507,7 +1506,7 @@ namespace Ex9
/- An example which uses map -/
open Primitives FixII Ex8

inductive Tree (a : Type u) :=
inductive Tree (a : Type u) where
| leaf (x : a)
| node (tl : List (Tree a))

Expand Down
94 changes: 54 additions & 40 deletions backends/lean/Base/Diverge/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def mkProdsType (xl : List Expr) : MetaM Expr :=
-/
def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do
-- Look for the first parameter which appears in the subsequent parameters
let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) :=
let rec splitAux (in_tys : List Expr) : MetaM (Std.HashSet FVarId × List Expr × List Expr) :=
match in_tys with
| [] => do
let fvars ← getFVarIds (← inferType out_ty)
Expand All @@ -150,18 +150,18 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×
-- We must split later: update the fvars set
let fvars := fvars.insertMany (← getFVarIds (← inferType ty))
pure (fvars, [], ty :: in_args)
let (_, in_tys, in_args) ← splitAux in_tys.data
let (_, in_tys, in_args) ← splitAux in_tys.toList
pure (Array.mk in_tys, Array.mk in_args)

/- Apply a lambda expression to some arguments, simplifying the lambdas -/
def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do
lambdaTelescopeN e xs.size fun vars body =>
-- Create the substitution
let s : HashMap FVarId Expr := HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList)
let s : Std.HashMap FVarId Expr := Std.HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList)
-- Substitute in the body
pure (body.replace fun e =>
match e with
| Expr.fvar fvarId => match s.find? fvarId with
| Expr.fvar fvarId => match s.get? fvarId with
| none => e
| some v => v
| _ => none)
Expand Down Expand Up @@ -447,10 +447,10 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
Remark: the continuation has an indexed type; we use the index (a finite number of
type `Fin`) to control which function we call at the recursive call site. -/
let nameToInfo : HashMap Name (Nat × TypeInfo) :=
let nameToInfo : Std.HashMap Name (Nat × TypeInfo) :=
let bl := preDefs.mapIdx fun i d =>
(d.declName, (i.val, paramInOutTys.get! i.val))
HashMap.ofList bl.toList
(d.declName, (i, paramInOutTys.get! i))
Std.HashMap.ofList bl.toList

trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}"

Expand All @@ -466,7 +466,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Check if this is a recursive call
if f.isConst then
let name := f.constName!
match nameToInfo.find? name with
match nameToInfo.get? name with
| none => pure e
| some (id, type_info) =>
trace[Diverge.def.genBody.visit] "this is a recursive call"
Expand Down Expand Up @@ -499,9 +499,9 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
it without arguments (if we give it to a higher-order
function for instance) and there are actually no type parameters.
-/
if (nameToInfo.find? name).isSome then do
if (nameToInfo.get? name).isSome then do
-- Checking the type information
match nameToInfo.find? name with
match nameToInfo.get? name with
| none => pure e
| some (id, type_info) =>
trace[Diverge.def.genBody.visit] "this is a recursive call"
Expand Down Expand Up @@ -539,7 +539,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- (over which we match to retrieve the individual arguments).
lambdaTelescope body fun args body => do
-- Split the arguments between the type parameters and the "regular" inputs
let (_, type_info) := nameToInfo.find! preDef.declName
let (_, type_info) := nameToInfo.get! preDef.declName
let (param_args, args) := args.toList.splitAt type_info.num_params
let body ← mkProdsMatchOrUnit args body
trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}"
Expand Down Expand Up @@ -666,7 +666,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
pure e
else pure e
match e with
| .const _ _ => throwError "Unimplemented" -- Shouldn't get there?
| .const _ _ => proveNoKExprIsValid k_var e
| .bvar _
| .fvar _
| .lit _
Expand Down Expand Up @@ -853,7 +853,7 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args :
- if yes: we have to lookup theorems in div spec database and continue -/
trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}"
let argsFVars ← args.mapM getFVarIds
let allArgsFVars := argsFVars.foldl (fun hs fvars => hs.insertMany fvars) HashSet.empty
let allArgsFVars := argsFVars.foldl (fun hs fvars => hs.insertMany fvars) Std.HashSet.empty
trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}"
if ¬ allArgsFVars.contains kk_var.fvarId! then do
-- Simple case
Expand All @@ -878,8 +878,8 @@ partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr)
-- Introduce fresh meta-variables for the universes
let ul : List (Name × Level) ←
thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar))
let ulMap : HashMap Name Level := HashMap.ofList ul
let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x)
let ulMap : Std.HashMap Name Level := Std.HashMap.ofList ul
let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.get! x)
trace[Diverge.def.valid] "Trying with theorem {thName}: {thTy}"
-- Introduce meta variables for the universally quantified variables
let (mvars, _binders, thTyBody) ← forallMetaTelescope thTy
Expand Down Expand Up @@ -1133,9 +1133,9 @@ def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array TypeInfo) (preDe
let defs ← preDefs.mapIdxM fun idx preDef => do
lambdaTelescope preDef.value fun xs _ => do
-- Retrieve the parameters info
let type_info := paramInOutTys.get! idx.val
let type_info := paramInOutTys.get! idx
-- Create the index
let idx ← mkFinVal grSize idx.val
let idx ← mkFinVal grSize idx
-- Group the inputs into two tuples
let (params_args, input_args) := xs.toList.splitAt type_info.num_params
let params ← mkSigmasVal type_info.params_ty params_args
Expand Down Expand Up @@ -1256,11 +1256,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let (params, in_tys) ← splitInputArgs in_tys out_ty
trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}"
let num_params := params.size
let params_ty ← mkSigmasType params.data
let in_ty ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data)
let params_ty ← mkSigmasType params.toList
let in_ty ← mkSigmasMatchOrUnit params.toList (← mkProdsType in_tys.toList)
-- Retrieve the type in the "Result"
let out_ty ← getResultTy out_ty
let out_ty ← mkSigmasMatchOrUnit params.data out_ty
let out_ty ← mkSigmasMatchOrUnit params.toList out_ty
trace[Diverge.def] "inOutTy: {preDef.declName}: {params_ty}, {in_tys}, {out_ty}"
pure ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩))
trace[Diverge.def] "paramInOutTys: {paramInOutTys}"
Expand Down Expand Up @@ -1386,31 +1386,43 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
else return ()
catch _ => s.restore

namespace Term

-- The following three functions are copy-pasted from Lean.Elab.MutualDef.lean
open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef

-- Copy/pasted from Lean.Elab.Term.withHeaderSecVars (because the definition is private)
private def Term.withHeaderSecVars {α} (vars : Array Expr) (includedVars : List Name) (headers : Array DefViewElabHeader)
private def withHeaderSecVars {α} (vars : Array Expr) (sc : Command.Scope) (headers : Array DefViewElabHeader)
(k : Array Expr → TermElabM α) : TermElabM α := do
let (_, used) ← collectUsed.run {}
let mut revSectionFVars : Std.HashMap FVarId Name := {}
for (uid, var) in (← read).sectionFVars do
revSectionFVars := revSectionFVars.insert var.fvarId! uid
let (_, used) ← collectUsed revSectionFVars |>.run {}
let (lctx, localInsts, vars) ← removeUnused vars used
withLCtx lctx localInsts <| k vars
where
collectUsed : StateRefT CollectFVars.State MetaM Unit := do
collectUsed revSectionFVars : StateRefT CollectFVars.State MetaM Unit := do
-- directly referenced in headers
headers.forM (·.type.collectFVars)
-- included by `include`
vars.forM fun var => do
let ldecl ← getFVarLocalDecl var
if includedVars.contains ldecl.userName then
modify (·.add ldecl.fvarId)
for var in vars do
if let some uid := revSectionFVars[var.fvarId!]? then
if sc.includedVars.contains uid then
modify (·.add var.fvarId!)
-- transitively referenced
get >>= (·.addDependencies) >>= set
for var in (← get).fvarIds do
if let some uid := revSectionFVars[var]? then
if sc.omittedVars.contains uid then
throwError "cannot omit referenced section variable '{Expr.fvar var}'"
-- instances (`addDependencies` unnecessary as by definition they may only reference variables
-- already included)
vars.forM fun var => do
for var in vars do
let ldecl ← getFVarLocalDecl var
if let some uid := revSectionFVars[var.fvarId!]? then
if sc.omittedVars.contains uid then
continue
let st ← get
if ldecl.binderInfo.isInstImplicit && (← getFVars ldecl.type).all st.fvarSet.contains then
modify (·.add ldecl.fvarId)
Expand All @@ -1422,7 +1434,7 @@ def isExample (views : Array DefView) : Bool :=
views.any (·.kind.isExample)

open Language in
def Term.elabMutualDef (vars : Array Expr) (includedVars : List Name) (views : Array DefView) : TermElabM Unit :=
def elabMutualDef (vars : Array Expr) (sc : Command.Scope) (views : Array DefView) : TermElabM Unit :=
if isExample views then
withoutModifyingEnv do
-- save correct environment in info tree
Expand Down Expand Up @@ -1451,33 +1463,32 @@ where
addTermInfo' view.declId funFVar
let values ←
try
let values ← elabFunValues headers vars includedVars
let values ← elabFunValues headers vars sc
Term.synthesizeSyntheticMVarsNoPostponing
values.mapM (instantiateMVars ·)
values.mapM (instantiateMVarsProfiling ·)
catch ex =>
logException ex
headers.mapM fun header => mkSorry header.type (synthetic := true)
let headers ← headers.mapM instantiateMVarsAtHeader
let letRecsToLift ← getLetRecsToLift
let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
checkLetRecsToLiftTypes funFVars letRecsToLift
(if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars includedVars headers else withUsed vars headers values letRecsToLift) fun vars => do
(if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars sc headers else withUsed vars headers values letRecsToLift) fun vars => do
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
for preDef in preDefs do
trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamTypesPreDecls preDefs
let preDefs ← instantiateMVarsAtPreDecls preDefs
let preDefs ← shareCommonPreDefs preDefs
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
for preDef in preDefs do
trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
checkForHiddenUnivLevels allUserLevelNames preDefs
addPreDefinitions preDefs -- MODIFICATION 2: we use our custom function here
processDeriving headers
for view in views, header in headers do
-- NOTE: this should be the full `ref`, and thus needs to be done after any snapshotting
-- that depends only on a part of the ref
addDeclarationRanges header.declName view.ref
addDeclarationRangesForBuiltin header.declName view.modifiers.stx view.ref

processDeriving (headers : Array DefViewElabHeader) := do
for header in headers, view in views do
Expand All @@ -1488,7 +1499,7 @@ where
unless (← processDefDeriving className header.declName) do
throwError "failed to synthesize instance '{className}' for '{header.declName}'"

#check Command.elabMutualDef
end Term

-- Copy/pasted from Lean.Elab.MutualDef
open Command in
Expand All @@ -1502,7 +1513,7 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
let mut reusedAllHeaders := true
for h : i in [0:ds.size], headerPromise in headerPromises do
let d := ds[i]
let modifiers ← elabModifiers d[0]
let modifiers ← elabModifiers d[0]
if ds.size > 1 && modifiers.isNonrec then
throwErrorAt d "invalid use of 'nonrec' modifier in 'mutual' block"
let mut view ← mkDefView modifiers d[2] -- MODIFICATION: changed the index to 2
Expand Down Expand Up @@ -1533,8 +1544,8 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
if let some snap := snap? then
-- no non-fatal diagnostics at this point
snap.new.resolve <| .ofTyped { defs, diagnostics := .empty : DefsParsedSnapshot }
let includedVars := (← getScope).includedVars
runTermElabM fun vars => Term.elabMutualDef vars includedVars views
let sc ← getScope
runTermElabM fun vars => Term.elabMutualDef vars sc views

syntax (name := divergentDef)
declModifiers "divergent" Lean.Parser.Command.definition : command
Expand Down Expand Up @@ -1688,8 +1699,11 @@ namespace Tests

#check infinite_loop.unfold

-- Testing a degenerate case
-- Another degenerate case
def infinite_loop1_call : Result Unit := Result.ok ()
divergent def infinite_loop1 : Result Unit :=
do
infinite_loop1_call
infinite_loop1

#check infinite_loop1.unfold
Expand Down
8 changes: 2 additions & 6 deletions backends/lean/Base/Diverge/ElabBase.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ initialize divspecAttr : DivSpecAttr ← do
let (_, _, fExpr) ← lambdaMetaTelescope fExpr.consumeMData
trace[Diverge] "Registering divspec theorem for {fExpr}"
-- Convert the function expression to a discrimination tree key
-- We use the default configuration
let config : WhnfCoreConfig := {}
DiscrTree.mkPath fExpr config)
DiscrTree.mkPath fExpr)
let env := ext.addEntry env (fKey, thName)
setEnv env
trace[Diverge] "Saved the environment"
Expand All @@ -71,9 +69,7 @@ initialize divspecAttr : DivSpecAttr ← do
pure { attr := attrImpl, ext := ext }

def DivSpecAttr.find? (s : DivSpecAttr) (e : Expr) : MetaM (Array Name) := do
-- We use the default configuration
let config : WhnfCoreConfig := {}
(s.ext.getState (← getEnv)).getMatch e config
(s.ext.getState (← getEnv)).getMatch e

def DivSpecAttr.getState (s : DivSpecAttr) : MetaM (DiscrTree Name) := do
pure (s.ext.getState (← getEnv))
Expand Down
Loading

0 comments on commit 35058c1

Please sign in to comment.