Skip to content

Commit

Permalink
Filter the useless backward functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Dec 21, 2023
1 parent cf3eea5 commit d4b3d0e
Showing 1 changed file with 145 additions and 75 deletions.
220 changes: 145 additions & 75 deletions compiler/SymbolicToPure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type call_info = {
Those inputs include the fuel and the state, if pertinent.
*)
back_funs : texpression RegionGroupId.Map.t option;
back_funs : texpression option RegionGroupId.Map.t option;
(** If we do not split between the forward/backward functions: the
variables we introduced for the backward functions.
Expand All @@ -78,6 +78,10 @@ type call_info = {
here
...
]}
The expression might be [None] in case the backward function
has to be filtered (because it does nothing - the backward
functions for shared borrows for instance).
*)
}
[@@deriving show]
Expand Down Expand Up @@ -125,7 +129,7 @@ type loop_info = {
(** The map from region group ids to the types of the values given back
by the corresponding loop abstractions.
*)
back_funs : texpression RegionGroupId.Map.t option;
back_funs : texpression option RegionGroupId.Map.t option;
(** Same as {!call_info.back_funs}.
Initialized with [None], gets updated to [Some] only if we merge
the fwd/back functions.
Expand Down Expand Up @@ -777,8 +781,8 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)

let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
(args : texpression list)
(back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx
=
(back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) :
bs_ctx =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
let info = { forward; forward_inputs = args; back_funs } in
Expand All @@ -790,13 +794,15 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
[back_args]: the *additional* list of inputs received by the backward function,
including the state.
Returns the updated context and the expression corresponding to the function.
Returns the updated context and the expression corresponding to the function
that we need to call. This function may be [None] if it has to be ignored
(because it does nothing).
*)
let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
(call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
(inherited_args : texpression list) (back_args : texpression list)
(generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
bs_ctx * texpression =
bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
Expand Down Expand Up @@ -827,7 +833,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
in
let func_ty = mk_arrows input_tys ret_ty in
let func = { id = FunOrOp fun_id; generics } in
{ e = Qualif func; ty = func_ty }
Some { e = Qualif func; ty = func_ty }
in
(* Update the context and return *)
({ ctx with calls; abstractions }, func)
Expand Down Expand Up @@ -1128,23 +1134,36 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
in
if effect_info.can_fail then mk_result_ty output else output

(** Compute the arrow types for all the backward functions. *)
(** Compute the arrow types for all the backward functions.
If a backward function has no inputs/outputs we filter it.
*)
let compute_back_tys (dsg : Pure.decomposed_fun_sig)
(subst : (generic_args * trait_instance_id) option) : ty list =
(subst : (generic_args * trait_instance_id) option) : ty option list =
List.map
(fun (back_sg : back_sg_info) ->
let effect_info = back_sg.effect_info in
(* Compute *)
(* Compute the input/output types *)
let inputs = List.map snd back_sg.inputs in
let output = mk_simpl_tuple_ty back_sg.outputs in
let output = mk_output_ty_from_effect_info effect_info output in
let ty = mk_arrows inputs output in
(* Substitute - TODO: normalize *)
match subst with
| None -> ty
| Some (generics, tr_self) ->
let subst = make_subst_from_generics dsg.generics generics tr_self in
ty_substitute subst ty)
let outputs = back_sg.outputs in
(* Filter if necessary *)
if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then
None
else
let output = mk_simpl_tuple_ty outputs in
let output = mk_output_ty_from_effect_info effect_info output in
let ty = mk_arrows inputs output in
(* Substitute - TODO: normalize *)
let ty =
match subst with
| None -> ty
| Some (generics, tr_self) ->
let subst =
make_subst_from_generics dsg.generics generics tr_self
in
ty_substitute subst ty
in
Some ty)
(RegionGroupId.Map.values dsg.back_sg)

let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
Expand All @@ -1169,7 +1188,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
if !Config.return_back_funs then (
assert (gid = None);
(* Compute the arrow types for all the backward functions *)
let back_tys = compute_back_tys dsg None in
let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
Expand Down Expand Up @@ -1259,8 +1278,19 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) :
bs_ctx * var list =
List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars

let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) :
bs_ctx * var option list =
List.fold_left_map
(fun ctx var ->
match var with
| None -> (ctx, None)
| Some (name, ty) ->
let ctx, var = fresh_var name ty ctx in
(ctx, Some var))
ctx vars

(* Introduce variables for the backward functions *)
let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list =
let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
(* We lookup the LLBC definition in an attempt to derive pretty names
for the backward functions. *)
let back_var_names =
Expand Down Expand Up @@ -1291,7 +1321,13 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list =
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
fresh_vars back_vars ctx
let back_vars =
List.map
(fun (name, ty) ->
match ty with None -> None | Some ty -> Some (name, ty))
back_vars
in
fresh_opt_vars back_vars ctx

let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
Expand Down Expand Up @@ -1748,6 +1784,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
| None ->
if !Config.return_back_funs then
let back_tys = compute_back_tys ctx.sg None in
let back_tys = List.filter_map (fun x -> x) back_tys in
let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
mk_output output
else mk_output ctx.sg.fwd_output
Expand Down Expand Up @@ -1933,21 +1970,33 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
name ^ "_back"
in
let ctx, back_vars =
fresh_vars
(List.map (fun ty -> (Some back_fun_name, ty)) back_tys)
fresh_opt_vars
(List.map
(fun ty ->
match ty with
| None -> None
| Some ty -> Some (Some back_fun_name, ty))
back_tys)
ctx
in
let back_funs =
List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
List.filter_map
(fun v ->
match v with
| None -> None
| Some v -> Some (mk_typed_pattern_from_var v None))
back_vars
in
let gids =
List.map
(fun (g : T.region_var_group) -> g.id)
call.regions_hierarchy
in
let back_vars =
List.map (Option.map mk_texpression_from_var) back_vars
in
let back_funs_map =
RegionGroupId.Map.of_list
(List.combine gids (List.map mk_texpression_from_var back_vars))
RegionGroupId.Map.of_list (List.combine gids back_vars)
in
(ctx, Some back_funs_map, back_funs)
else (ctx, None, [])
Expand Down Expand Up @@ -2220,15 +2269,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
log#ldebug
(lazy
(let args = List.map (texpression_to_string ctx) args in
"func: "
^ texpression_to_string ctx func
^ "\nfunc type: "
^ pure_ty_to_string ctx func.ty
^ "\n\nargs:\n" ^ String.concat "\n" args));
let call = mk_apps func args in
(* **Optimization**:
=================
We do a small optimization here if we split the forward/backward functions.
Expand All @@ -2252,7 +2292,22 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else mk_let effect_info.can_fail output call next_e
else
(* The backward function might also have been filtered if we do not
split the forward/backward functions *)
match func with
| None -> next_e
| Some func ->
log#ldebug
(lazy
(let args = List.map (texpression_to_string ctx) args in
"func: "
^ texpression_to_string ctx func
^ "\nfunc type: "
^ pure_ty_to_string ctx func.ty
^ "\n\nargs:\n" ^ String.concat "\n" args));
let call = mk_apps func args in
mk_let effect_info.can_fail output call next_e

and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
Expand Down Expand Up @@ -2348,7 +2403,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopSynthInput ->
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
| V.LoopCall ->
| V.LoopCall -> (
(* We need to introduce a call to the backward function corresponding
to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
Expand Down Expand Up @@ -2419,9 +2474,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let func_ty = mk_arrows input_tys ret_ty in
let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; generics } in
{ e = Qualif func; ty = func_ty }
Some { e = Qualif func; ty = func_ty }
in
let call = mk_apps func args in
(* **Optimization**:
=================
We do a small optimization here in case we split the forward/backward
Expand All @@ -2447,38 +2501,44 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else
(* Add meta-information - this is slightly hacky: we look at the
values consumed by the abstraction (note that those come from
*before* we applied the fixed-point context) and use them to
guide the naming of the output vars.
Also, we need to convert the backward outputs from patterns to
variables.
Finally, in practice, this works well only for loop bodies:
we do this only in this case.
TODO: improve the heuristics, to give weight to the hints for
instance.
*)
let next_e =
if ctx.inside_loop then
let consumed_values = abs_to_consumed ctx ectx abs in
let var_values = List.combine outputs consumed_values in
let var_values =
List.filter_map
(fun (var, v) ->
match var.Pure.value with
| PatVar (var, _) -> Some (var, v)
| _ -> None)
var_values
(* In case we merge the fwd/back functions we filter the backward
functions elsewhere *)
match func with
| None -> next_e
| Some func ->
let call = mk_apps func args in
(* Add meta-information - this is slightly hacky: we look at the
values consumed by the abstraction (note that those come from
*before* we applied the fixed-point context) and use them to
guide the naming of the output vars.
Also, we need to convert the backward outputs from patterns to
variables.
Finally, in practice, this works well only for loop bodies:
we do this only in this case.
TODO: improve the heuristics, to give weight to the hints for
instance.
*)
let next_e =
if ctx.inside_loop then
let consumed_values = abs_to_consumed ctx ectx abs in
let var_values = List.combine outputs consumed_values in
let var_values =
List.filter_map
(fun (var, v) ->
match var.Pure.value with
| PatVar (var, _) -> Some (var, v)
| _ -> None)
var_values
in
let vars, values = List.split var_values in
mk_emeta_symbolic_assignments vars values next_e
else next_e
in
let vars, values = List.split var_values in
mk_emeta_symbolic_assignments vars values next_e
else next_e
in

(* Create the let-binding *)
mk_let effect_info.can_fail output call next_e
(* Create the let-binding *)
mk_let effect_info.can_fail output call next_e)

and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
Expand Down Expand Up @@ -2894,7 +2954,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let _, back_vars = fresh_back_vars_for_current_fun ctx in

(* Create the return expressions *)
let vars = fwd_var :: back_vars in
let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in
let vars = List.map mk_texpression_from_var vars in
let ret = mk_simpl_tuple_texpression vars in
let state_var = List.map mk_texpression_from_var state_var in
Expand All @@ -2903,12 +2963,16 @@ and translate_forward_end (ectx : C.eval_ctx)

(* Bind the expressions for the backward function and the expression
for the computation of the forward output *)
let back_vars_els =
List.filter_map
(fun (v, el) -> match v with None -> None | Some v -> Some (v, el))
(List.combine back_vars back_el)
in
let e =
List.fold_right
(fun (var, back_e) e ->
mk_let false (mk_typed_pattern_from_var var None) back_e e)
(List.combine back_vars back_el)
ret
back_vars_els ret
in
(* Bind the expression for the forward output *)
let fwd_var = mk_typed_pattern_from_var fwd_var None in
Expand Down Expand Up @@ -2976,12 +3040,18 @@ and translate_forward_end (ectx : C.eval_ctx)
if !Config.return_back_funs then
let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
let back_funs =
List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
List.filter_map
(fun v ->
match v with
| None -> None
| Some v -> Some (mk_typed_pattern_from_var v None))
back_vars
in
let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
let back_funs_map =
RegionGroupId.Map.of_list
(List.combine gids (List.map mk_texpression_from_var back_vars))
(List.combine gids
(List.map (Option.map mk_texpression_from_var) back_vars))
in
(ctx, Some back_funs_map, back_funs)
else (ctx, None, [])
Expand Down

0 comments on commit d4b3d0e

Please sign in to comment.