Skip to content

Commit

Permalink
Update SymbolicToPure.ml for the loops
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Dec 21, 2023
1 parent d9f91cf commit cf3eea5
Showing 1 changed file with 125 additions and 96 deletions.
221 changes: 125 additions & 96 deletions compiler/SymbolicToPure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ type loop_info = {
(** The map from region group ids to the types of the values given back
by the corresponding loop abstractions.
*)
back_funs : texpression RegionGroupId.Map.t option;
(** Same as {!call_info.back_funs}.
Initialized with [None], gets updated to [Some] only if we merge
the fwd/back functions.
*)
}
[@@deriving show]

Expand Down Expand Up @@ -1123,45 +1128,25 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
in
if effect_info.can_fail then mk_result_ty output else output

(** Compute the arrow types for all the backward functions.
TODO: merge with below?
*)
let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list =
(** Compute the arrow types for all the backward functions. *)
let compute_back_tys (dsg : Pure.decomposed_fun_sig)
(subst : (generic_args * trait_instance_id) option) : ty list =
List.map
(fun (back_sg : back_sg_info) ->
let effect_info = back_sg.effect_info in
(* Compute *)
let inputs = List.map snd back_sg.inputs in
let output = mk_simpl_tuple_ty back_sg.outputs in
let output = mk_output_ty_from_effect_info effect_info output in
mk_arrows inputs output)
let ty = mk_arrows inputs output in
(* Substitute - TODO: normalize *)
match subst with
| None -> ty
| Some (generics, tr_self) ->
let subst = make_subst_from_generics dsg.generics generics tr_self in
ty_substitute subst ty)
(RegionGroupId.Map.values dsg.back_sg)

(** Return the instantiated pure signature of a backward function, in the
case the forward/backward functions are merged (i.e., the forward functions
return the backward functions).
*)
let translate_ret_back_inst_fun_sig_from_decomposed
(dsg : Pure.decomposed_fun_sig) (generics : generic_args)
(gid : RegionGroupId.id) : inst_fun_sig =
assert !Config.return_back_funs;
let mk_output_ty = mk_output_ty_from_effect_info in
(* Lookup the signature information *)
let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
let effect_info = back_sg.effect_info in
(* Do not prepend the forward inputs *)
let inputs = List.map snd back_sg.inputs in
let output = mk_simpl_tuple_ty back_sg.outputs in
let output = mk_output_ty effect_info output in
(* Substitute the types *)
let tr_self = UnknownTrait __FUNCTION__ in
let subst = make_subst_from_generics dsg.generics generics tr_self in
let subst = ty_substitute subst in
let inputs = List.map subst inputs in
let output = subst output in
(* Return *)
{ inputs; output }

let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
Expand All @@ -1184,7 +1169,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
if !Config.return_back_funs then (
assert (gid = None);
(* Compute the arrow types for all the backward functions *)
let back_tys = compute_back_tys dsg in
let back_tys = compute_back_tys dsg None in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
Expand Down Expand Up @@ -1274,6 +1259,40 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) :
bs_ctx * var list =
List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars

(* Introduce variables for the backward functions *)
let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list =
(* We lookup the LLBC definition in an attempt to derive pretty names
for the backward functions. *)
let back_var_names =
let def_id = ctx.fun_decl.def_id in
let sg = ctx.fun_decl.signature in
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FRegular def_id)
ctx.fun_ctx.regions_hierarchies
in
List.map
(fun (gid, _) ->
let rg = RegionGroupId.nth regions_hierarchy gid in
let region_names =
List.map
(fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
rg.regions
in
let name =
match region_names with
| [] -> "back"
| [ Some r ] -> "back" ^ r
| _ ->
(* Concatenate all the region names *)
"back"
^ String.concat "" (List.filter_map (fun x -> x) region_names)
in
Some name)
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
fresh_vars back_vars ctx

let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
| Some v -> v
Expand Down Expand Up @@ -1728,7 +1747,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
match ctx.bid with
| None ->
if !Config.return_back_funs then
let back_tys = compute_back_tys ctx.sg in
let back_tys = compute_back_tys ctx.sg None in
let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
mk_output output
else mk_output ctx.sg.fwd_output
Expand Down Expand Up @@ -1883,22 +1902,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
fid call.regions_hierarchy sg
(List.map (fun _ -> None) sg.inputs)
in
let gids =
List.map
(fun (g : T.region_var_group) -> g.id)
call.regions_hierarchy
in
let back_sgs =
List.map
(translate_ret_back_inst_fun_sig_from_decomposed dsg generics)
gids
in
let tr_self = UnknownTrait __FUNCTION__ in
let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in
(* Introduce variables for the backward functions *)
let back_tys =
List.map
(fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output)
back_sgs
in
(* Compute a proper basename for the variables *)
let back_fun_name =
let name =
Expand Down Expand Up @@ -1934,6 +1940,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let back_funs =
List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
in
let gids =
List.map
(fun (g : T.region_var_group) -> g.id)
call.regions_hierarchy
in
let back_funs_map =
RegionGroupId.Map.of_list
(List.combine gids (List.map mk_texpression_from_var back_vars))
Expand Down Expand Up @@ -2338,6 +2349,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
| V.LoopCall ->
(* We need to introduce a call to the backward function corresponding
to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
let effect_info =
get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id)
Expand Down Expand Up @@ -2367,7 +2380,10 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in
let inputs =
if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
else List.concat [ fwd_inputs; back_inputs; back_state ]
in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
Expand All @@ -2391,28 +2407,43 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let ret_ty =
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
(* Create the expression for the function:
- it is either a call to a top-level function, if we split the
forward/backward functions
- or a call to the variable we introduced for the backward function,
if we merge the forward/backward functions *)
let func =
if !Config.return_back_funs then
RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
else
let func_ty = mk_arrows input_tys ret_ty in
let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; generics } in
{ e = Qualif func; ty = func_ty }
in
let call = mk_apps func args in
(* **Optimization**:
* =================
* We do a small optimization here: if the backward function doesn't
* have any output, we don't introduce any function call.
* See the comment in {!Config.filter_useless_monadic_calls}.
*
* TODO: use an option to disallow backward functions from updating the state.
* TODO: a backward function which only gives back shared borrows shouldn't
* update the state (state updates should only be used for mutable borrows,
* with objects like Rc for instance).
*)
if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None
=================
We do a small optimization here in case we split the forward/backward
functions.
If the backward function doesn't have any output, we don't introduce
any function call.
See the comment in {!Config.filter_useless_monadic_calls}.
TODO: use an option to disallow backward functions from updating the state.
TODO: a backward function which only gives back shared borrows shouldn't
update the state (state updates should only be used for mutable borrows,
with objects like Rc for instance).
*)
if
(not !Config.return_back_funs)
&& !Config.filter_useless_monadic_calls
&& outputs = [] && nstate = None
then (
(* No outputs - we do a small sanity check: the backward function
* should have exactly the same number of inputs as the forward:
* this number can be different only if the forward function returned
* a value containing mutable borrows, which can't be the case... *)
should have exactly the same number of inputs as the forward:
this number can be different only if the forward function returned
a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else
Expand Down Expand Up @@ -2860,35 +2891,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Introduce variables for the backward functions.
We lookup the LLBC definition in an attempt to derive pretty names
for those functions. *)
let back_var_names =
let def_id = ctx.fun_decl.def_id in
let sg = ctx.fun_decl.signature in
let regions_hierarchy =
LlbcAstUtils.FunIdMap.find (FRegular def_id)
ctx.fun_ctx.regions_hierarchies
in
List.map
(fun (gid, _) ->
let rg = RegionGroupId.nth regions_hierarchy gid in
let region_names =
List.map
(fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
rg.regions
in
let name =
match region_names with
| [] -> "back"
| [ Some r ] -> "back" ^ r
| _ ->
(* Concatenate all the region names *)
"back"
^ String.concat "" (List.filter_map (fun x -> x) region_names)
in
Some name)
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in
let _, back_vars = fresh_vars back_vars ctx in
let _, back_vars = fresh_back_vars_for_current_fun ctx in

(* Create the return expressions *)
let vars = fwd_var :: back_vars in
Expand Down Expand Up @@ -2964,8 +2967,32 @@ and translate_forward_end (ectx : C.eval_ctx)

(* Introduce a fresh output value for the forward function *)
let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in
(* Introduce fresh variables for the backward functions of the loop.
For now, the backward functions of the loop are the same as the
backward functions of the outer function.
*)
let ctx, back_funs_map, back_funs =
if !Config.return_back_funs then
let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
let back_funs =
List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
in
let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
let back_funs_map =
RegionGroupId.Map.of_list
(List.combine gids (List.map mk_texpression_from_var back_vars))
in
(ctx, Some back_funs_map, back_funs)
else (ctx, None, [])
in

(* Introduce patterns *)
let args, ctx, out_pats =
(* Create the pattern for the output value *)
let output_pat = mk_typed_pattern_from_var output_var None in
(* Add the returned backward functions (they might be empty) *)
let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in

(* Depending on the function effects:
* - add the fuel
Expand All @@ -2988,6 +3015,7 @@ and translate_forward_end (ectx : C.eval_ctx)
loop_info with
forward_inputs = Some args;
forward_output_no_state_no_result = Some output_var;
back_funs = back_funs_map;
}
in
let ctx =
Expand Down Expand Up @@ -3143,6 +3171,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
forward_inputs = None;
forward_output_no_state_no_result = None;
back_outputs = rg_to_given_back_tys;
back_funs = None;
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in
Expand Down

0 comments on commit cf3eea5

Please sign in to comment.