Skip to content

Commit

Permalink
feat: sugar for SatisfiesM (#1029)
Browse files Browse the repository at this point in the history
Co-authored-by: Mario Carneiro <[email protected]>
  • Loading branch information
kim-em and digama0 authored Nov 13, 2024
1 parent e0d8449 commit 11da075
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 8 deletions.
3 changes: 3 additions & 0 deletions Batteries.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ import Batteries.Data.UnionFind
import Batteries.Data.Vector
import Batteries.Lean.AttributeExtra
import Batteries.Lean.Delaborator
import Batteries.Lean.EStateM
import Batteries.Lean.Except
import Batteries.Lean.Expr
import Batteries.Lean.Float
import Batteries.Lean.HashMap
import Batteries.Lean.HashSet
import Batteries.Lean.IO.Process
import Batteries.Lean.Json
import Batteries.Lean.LawfulMonad
import Batteries.Lean.Meta.Basic
import Batteries.Lean.Meta.DiscrTree
import Batteries.Lean.Meta.Expr
Expand All @@ -59,6 +61,7 @@ import Batteries.Lean.NameMapAttribute
import Batteries.Lean.PersistentHashMap
import Batteries.Lean.PersistentHashSet
import Batteries.Lean.Position
import Batteries.Lean.SatisfiesM
import Batteries.Lean.Syntax
import Batteries.Lean.System.IO
import Batteries.Lean.TagAttribute
Expand Down
130 changes: 124 additions & 6 deletions Batteries/Classes/SatisfiesM.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
/-
Copyright (c) 2022 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro
Authors: Mario Carneiro, Kim Morrison
-/
import Batteries.Lean.EStateM
import Batteries.Lean.Except
import Batteries.Tactic.Lint

/-!
## SatisfiesM
Expand All @@ -12,6 +15,13 @@ and enables Hoare-like reasoning over monadic expressions. For example, given a
function `f : α → m β`, to say that the return value of `f` satisfies `Q` whenever
the input satisfies `P`, we write `∀ a, P a → SatisfiesM Q (f a)`.
For any monad equipped with `MonadSatisfying m`
one can lift `SatisfiesM` to a monadic value in `Subtype`,
using `satisfying x h : m {a // p a}`, where `x : m α` and `h : SatisfiesM p x`.
This includes `Option`, `ReaderT`, `StateT`, and `ExceptT`, and the Lean monad stack.
(Although it is not entirely clear one should treat the Lean monad stack as lawful,
even though Lean accepts this.)
## Notes
`SatisfiesM` is not yet a satisfactory solution for verifying the behaviour of large scale monadic
Expand All @@ -23,7 +33,7 @@ presumably requiring more syntactic support (and smarter `do` blocks) from Lean.
Or it may be that such a solution will look different!
This is an open research program, and for now one should not be overly ambitious using `SatisfiesM`.
In particular lemmas about pure operations on data structures in `batteries` except for `HashMap`
In particular lemmas about pure operations on data structures in `Batteries` except for `HashMap`
should avoid `SatisfiesM` for now, so that it is easy to migrate to other approaches in future.
-/

Expand Down Expand Up @@ -158,25 +168,133 @@ end SatisfiesM
by revert x; intro | .ok _, ⟨.ok ⟨_, h⟩, rfl⟩, _, rfl => exact h,
fun h => match x with | .ok a => ⟨.ok ⟨a, h _ rfl⟩, rfl⟩ | .error e => ⟨.error e, rfl⟩⟩

theorem SatisfiesM_EStateM_eq :
SatisfiesM (m := EStateM ε σ) p x ↔ ∀ s a s', x.run s = .ok a s' → p a := by
constructor
· rintro ⟨x, rfl⟩ s a s' h
match w : x.run s with
| .ok a s' => simp at h; exact h.1
| .error e s' => simp [w] at h
· intro w
refine ⟨?_, ?_⟩
· intro s
match q : x.run s with
| .ok a s' => exact .ok ⟨a, w s a s' q⟩ s'
| .error e s' => exact .error e s'
· ext s
rw [EStateM.run_map, EStateM.run]
split <;> simp_all

@[simp] theorem SatisfiesM_ReaderT_eq [Monad m] :
SatisfiesM (m := ReaderT ρ m) p x ↔ ∀ s, SatisfiesM p (x s) :=
SatisfiesM (m := ReaderT ρ m) p x ↔ ∀ s, SatisfiesM p (x.run s) :=
(exists_congr fun a => by exact ⟨fun eq _ => eq ▸ rfl, funext⟩).trans Classical.skolem.symm

theorem SatisfiesM_StateRefT_eq [Monad m] :
SatisfiesM (m := StateRefT' ω σ m) p x ↔ ∀ s, SatisfiesM p (x s) := by simp
SatisfiesM (m := StateRefT' ω σ m) p x ↔ ∀ s, SatisfiesM p (x s) := by simp [ReaderT.run]

@[simp] theorem SatisfiesM_StateT_eq [Monad m] [LawfulMonad m] :
SatisfiesM (m := StateT ρ m) (α := α) p x ↔ ∀ s, SatisfiesM (m := m) (p ·.1) (x s) := by
SatisfiesM (m := StateT ρ m) (α := α) p x ↔ ∀ s, SatisfiesM (m := m) (p ·.1) (x.run s) := by
change SatisfiesM (m := StateT ρ m) (α := α) p x ↔ ∀ s, SatisfiesM (m := m) (p ·.1) (x s)
refine .trans ⟨fun ⟨f, eq⟩ => eq ▸ ?_, fun ⟨f, h⟩ => ?_⟩ Classical.skolem.symm
· refine ⟨fun s => (fun ⟨⟨a, h⟩, s'⟩ => ⟨⟨a, s'⟩, h⟩) <$> f s, fun s => ?_⟩
rw [← comp_map, map_eq_pure_bind]; rfl
· refine ⟨fun s => (fun ⟨⟨a, s'⟩, h⟩ => ⟨⟨a, h⟩, s'⟩) <$> f s, funext fun s => ?_⟩
show _ >>= _ = _; simp [← h]

@[simp] theorem SatisfiesM_ExceptT_eq [Monad m] [LawfulMonad m] :
SatisfiesM (m := ExceptT ρ m) (α := α) p x ↔ SatisfiesM (m := m) (∀ a, · = .ok a → p a) x := by
SatisfiesM (m := ExceptT ρ m) (α := α) p x ↔
SatisfiesM (m := m) (∀ a, · = .ok a → p a) x.run := by
change _ ↔ SatisfiesM (m := m) (∀ a, · = .ok a → p a) x
refine ⟨fun ⟨f, eq⟩ => eq ▸ ?_, fun ⟨f, eq⟩ => eq ▸ ?_⟩
· exists (fun | .ok ⟨a, h⟩ => ⟨.ok a, fun | _, rfl => h⟩ | .error e => ⟨.error e, nofun⟩) <$> f
show _ = _ >>= _; rw [← comp_map, map_eq_pure_bind]; congr; funext a; cases a <;> rfl
· exists ((fun | ⟨.ok a, h⟩ => .ok ⟨a, h _ rfl⟩ | ⟨.error e, _⟩ => .error e) <$> f : m _)
show _ >>= _ = _; simp [← comp_map, ← bind_pure_comp]; congr; funext ⟨a, h⟩; cases a <;> rfl

/--
If a monad has `MonadSatisfying m`, then we can lift a `h : SatisfiesM (m := m) p x` predicate
to monadic value `satisfying x p : m { x // p x }`.
Reader, state, and exception monads have `MonadSatisfying` instances if the base monad does.
-/
class MonadSatisfying (m : Type u → Type v) [Functor m] [LawfulFunctor m] where
/-- Lift a `SatisfiesM` predicate to a monadic value. -/
satisfying {p : α → Prop} {x : m α} (h : SatisfiesM (m := m) p x) : m {a // p a}
/-- The value of the lifted monadic value is equal to the original monadic value. -/
val_eq {p : α → Prop} {x : m α} (h : SatisfiesM (m := m) p x) : Subtype.val <$> satisfying h = x

export MonadSatisfying (satisfying)

namespace MonadSatisfying

instance : MonadSatisfying Id where
satisfying {α p x} h := ⟨x, by obtain ⟨⟨_, h⟩, rfl⟩ := h; exact h⟩
val_eq {α p x} h := rfl

instance : MonadSatisfying Option where
satisfying {α p x?} h :=
have h' := SatisfiesM_Option_eq.mp h
match x? with
| none => none
| some x => some ⟨x, h' x rfl⟩
val_eq {α p x?} h := by cases x? <;> simp

instance : MonadSatisfying (Except ε) where
satisfying {α p x?} h :=
have h' := SatisfiesM_Except_eq.mp h
match x? with
| .ok x => .ok ⟨x, h' x rfl⟩
| .error e => .error e
val_eq {α p x?} h := by cases x? <;> simp

-- This will be redundant after nightly-2024-11-08.
attribute [ext] ReaderT.ext

instance [Monad m] [LawfulMonad m][MonadSatisfying m] : MonadSatisfying (ReaderT ρ m) where
satisfying {α p x} h :=
have h' := SatisfiesM_ReaderT_eq.mp h
fun r => satisfying (h' r)
val_eq {α p x} h := by
have h' := SatisfiesM_ReaderT_eq.mp h
ext r
rw [ReaderT.run_map, ← MonadSatisfying.val_eq (h' r)]
rfl

instance [Monad m] [LawfulMonad m] [MonadSatisfying m] : MonadSatisfying (StateRefT' ω σ m) :=
inferInstanceAs <| MonadSatisfying (ReaderT _ _)

-- This will be redundant after nightly-2024-11-08.
attribute [ext] StateT.ext

instance [Monad m] [LawfulMonad m] [MonadSatisfying m] : MonadSatisfying (StateT ρ m) where
satisfying {α p x} h :=
have h' := SatisfiesM_StateT_eq.mp h
fun r => (fun ⟨⟨a, r'⟩, h⟩ => ⟨⟨a, h⟩, r'⟩) <$> satisfying (h' r)
val_eq {α p x} h := by
have h' := SatisfiesM_StateT_eq.mp h
ext r
rw [← MonadSatisfying.val_eq (h' r), StateT.run_map]
simp [StateT.run]

instance [Monad m] [LawfulMonad m] [MonadSatisfying m] : MonadSatisfying (ExceptT ε m) where
satisfying {α p x} h :=
let x' := satisfying (SatisfiesM_ExceptT_eq.mp h)
ExceptT.mk ((fun ⟨y, w⟩ => y.pmap fun a h => ⟨a, w _ h⟩) <$> x')
val_eq {α p x} h:= by
ext
rw [← MonadSatisfying.val_eq (SatisfiesM_ExceptT_eq.mp h)]
simp

instance : MonadSatisfying (EStateM ε σ) where
satisfying {α p x} h :=
have h' := SatisfiesM_EStateM_eq.mp h
fun s => match w : x.run s with
| .ok a s' => .ok ⟨a, h' s a s' w⟩ s'
| .error e s' => .error e s'
val_eq {α p x} h := by
ext s
rw [EStateM.run_map, EStateM.run]
simp only
split <;> simp_all

end MonadSatisfying
2 changes: 1 addition & 1 deletion Batteries/Data/HashMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ Applies `f` to each key-value pair `a, b` in the map. If it returns `some c` the
have : m'.1.size > 0 := by
have := Array.size_mapM (m := StateT (ULift Nat) Id) (go .nil) m.buckets.1
simp [SatisfiesM_StateT_eq, SatisfiesM_Id_eq] at this
simp [this, Id.run, StateT.run, m.2.2, m']
simp [this, Id.run, m.2.2, m']
⟨m'.2.1, m'.1, this⟩
where
/-- Inner loop of `filterMap`. Note that this reverses the bucket lists,
Expand Down
40 changes: 40 additions & 0 deletions Batteries/Lean/EStateM.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/

namespace EStateM

namespace Result

/-- Map a function over an `EStateM.Result`, preserving states and errors. -/
def map {ε σ α β : Type u} (f : α → β) (x : Result ε σ α) : Result ε σ β :=
match x with
| .ok a s' => .ok (f a) s'
| .error e s' => .error e s'

@[simp] theorem map_ok {ε σ α β : Type u} (f : α → β) (a : α) (s : σ) :
(Result.ok a s : Result ε σ α).map f = .ok (f a) s := rfl

@[simp] theorem map_error {ε σ α β : Type u} (f : α → β) (e : ε) (s : σ) :
(Result.error e s : Result ε σ α).map f = .error e s := rfl

@[simp] theorem map_eq_ok {ε σ α β : Type u} (f : α → β) (x : Result ε σ α) (b : β) (s : σ) :
x.map f = .ok b s ↔ ∃ a, x = .ok a s ∧ b = f a := by
cases x <;> simp [and_assoc, and_comm, eq_comm]

@[simp] theorem map_eq_error {ε σ α β : Type u} (f : α → β) (x : Result ε σ α) (e : ε) (s : σ) :
x.map f = .error e s ↔ x = .error e s := by
cases x <;> simp [eq_comm]

end Result

@[simp] theorem run_map (f : α → β) (x : EStateM ε σ α) :
(f <$> x).run s = (x.run s).map f := rfl

@[ext] theorem ext {ε σ α : Type u} (x y : EStateM ε σ α) (h : ∀ s, x.run s = y.run s) : x = y := by
funext s
exact h s

end EStateM
51 changes: 50 additions & 1 deletion Batteries/Lean/Except.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,56 @@ import Lean.Util.Trace

open Lean

namespace Except

/-- Visualize an `Except` using a checkmark or a cross. -/
def Except.emoji : Except ε α → String
def emoji : Except ε α → String
| .error _ => crossEmoji
| .ok _ => checkEmoji

@[simp] theorem map_error {ε : Type u} (f : α → β) (e : ε) :
f <$> (.error e : Except ε α) = .error e := rfl

@[simp] theorem map_ok {ε : Type u} (f : α → β) (x : α) :
f <$> (.ok x : Except ε α) = .ok (f x) := rfl

/-- Map a function over an `Except` value, using a proof that the value is `.ok`. -/
def pmap {ε : Type u} {α β : Type v} (x : Except ε α) (f : (a : α) → x = .ok a → β) : Except ε β :=
match x with
| .error e => .error e
| .ok a => .ok (f a rfl)

@[simp] theorem pmap_error {ε : Type u} {α β : Type v} (e : ε)
(f : (a : α) → Except.error e = Except.ok a → β) :
Except.pmap (.error e) f = .error e := rfl

@[simp] theorem pmap_ok {ε : Type u} {α β : Type v} (a : α)
(f : (a' : α) → (.ok a : Except ε α) = .ok a' → β) :
Except.pmap (.ok a) f = .ok (f a rfl) := rfl

@[simp] theorem pmap_id {ε : Type u} {α : Type v} (x : Except ε α) :
x.pmap (fun a _ => a) = x := by cases x <;> simp

@[simp] theorem map_pmap (g : β → γ) (x : Except ε α) (f : (a : α) → x = .ok a → β) :
g <$> x.pmap f = x.pmap fun a h => g (f a h) := by
cases x <;> simp

end Except

namespace ExceptT

-- This will be redundant after nightly-2024-11-08.
attribute [ext] ExceptT.ext

@[simp] theorem run_mk {m : Type u → Type v} (x : m (Except ε α)) : (ExceptT.mk x).run = x := rfl
@[simp] theorem mk_run (x : ExceptT ε m α) : ExceptT.mk (ExceptT.run x) = x := rfl

@[simp]
theorem map_mk [Monad m] [LawfulMonad m] (f : α → β) (x : m (Except ε α)) :
f <$> ExceptT.mk x = ExceptT.mk ((f <$> · ) <$> x) := by
simp only [Functor.map, Except.map, ExceptT.map, map_eq_pure_bind]
congr
funext a
split <;> simp

end ExceptT
30 changes: 30 additions & 0 deletions Batteries/Lean/LawfulMonad.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
import Batteries.Classes.SatisfiesM
import Lean.Elab.Command

/-!
# Construct `LawfulMonad` instances for the Lean monad stack.
-/

open Lean Elab Term Tactic Command

instance : LawfulMonad (EIO ε) := inferInstanceAs <| LawfulMonad (EStateM _ _)
instance : LawfulMonad BaseIO := inferInstanceAs <| LawfulMonad (EIO _)
instance : LawfulMonad IO := inferInstanceAs <| LawfulMonad (EIO _)

instance : LawfulMonad (EST ε σ) := inferInstanceAs <| LawfulMonad (EStateM _ _)

instance : LawfulMonad CoreM :=
inferInstanceAs <| LawfulMonad (ReaderT _ <| StateRefT' _ _ (EIO Exception))
instance : LawfulMonad MetaM :=
inferInstanceAs <| LawfulMonad (ReaderT _ <| StateRefT' _ _ CoreM)
instance : LawfulMonad TermElabM :=
inferInstanceAs <| LawfulMonad (ReaderT _ <| StateRefT' _ _ MetaM)
instance : LawfulMonad TacticM :=
inferInstanceAs <| LawfulMonad (ReaderT _ $ StateRefT' _ _ $ TermElabM)
instance : LawfulMonad CommandElabM :=
inferInstanceAs <| LawfulMonad (ReaderT _ $ StateRefT' _ _ $ EIO _)
35 changes: 35 additions & 0 deletions Batteries/Lean/SatisfiesM.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
import Batteries.Classes.SatisfiesM
import Batteries.Lean.LawfulMonad
import Lean.Elab.Command

/-!
# Construct `MonadSatisfying` instances for the Lean monad stack.
-/

open Lean Elab Term Tactic Command

instance : MonadSatisfying (EIO ε) := inferInstanceAs <| MonadSatisfying (EStateM _ _)
instance : MonadSatisfying BaseIO := inferInstanceAs <| MonadSatisfying (EIO _)
instance : MonadSatisfying IO := inferInstanceAs <| MonadSatisfying (EIO _)

instance : MonadSatisfying (EST ε σ) := inferInstanceAs <| MonadSatisfying (EStateM _ _)

instance : MonadSatisfying CoreM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ <| StateRefT' _ _ (EIO _))

instance : MonadSatisfying MetaM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ <| StateRefT' _ _ CoreM)

instance : MonadSatisfying TermElabM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ <| StateRefT' _ _ MetaM)

instance : MonadSatisfying TacticM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ $ StateRefT' _ _ TermElabM)

instance : MonadSatisfying CommandElabM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ $ StateRefT' _ _ (EIO _))
8 changes: 8 additions & 0 deletions BatteriesTest/satisfying.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import Batteries.Lean.SatisfiesM
import Batteries.Data.Array.Monadic

open Lean Meta Array Elab Term Tactic Command

example (xs : Array Expr) : MetaM { ts : Array Expr // ts.size = xs.size } := do
let r ← satisfying (xs.size_mapM inferType)
return r

0 comments on commit 11da075

Please sign in to comment.