Skip to content

Commit

Permalink
Improve the micro passes to eliminate pattern let f := fun x => g x
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Dec 22, 2023
1 parent aa5e257 commit b6ef8ee
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 54 deletions.
45 changes: 43 additions & 2 deletions compiler/PureMicroPasses.ml
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,15 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let y1 = x1 in
...
]}
Simplify arrows:
{[
let f := fun x => g x in
...
~~>
let f := g in
...
]}
*)
let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
let obj =
Expand Down Expand Up @@ -739,6 +748,23 @@ let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
super#visit_expression env e.e
| _ -> super#visit_Let env monadic lv rv next
else super#visit_Let env monadic lv rv next
| Lambda _ ->
if not monadic then
(* Arrow case *)
let pats, e = destruct_lambdas rv in
let g, args = destruct_apps e in
if List.length pats = List.length args then
(* Check if the arguments are exactly the lambdas *)
let check_pat_arg ((pat, arg) : typed_pattern * texpression) =
match (pat.value, arg.e) with
| PatVar (v, _), Var vid -> v.id = vid
| _ -> false
in
if List.for_all check_pat_arg (List.combine pats args) then
self#visit_Let env monadic lv g next
else super#visit_Let env monadic lv rv next
else super#visit_Let env monadic lv rv next
else super#visit_Let env monadic lv rv next
| _ -> super#visit_Let env monadic lv rv next
end
in
Expand Down Expand Up @@ -1934,9 +1960,10 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Inline the useless variable reassignments *)
let inline_named_vars = true in
let inline_pure = true in
let def =
inline_useless_var_reassignments ctx inline_named_vars inline_pure def
let inline_useless_var_reassignments ctx =
inline_useless_var_reassignments ctx inline_named_vars inline_pure
in
let def = inline_useless_var_reassignments ctx def in
log#ldebug
(lazy
("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
Expand Down Expand Up @@ -1982,6 +2009,20 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
log#ldebug
(lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));

(* Simplify the let-bindings - some simplifications may have been unlocked by
the pass above (for instance, the lambda simplification) *)
let def = simplify_let_bindings ctx def in
log#ldebug
(lazy
("simplify_let_bindings (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n"));

(* Inline the useless vars again *)
let def = inline_useless_var_reassignments ctx def in
log#ldebug
(lazy
("inline_useless_var_assignments (pass 2):\n\n"
^ fun_decl_to_string ctx def ^ "\n"));

(* Decompose the monadic let-bindings - used by Coq *)
let def =
if !Config.decompose_monadic_let_bindings then (
Expand Down
26 changes: 8 additions & 18 deletions tests/coq/misc/Loops.v
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ Definition list_nth_mut_loop
(T : Type) (n : nat) (ls : List_t T) (i : u32) :
result (T * (T -> result (List_t T)))
:=
p <- list_nth_mut_loop_loop T n ls i;
let (t, back) := p in
let back1 := fun (ret : T) => back ret in
Return (t, back1)
p <- list_nth_mut_loop_loop T n ls i; let (t, back) := p in Return (t, back)
.

(** [loops::list_nth_shared_loop]: loop 0:
Expand Down Expand Up @@ -265,7 +262,7 @@ Definition id_mut
(T : Type) (ls : List_t T) :
result ((List_t T) * (List_t T -> result (List_t T)))
:=
let back := fun (ret : List_t T) => Return ret in Return (ls, back)
Return (ls, Return)
.

(** [loops::id_shared]:
Expand Down Expand Up @@ -382,9 +379,7 @@ Definition list_nth_mut_loop_pair
:=
t <- list_nth_mut_loop_pair_loop T n ls0 ls1 i;
let (p, back_'a, back_'b) := t in
let back_'a1 := fun (ret : T) => back_'a ret in
let back_'b1 := fun (ret : T) => back_'b ret in
Return (p, back_'a1, back_'b1)
Return (p, back_'a, back_'b)
.

(** [loops::list_nth_shared_loop_pair]: loop 0:
Expand Down Expand Up @@ -465,8 +460,7 @@ Definition list_nth_mut_loop_pair_merge
:=
p <- list_nth_mut_loop_pair_merge_loop T n ls0 ls1 i;
let (p1, back_'a) := p in
let back_'a1 := fun (ret : (T * T)) => back_'a ret in
Return (p1, back_'a1)
Return (p1, back_'a)
.

(** [loops::list_nth_shared_loop_pair_merge]: loop 0:
Expand Down Expand Up @@ -542,8 +536,7 @@ Definition list_nth_mut_shared_loop_pair
:=
p <- list_nth_mut_shared_loop_pair_loop T n ls0 ls1 i;
let (p1, back_'a) := p in
let back_'a1 := fun (ret : T) => back_'a ret in
Return (p1, back_'a1)
Return (p1, back_'a)
.

(** [loops::list_nth_mut_shared_loop_pair_merge]: loop 0:
Expand Down Expand Up @@ -585,8 +578,7 @@ Definition list_nth_mut_shared_loop_pair_merge
:=
p <- list_nth_mut_shared_loop_pair_merge_loop T n ls0 ls1 i;
let (p1, back_'a) := p in
let back_'a1 := fun (ret : T) => back_'a ret in
Return (p1, back_'a1)
Return (p1, back_'a)
.

(** [loops::list_nth_shared_mut_loop_pair]: loop 0:
Expand Down Expand Up @@ -628,8 +620,7 @@ Definition list_nth_shared_mut_loop_pair
:=
p <- list_nth_shared_mut_loop_pair_loop T n ls0 ls1 i;
let (p1, back_'b) := p in
let back_'b1 := fun (ret : T) => back_'b ret in
Return (p1, back_'b1)
Return (p1, back_'b)
.

(** [loops::list_nth_shared_mut_loop_pair_merge]: loop 0:
Expand Down Expand Up @@ -671,8 +662,7 @@ Definition list_nth_shared_mut_loop_pair_merge
:=
p <- list_nth_shared_mut_loop_pair_merge_loop T n ls0 ls1 i;
let (p1, back_'a) := p in
let back_'a1 := fun (ret : T) => back_'a ret in
Return (p1, back_'a1)
Return (p1, back_'a)
.

End Loops.
25 changes: 8 additions & 17 deletions tests/fstar/misc/Loops.Funs.fst
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ let list_nth_mut_loop
(t : Type0) (ls : list_t t) (i : u32) :
result (t & (t -> result (list_t t)))
=
let* (x, back) = list_nth_mut_loop_loop t ls i in
let back1 = fun ret -> back ret in
Return (x, back1)
let* (x, back) = list_nth_mut_loop_loop t ls i in Return (x, back)

(** [loops::list_nth_shared_loop]: loop 0:
Source: 'src/loops.rs', lines 91:0-101:1 *)
Expand Down Expand Up @@ -201,7 +199,7 @@ let id_mut
(t : Type0) (ls : list_t t) :
result ((list_t t) & (list_t t -> result (list_t t)))
=
let back = fun ret -> Return ret in Return (ls, back)
Return (ls, Return)

(** [loops::id_shared]:
Source: 'src/loops.rs', lines 139:0-139:45 *)
Expand Down Expand Up @@ -296,9 +294,7 @@ let list_nth_mut_loop_pair
result ((t & t) & (t -> result (list_t t)) & (t -> result (list_t t)))
=
let* (p, back_'a, back_'b) = list_nth_mut_loop_pair_loop t ls0 ls1 i in
let back_'a1 = fun ret -> back_'a ret in
let back_'b1 = fun ret -> back_'b ret in
Return (p, back_'a1, back_'b1)
Return (p, back_'a, back_'b)

(** [loops::list_nth_shared_loop_pair]: loop 0:
Source: 'src/loops.rs', lines 198:0-219:1 *)
Expand Down Expand Up @@ -362,8 +358,7 @@ let list_nth_mut_loop_pair_merge
result ((t & t) & ((t & t) -> result ((list_t t) & (list_t t))))
=
let* (p, back_'a) = list_nth_mut_loop_pair_merge_loop t ls0 ls1 i in
let back_'a1 = fun ret -> back_'a ret in
Return (p, back_'a1)
Return (p, back_'a)

(** [loops::list_nth_shared_loop_pair_merge]: loop 0:
Source: 'src/loops.rs', lines 241:0-256:1 *)
Expand Down Expand Up @@ -425,8 +420,7 @@ let list_nth_mut_shared_loop_pair
result ((t & t) & (t -> result (list_t t)))
=
let* (p, back_'a) = list_nth_mut_shared_loop_pair_loop t ls0 ls1 i in
let back_'a1 = fun ret -> back_'a ret in
Return (p, back_'a1)
Return (p, back_'a)

(** [loops::list_nth_mut_shared_loop_pair_merge]: loop 0:
Source: 'src/loops.rs', lines 278:0-293:1 *)
Expand Down Expand Up @@ -462,8 +456,7 @@ let list_nth_mut_shared_loop_pair_merge
result ((t & t) & (t -> result (list_t t)))
=
let* (p, back_'a) = list_nth_mut_shared_loop_pair_merge_loop t ls0 ls1 i in
let back_'a1 = fun ret -> back_'a ret in
Return (p, back_'a1)
Return (p, back_'a)

(** [loops::list_nth_shared_mut_loop_pair]: loop 0:
Source: 'src/loops.rs', lines 297:0-312:1 *)
Expand Down Expand Up @@ -498,8 +491,7 @@ let list_nth_shared_mut_loop_pair
result ((t & t) & (t -> result (list_t t)))
=
let* (p, back_'b) = list_nth_shared_mut_loop_pair_loop t ls0 ls1 i in
let back_'b1 = fun ret -> back_'b ret in
Return (p, back_'b1)
Return (p, back_'b)

(** [loops::list_nth_shared_mut_loop_pair_merge]: loop 0:
Source: 'src/loops.rs', lines 316:0-331:1 *)
Expand Down Expand Up @@ -535,6 +527,5 @@ let list_nth_shared_mut_loop_pair_merge
result ((t & t) & (t -> result (list_t t)))
=
let* (p, back_'a) = list_nth_shared_mut_loop_pair_merge_loop t ls0 ls1 i in
let back_'a1 = fun ret -> back_'a ret in
Return (p, back_'a1)
Return (p, back_'a)

25 changes: 8 additions & 17 deletions tests/lean/Loops.lean
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def list_nth_mut_loop
(T : Type) (ls : List T) (i : U32) : Result (T × (T → Result (List T))) :=
do
let (t, back) ← list_nth_mut_loop_loop T ls i
let back1 := fun ret => back ret
Result.ret (t, back1)
Result.ret (t, back)

/- [loops::list_nth_shared_loop]: loop 0:
Source: 'src/loops.rs', lines 91:0-101:1 -/
Expand Down Expand Up @@ -207,8 +206,7 @@ def id_mut
(T : Type) (ls : List T) :
Result ((List T) × (List T → Result (List T)))
:=
let back := fun ret => Result.ret ret
Result.ret (ls, back)
Result.ret (ls, Result.ret)

/- [loops::id_shared]:
Source: 'src/loops.rs', lines 139:0-139:45 -/
Expand Down Expand Up @@ -308,9 +306,7 @@ def list_nth_mut_loop_pair
:=
do
let (p, back_'a, back_'b) ← list_nth_mut_loop_pair_loop T ls0 ls1 i
let back_'a1 := fun ret => back_'a ret
let back_'b1 := fun ret => back_'b ret
Result.ret (p, back_'a1, back_'b1)
Result.ret (p, back_'a, back_'b)

/- [loops::list_nth_shared_loop_pair]: loop 0:
Source: 'src/loops.rs', lines 198:0-219:1 -/
Expand Down Expand Up @@ -372,8 +368,7 @@ def list_nth_mut_loop_pair_merge
:=
do
let (p, back_'a) ← list_nth_mut_loop_pair_merge_loop T ls0 ls1 i
let back_'a1 := fun ret => back_'a ret
Result.ret (p, back_'a1)
Result.ret (p, back_'a)

/- [loops::list_nth_shared_loop_pair_merge]: loop 0:
Source: 'src/loops.rs', lines 241:0-256:1 -/
Expand Down Expand Up @@ -432,8 +427,7 @@ def list_nth_mut_shared_loop_pair
:=
do
let (p, back_'a) ← list_nth_mut_shared_loop_pair_loop T ls0 ls1 i
let back_'a1 := fun ret => back_'a ret
Result.ret (p, back_'a1)
Result.ret (p, back_'a)

/- [loops::list_nth_mut_shared_loop_pair_merge]: loop 0:
Source: 'src/loops.rs', lines 278:0-293:1 -/
Expand Down Expand Up @@ -470,8 +464,7 @@ def list_nth_mut_shared_loop_pair_merge
:=
do
let (p, back_'a) ← list_nth_mut_shared_loop_pair_merge_loop T ls0 ls1 i
let back_'a1 := fun ret => back_'a ret
Result.ret (p, back_'a1)
Result.ret (p, back_'a)

/- [loops::list_nth_shared_mut_loop_pair]: loop 0:
Source: 'src/loops.rs', lines 297:0-312:1 -/
Expand Down Expand Up @@ -507,8 +500,7 @@ def list_nth_shared_mut_loop_pair
:=
do
let (p, back_'b) ← list_nth_shared_mut_loop_pair_loop T ls0 ls1 i
let back_'b1 := fun ret => back_'b ret
Result.ret (p, back_'b1)
Result.ret (p, back_'b)

/- [loops::list_nth_shared_mut_loop_pair_merge]: loop 0:
Source: 'src/loops.rs', lines 316:0-331:1 -/
Expand Down Expand Up @@ -545,7 +537,6 @@ def list_nth_shared_mut_loop_pair_merge
:=
do
let (p, back_'a) ← list_nth_shared_mut_loop_pair_merge_loop T ls0 ls1 i
let back_'a1 := fun ret => back_'a ret
Result.ret (p, back_'a1)
Result.ret (p, back_'a)

end loops

0 comments on commit b6ef8ee

Please sign in to comment.