Skip to content

Commit

Permalink
Merge Records into ADT
Browse files Browse the repository at this point in the history
  • Loading branch information
Halbaroth committed Jul 9, 2024
1 parent a7cf228 commit 0e0a6a1
Show file tree
Hide file tree
Showing 22 changed files with 193 additions and 664 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals_intf Intervals_core Intervals
Ite_rel Matching Matching_types Polynome Records Records_rel
Ite_rel Matching Matching_types Polynome
Satml_frontend_hybrid Satml_frontend Satml Sat_solver Sat_solver_sig
Sig Sig_rel Theory Uf Use Rel_utils Bitlist
; structures
Expand Down
16 changes: 11 additions & 5 deletions src/lib/frontend/cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,18 @@ let rec make_term quant_basename t =
E.mk_term (Sy.Op Sy.Concat)
[mk_term t1; mk_term t2] ty

| TTdot (t, s) ->
E.mk_term (Sy.Op (Sy.Access (Uid.of_hstring s))) [mk_term t] ty

| TTrecord lbs ->
| TTrecord (ty, lbs) ->
let lbs = List.map (fun (_, t) -> mk_term t) lbs in
E.mk_record lbs ty
let cstr =
match ty with
| Tadt (name, params, true) ->
begin match Ty.type_body name params with
| [{ constr; _ }] -> Uid.show constr
| _ -> assert false
end
| _ -> assert false
in
E.mk_constr (Uid.of_string cstr) lbs ty

| TTlet (binders, t2) ->
let binders =
Expand Down
160 changes: 11 additions & 149 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -530,86 +530,17 @@ let rec dty_to_ty ?(update = false) ?(is_var = false) dty =
| _ -> unsupported "Type %a" DE.Ty.print dty

and handle_ty_app ?(update = false) ty_c l =
(* Applies the substitutions in [tysubsts] to each encountered type
variable. *)
let rec apply_ty_substs tysubsts ty =
match ty with
| Ty.Tvar { v; _ } ->
Ty.M.find v tysubsts

| Text (tyl, hs) ->
Ty.Text (List.map (apply_ty_substs tysubsts) tyl, hs)

| Tfarray (ti, tv) ->
Tfarray (
apply_ty_substs tysubsts ti,
apply_ty_substs tysubsts tv
)

| Tadt (hs, tyl) ->
Tadt (hs, List.map (apply_ty_substs tysubsts) tyl)

| Trecord ({ args; lbs; _ } as rcrd) ->
Trecord {
rcrd with
args = List.map (apply_ty_substs tysubsts) args;
lbs = List.map (
fun (hs, t) ->
hs, apply_ty_substs tysubsts t
) lbs;
}

| _ -> ty
in
let tyl = List.map (dty_to_ty ~update) l in
(* Recover the initial versions of the types and apply them on the provided
type arguments stored in [tyl]. *)
match Cache.find_ty ty_c with
| Tadt (hs, _) -> Tadt (hs, tyl)

| Trecord { args; _ } as ty ->
let tysubsts =
List.fold_left2 (
fun acc tv ty ->
match tv with
| Ty.Tvar { v; _ } -> Ty.M.add v ty acc
| _ -> assert false
) Ty.M.empty args tyl
in
apply_ty_substs tysubsts ty

| Tadt (hs, _, record) -> Tadt (hs, tyl, record)
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

(** Handles a simple type declaration. *)
let mk_ty_decl (ty_c: DE.ty_cst) =
match DT.definition ty_c with
| Some (
Adt { cases = [| { cstr = { id_ty; _ } as cstr; dstrs; _ } |]; _ } as adt
) ->
(* Records and adts that only have one case are treated in the same way,
and considered as records. *)
Nest.add_nest [adt];
let tyvl = Cache.store_ty_vars_ret id_ty in
let rev_lbs =
Array.fold_left (
fun acc c ->
match c with
| Some (DE.{ id_ty; _ } as id) ->
let pty = dty_to_ty id_ty in
(Uid.of_dolmen id, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c

) [] dstrs
in
let lbs = List.rev rev_lbs in
let record_constr = Uid.of_dolmen cstr in
let ty = Ty.trecord ~record_constr tyvl (Uid.of_dolmen ty_c) lbs in
Cache.store_ty ty_c ty

| Some (Adt { cases; _ } as adt) ->
Nest.add_nest [adt];
let uid = Uid.of_dolmen ty_c in
Expand Down Expand Up @@ -668,30 +599,7 @@ let mk_term_decl ({ id_ty; path; tags; _ } as tcst: DE.term_cst) =
let mk_mr_ty_decls (tdl: DE.ty_cst list) =
let handle_ty_decl (ty: Ty.t) (tdef: DE.Ty.def option) =
match ty, tdef with
| Trecord { args; name; record_constr; _ },
Some (
Adt { cases = [| { dstrs; _ } |]; ty = ty_c; _ }
) ->
let rev_lbs =
Array.fold_left (
fun acc c ->
match c with
| Some (DE.{ id_ty; _ } as id) ->
let pty = dty_to_ty id_ty in
(Uid.of_dolmen id, pty) :: acc
| _ ->
Fmt.failwith
"Unexpected null label for some field of the record type %a"
DE.Ty.Const.print ty_c
) [] dstrs
in
let lbs = List.rev rev_lbs in
let ty =
Ty.trecord ~record_constr args name lbs
in
Cache.store_ty ty_c ty

| Tadt (hs, tyl), Some (Adt { cases; ty = ty_c; _ }) ->
| Tadt (hs, tyl, _), Some (Adt { cases; ty = ty_c; _ }) ->
let rev_cs =
Array.fold_left (
fun accl DE.{ cstr; dstrs; _ } ->
Expand All @@ -713,37 +621,16 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =

| _ -> assert false
in
(* If there are adts in the list of type declarations then records are
converted to adts, because that's how it's done in the legacy typechecker.
But it might be more efficient not to do that. *)
let rev_tdefs, contains_adts =
List.fold_left (
fun (acc, ca) ty_c ->
match DT.definition ty_c with
| Some (Adt { record; cases; _ } as df)
when not record && Array.length cases > 1 ->
df :: acc, true
| Some (Adt _ as df) ->
df :: acc, ca
| Some Abstract | None ->
assert false
) ([], false) tdl
in
let rev_tdefs = List.rev_map (fun td -> Option.get @@ DT.definition td) tdl in
Nest.add_nest rev_tdefs;
let rev_l =
List.fold_left (
fun acc tdef ->
match tdef with
| DE.Adt { cases; record; ty = ty_c; } as adt ->
| DE.Adt { cases; ty = ty_c; _ } as adt ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let uid = Uid.of_dolmen ty_c in
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else
Ty.t_adt uid tyvl
in
let ty = Ty.t_adt uid tyvl in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc

Expand Down Expand Up @@ -970,14 +857,8 @@ let rec mk_expr
E.mk_term sy [] ty

| B.Constructor _ ->
begin match dty_to_ty term_ty with
| Trecord _ as ty ->
E.mk_record [] ty
| Tadt _ as ty ->
E.mk_constr (Uid.of_dolmen tcst) [] ty
| ty ->
Fmt.failwith "unexpected type %a@." Ty.print ty
end
let ty = dty_to_ty term_ty in
E.mk_constr (Uid.of_dolmen tcst) [] ty

| _ -> unsupported "Constant term %a" DE.Term.print term
end
Expand Down Expand Up @@ -1018,10 +899,7 @@ let rec mk_expr
let e = aux_mk_expr x in
let sy =
match Cache.find_ty adt with
| Trecord _ ->
Sy.Op (Sy.Access (Uid.of_dolmen destr))
| Tadt _ ->
Sy.destruct (Uid.of_dolmen destr)
| Tadt _ -> Sy.destruct (Uid.of_dolmen destr)
| _ -> assert false
in
E.mk_term sy [e] ty
Expand Down Expand Up @@ -1053,11 +931,6 @@ let rec mk_expr
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]

| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
replace the tester of a record by the true literal. *)
E.vrai
| _ -> assert false
end

Expand Down Expand Up @@ -1342,20 +1215,9 @@ let rec mk_expr

| B.Constructor _, _ ->
let ty = dty_to_ty term_ty in
begin match ty with
| Ty.Tadt _ ->
let sy = Sy.constr @@ Uid.of_dolmen tcst in
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term sy l ty
| Ty.Trecord _ ->
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_record l ty
| _ ->
Fmt.failwith
"Constructor error: %a does not belong to a record nor an\
algebraic data type"
DE.Term.print app_term
end
let sy = Sy.constr @@ Uid.of_dolmen tcst in
let l = List.map (fun t -> aux_mk_expr t) args in
E.mk_term sy l ty

| B.Coercion, [ x ] ->
begin match DT.view (DE.Term.ty x), DT.view term_ty with
Expand Down
22 changes: 1 addition & 21 deletions src/lib/frontend/models.ml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ module Pp_smtlib_term = struct
asprintf "%a" Ty.pp_smtlib t

let rec print fmt t =
let {Expr.f;xs;ty; _} = Expr.term_view t in
let {Expr.f;xs; _} = Expr.term_view t in
match f, xs with

| Sy.Lit lit, xs ->
Expand Down Expand Up @@ -159,26 +159,6 @@ module Pp_smtlib_term = struct
| Sy.Op Sy.Extract (i, j), [e] ->
fprintf fmt "%a^{%d,%d}" print e i j

| Sy.Op (Sy.Access field), [e] ->
if Options.get_output_smtlib () then
fprintf fmt "(%a %a)" Uid.pp field print e
else
fprintf fmt "%a.%a" print e Uid.pp field

| Sy.Op (Sy.Record), _ ->
begin match ty with
| Ty.Trecord { Ty.lbs = lbs; _ } ->
assert (List.length xs = List.length lbs);
fprintf fmt "{";
ignore (List.fold_left2 (fun first (field,_) e ->
fprintf fmt "%s%a = %a" (if first then "" else "; ")
Uid.pp field print e;
false
) true lbs xs);
fprintf fmt "}";
| _ -> assert false
end

(* TODO: introduce PrefixOp in the future to simplify this ? *)
| Sy.Op op, [e1; e2] when op == Sy.Pow || op == Sy.Integer_round ||
op == Sy.Max_real || op == Sy.Max_int ||
Expand Down
Loading

0 comments on commit 0e0a6a1

Please sign in to comment.