Skip to content

Commit

Permalink
Simplify the type of the merged fwd/back functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Dec 21, 2023
1 parent ccfcadc commit 2f68144
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 46 deletions.
26 changes: 26 additions & 0 deletions compiler/Config.ml
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,32 @@ let filter_useless_monadic_calls = ref true
*)
let filter_useless_functions = ref true

(** Simplify the forward/backward functions, in case we merge them
(i.e., the forward functions return the backward functions).
The simplification occurs as follows:
- if a forward function returns the unit type and has non-trivial backward
functions, then we remove the returned output.
- if a backward function doesn't have inputs, we evaluate it inside the
forward function and don't wrap it in a result.
Example:
{[
// LLBC:
fn incr(x: &mut u32) { *x += 1 }
// Translation without simplification:
let incr (x : u32) : result (unit * result u32) = ...
^^^^ ^^^^^^
| remove this result
remove the unit
// Translation with simplification:
let incr (x : u32) : result u32 = ...
]}
*)
let simplify_merged_fwd_backs = ref true

(** Use short names for the record fields.
Some backends can't disambiguate records when their field names have collisions.
Expand Down
6 changes: 6 additions & 0 deletions compiler/Pure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,11 @@ type fun_sig_info = {
fwd_info : inputs_info;
(** Information about the inputs of the forward function *)
effect_info : fun_effect_info;
ignore_output : bool;
(** In case we merge the forward/backward functions: should we ignore
the output (happens for forward functions if the output type is
[unit] and there are non-filtered backward functions)?
*)
}
[@@deriving show]

Expand Down Expand Up @@ -939,6 +944,7 @@ type back_sg_info = {
We derive those from the names of the inputs of the original LLBC
function. *)
effect_info : fun_effect_info;
filter : bool; (** Should we filter this backward function? *)
}
[@@deriving show]

Expand Down
7 changes: 4 additions & 3 deletions compiler/PureMicroPasses.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
let fun_sig = def.signature in
let fwd_info = fun_sig.fwd_info in
let fwd_effect_info = fwd_info.effect_info in
let ignore_output = fwd_info.ignore_output in

(* Generate the loop definition *)
let loop_fwd_effect_info = fwd_effect_info in
Expand All @@ -1358,7 +1359,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
}
in

{ fwd_info; effect_info = loop_fwd_effect_info }
{ fwd_info; effect_info = loop_fwd_effect_info; ignore_output }
in
assert (fun_sig_info_is_wf loop_fwd_sig_info);

Expand Down Expand Up @@ -2187,7 +2188,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
} =
decl.signature
in
let { fwd_info; effect_info } = fwd_info in
let { fwd_info; effect_info; ignore_output } = fwd_info in

let {
has_fuel;
Expand All @@ -2212,7 +2213,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
}
in

let fwd_info = { fwd_info; effect_info } in
let fwd_info = { fwd_info; effect_info; ignore_output } in
assert (fun_sig_info_is_wf fwd_info);
let signature =
{
Expand Down
1 change: 1 addition & 0 deletions compiler/PureUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ let mk_simpl_tuple_ty (tys : ty list) : ty =

let mk_bool_ty : ty = TLiteral TBool
let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args)
let ty_is_unit ty : bool = ty = mk_unit_ty

let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = TTuple; variant_id = None } in
Expand Down
159 changes: 116 additions & 43 deletions compiler/SymbolicToPure.ml
Original file line number Diff line number Diff line change
Expand Up @@ -979,30 +979,6 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
in
(* Compute the backward output, without the effect information *)
let fwd_output = translate_fwd_ty type_infos sg.output in
(* The additinoal information *)
let fwd_info =
(* *)
let has_fuel = fwd_fuel <> [] in
let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in
let num_inputs_with_fuel_no_state =
(* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
List.length fwd_fuel + num_inputs_no_fuel_no_state
in
let fwd_info : inputs_info =
{
has_fuel;
num_inputs_no_fuel_no_state;
num_inputs_with_fuel_no_state;
num_inputs_with_fuel_with_state =
(* We use the fact that [fwd_state_ty] has length 1 if there is a state,
and 0 otherwise *)
num_inputs_with_fuel_no_state + List.length fwd_state_ty;
}
in
let info = { fwd_info; effect_info = fwd_effect_info } in
assert (fun_sig_info_is_wf info);
info
in

(* Compute the type information for the backward function *)
(* Small helper to translate types for backward functions *)
Expand Down Expand Up @@ -1086,13 +1062,17 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
in
let inputs = inputs_no_state @ state in
let output_names, outputs = compute_back_outputs_for_gid gid in
let filter =
!Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
in
let info =
{
inputs;
inputs_no_state;
outputs;
output_names;
effect_info = back_effect_info;
filter;
}
in
(gid, info)
Expand All @@ -1102,6 +1082,39 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
(List.map compute_back_info_for_group regions_hierarchy)
in

(* The additional information about the forward function *)
let fwd_info =
(* *)
let has_fuel = fwd_fuel <> [] in
let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in
let num_inputs_with_fuel_no_state =
(* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
List.length fwd_fuel + num_inputs_no_fuel_no_state
in
let fwd_info : inputs_info =
{
has_fuel;
num_inputs_no_fuel_no_state;
num_inputs_with_fuel_no_state;
num_inputs_with_fuel_with_state =
(* We use the fact that [fwd_state_ty] has length 1 if there is a state,
and 0 otherwise *)
num_inputs_with_fuel_no_state + List.length fwd_state_ty;
}
in
let ignore_output =
if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
ty_is_unit fwd_output
&& List.exists
(fun (info : back_sg_info) -> not info.filter)
(RegionGroupId.Map.values back_sg)
else false
in
let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in
assert (fun_sig_info_is_wf info);
info
in

(* Generic parameters *)
let generics = translate_generic_params sg.generics in

Expand Down Expand Up @@ -1134,6 +1147,13 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
in
if effect_info.can_fail then mk_result_ty output else output

let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info)
(inputs : ty list) (ty : ty) : ty =
let output =
if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
in
if effect_info.can_fail && inputs <> [] then mk_result_ty output else output

(** Compute the arrow types for all the backward functions.
If a backward function has no inputs/outputs we filter it.
Expand All @@ -1151,7 +1171,9 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
None
else
let output = mk_simpl_tuple_ty outputs in
let output = mk_output_ty_from_effect_info effect_info output in
let output =
mk_back_output_ty_from_effect_info effect_info inputs output
in
let ty = mk_arrows inputs output in
(* Substitute - TODO: normalize *)
let ty =
Expand All @@ -1166,6 +1188,25 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
Some ty)
(RegionGroupId.Map.values dsg.back_sg)

(** In case we merge the fwd/back functions: compute the output type of
a function, from a decomposed signature. *)
let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
assert !Config.return_back_funs;
(* Compute the arrow types for all the backward functions *)
let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output =
(* We might need to ignore the output of the forward function
(if it is unit for instance) *)
let tys =
if dsg.fwd_info.ignore_output then back_tys
else dsg.fwd_output :: back_tys
in
mk_simpl_tuple_ty tys
in
mk_output_ty_from_effect_info effect_info output

let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
Expand All @@ -1180,19 +1221,12 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid, info.effect_info))
(RegionGroupId.Map.bindings dsg.back_sg))
in
(* Two cases depending on whether we split the forward/backward functions
or not *)
let mk_output_ty = mk_output_ty_from_effect_info in

let inputs, output =
(* Two cases depending on whether we split the forward/backward functions or not *)
if !Config.return_back_funs then (
assert (gid = None);
(* Compute the arrow types for all the backward functions *)
let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
let output = mk_output_ty effect_info output in
let output = compute_output_ty_from_decomposed dsg in
let inputs = dsg.fwd_inputs in
(inputs, output))
else
Expand Down Expand Up @@ -1785,7 +1819,11 @@ and translate_panic (ctx : bs_ctx) : texpression =
if !Config.return_back_funs then
let back_tys = compute_back_tys ctx.sg None in
let back_tys = List.filter_map (fun x -> x) back_tys in
let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
let tys =
if ctx.sg.fwd_info.ignore_output then back_tys
else ctx.sg.fwd_output :: back_tys
in
let output = mk_simpl_tuple_ty tys in
mk_output output
else mk_output ctx.sg.fwd_output
| Some bid ->
Expand All @@ -1798,6 +1836,9 @@ and translate_panic (ctx : bs_ctx) : texpression =
Remark: for now, we can't get there if we are inside a loop.
If inside a loop, we use {!translate_return_with_loop}.
Remark: in case we merge the forward/backward functions, we introduce
those in [translate_forward_end].
*)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
Expand Down Expand Up @@ -2648,6 +2689,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
If (true_e, false_e) )
in
let ty = true_e.ty in
log#ldebug
(lazy
("true_e.ty: "
^ pure_ty_to_string ctx true_e.ty
^ "\n\nfalse_e.ty: "
^ pure_ty_to_string ctx false_e.ty));
assert (ty = false_e.ty);
{ e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
Expand Down Expand Up @@ -2941,37 +2988,63 @@ and translate_forward_end (ectx : C.eval_ctx)
in
let fwd_e = translate_one_end ctx None in

(* Introduce the backward functions *)
(* Introduce the backward functions. *)
let back_el =
List.map
(fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
translate_one_end ctx (Some gid))
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in

(* Compute whether the backward expressions should be evaluated straight
away or not (i.e., if we should bind them with monadic let-bindings
or not). We evaluate them straight away if they can fail and have no
inputs *)
let evaluate_backs =
List.map
(fun (sg : back_sg_info) ->
if !Config.simplify_merged_fwd_backs then
sg.inputs = [] && sg.effect_info.can_fail
else false)
(RegionGroupId.Map.values ctx.sg.back_sg)
in

(* Introduce variables for the backward functions.
We lookup the LLBC definition in an attempt to derive pretty names
for those functions. *)
let _, back_vars = fresh_back_vars_for_current_fun ctx in

(* Create the return expressions *)
let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in
let vars =
let back_vars = List.filter_map (fun x -> x) back_vars in
if ctx.sg.fwd_info.ignore_output then back_vars
else fwd_var :: back_vars
in
let vars = List.map mk_texpression_from_var vars in
let ret = mk_simpl_tuple_texpression vars in
let state_var = List.map mk_texpression_from_var state_var in
let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
let ret = mk_result_return_texpression ret in

(* Bind the expressions for the backward function and the expression
for the computation of the forward output *)
(* Introduce all the let-bindings *)

(* Combine:
- the backward variables
- whether we should evaluate the expression for the backward function
(i.e., should we use a monadic let-binding or not - we do if the
backward functions don't have inputs and can fail)
- the expressions for the backward functions
*)
let back_vars_els =
List.filter_map
(fun (v, el) -> match v with None -> None | Some v -> Some (v, el))
(List.combine back_vars back_el)
(fun (v, (eval, el)) ->
match v with None -> None | Some v -> Some (v, eval, el))
(List.combine back_vars (List.combine evaluate_backs back_el))
in
let e =
List.fold_right
(fun (var, back_e) e ->
mk_let false (mk_typed_pattern_from_var var None) back_e e)
(fun (var, evaluate, back_e) e ->
mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
back_vars_els ret
in
(* Bind the expression for the forward output *)
Expand Down

0 comments on commit 2f68144

Please sign in to comment.