Skip to content

Commit

Permalink
Cleanup a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Dec 11, 2023
1 parent c23a376 commit ee669c4
Showing 1 changed file with 47 additions and 16 deletions.
63 changes: 47 additions & 16 deletions backends/lean/Base/Diverge/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ open Utils

open WF in

def mkProdType (x y : Expr) : MetaM Expr :=
mkAppM ``Prod #[x, y]

def mkProd (x y : Expr) : MetaM Expr :=
mkAppM ``Prod.mk #[x, y]

Expand All @@ -47,6 +50,17 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do
else
pure (args.get! 0, args.get! 1)

/- Make a sigma type.
`x` should be a variable, and `ty` and type which (might) uses `x`
-/
def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do
trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}"
let alpha ← inferType x
let beta ← mkLambdaFVars #[x] sty
trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})"
mkAppOptM ``Sigma #[some alpha, some beta]

/- Generate a Sigma type from a list of *variables* (all the expressions
must be variables).
Expand All @@ -64,16 +78,12 @@ def mkSigmasType (xl : List Expr) : MetaM Expr :=
pure (Expr.const ``PUnit [Level.succ .zero])
| [x] => do
trace[Diverge.def.sigmas] "mkSigmasType: [{x}]"
let ty ← Lean.Meta.inferType x
let ty ← inferType x
pure ty
| x :: xl => do
trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]"
let alpha ← Lean.Meta.inferType x
let sty ← mkSigmasType xl
trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}"
let beta ← mkLambdaFVars #[x] sty
trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})"
mkAppOptM ``Sigma #[some alpha, some beta]
mkSigmaType x sty

/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`).
Expand All @@ -90,11 +100,11 @@ def mkProdsType (xl : List Expr) : MetaM Expr :=
pure (Expr.const ``PUnit [Level.succ .zero])
| [x] => do
trace[Diverge.def.prods] "mkProdsType: [{x}]"
let ty ← Lean.Meta.inferType x
let ty ← inferType x
pure ty
| x :: xl => do
trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]"
let ty ← Lean.Meta.inferType x
let ty ← inferType x
let xl_ty ← mkProdsType xl
mkAppM ``Prod #[ty, xl_ty]

Expand All @@ -114,7 +124,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×
let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) :=
match in_tys with
| [] => do
let fvars ← getFVarIds (← Lean.Meta.inferType out_ty)
let fvars ← getFVarIds (← inferType out_ty)
pure (fvars, [], [])
| ty :: in_tys => do
let (fvars, in_tys, in_args) ← splitAux in_tys
Expand All @@ -132,7 +142,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×
pure (fvars, [ty], in_args)
else
-- We must split later: update the fvars set
let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty))
let fvars := fvars.insertMany (← getFVarIds (← inferType ty))
pure (fvars, [], ty :: in_args)
let (_, in_tys, in_args) ← splitAux in_tys.data
pure (Array.mk in_tys, Array.mk in_args)
Expand Down Expand Up @@ -250,13 +260,13 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
| Sigma.mk x ... -- the hole is given by a recursive call on the tail
``` -/
trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]"
let alpha ← Lean.Meta.inferType fst
let alpha ← inferType fst
let snd_ty ← mkSigmasType xl
let beta ← mkLambdaFVars #[fst] snd_ty
let snd ← mkSigmasMatch xl out (index + 1)
let mk ← mkLambdaFVars #[fst] snd
-- Introduce the "scrut" variable
let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty
let scrut_ty ← mkSigmaType fst snd_ty
withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})"
-- TODO: make the computation of the motive more efficient
Expand Down Expand Up @@ -294,12 +304,12 @@ partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Meta
mkLambdaFVars #[x] out
| fst :: xl => do
trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]"
let alpha ← Lean.Meta.inferType fst
let alpha ← inferType fst
let beta ← mkProdsType xl
let snd ← mkProdsMatch xl out (index + 1)
let mk ← mkLambdaFVars #[fst] snd
-- Introduce the "scrut" variable
let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta
let scrut_ty ← mkProdType alpha beta
withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})"
-- TODO: make the computation of the motive more efficient
Expand Down Expand Up @@ -1265,8 +1275,29 @@ elab_rules : command

namespace Tests
/- Some examples of partial functions -/

divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
/- section HigherOrder
open FixI
-- The index type
variable {id : Type u}
-- The input/output types
variable {a : id → Type v} {b : (i:id) → a i → Type w}
-- Example with a higher-order function
theorem map_is_valid
{{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}}
(Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x))
(k : (a → Result b) → a → Result b)
(ls : List a) :
is_valid_p k (λ k => Ex5.map (f k) ls) :=
induction ls <;> simp [map]
apply is_valid_p_bind <;> try simp_all
intros
apply is_valid_p_bind <;> try simp_all
end HigherOrder -/

divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
| [] => .fail .panic
| x :: ls =>
Expand Down

0 comments on commit ee669c4

Please sign in to comment.