diff --git a/Makefile b/Makefile index 24acb2e..68c56e6 100644 --- a/Makefile +++ b/Makefile @@ -4,5 +4,8 @@ dumpast: opam-install-dev-deps: opam install ocamlformat ocaml-lsp-server ppx_tools -show-ppx-test: - dune exec -- pp/pp.exe test/test.ml +show-ppx-test-encoders: + dune exec -- pp/pp.exe test/test_encoders.ml + +show-ppx-test-decoders: + dune exec -- pp/pp.exe test/test_decoders.ml diff --git a/README.md b/README.md index 7d12368..6c2a415 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [mattjbray/ocaml-decoders](https://github.com/mattjbray/ocaml-decoders) is an excellent library for writing decoders using decoding combinators. However, writing out decoders by hand for more complicated types can be quite time-intensive. -This library helps by automatically producing the appropriate decoder for a particular type. +This library helps by automatically producing the appropriate decoder (and encoder!) for a particular type. There are two primary ways in which this library can be of use. (More details of both follows.) @@ -78,27 +78,21 @@ type bar = Int of int | String of string [@@deriving_inline decoders] [@@@deriving.end] ``` -Then, after running `dune build --auto-promote`, our file will become: +Then, after running `dune build --auto-promote`, our file will become (after applying `ocamlformat`): ```ocaml (* In file foo.ml *) module D = Decoders_yojson.Safe.Decode type bar = Int of int | String of string [@@deriving_inline decoders] - let _ = fun (_ : bar) -> () + let bar_decoder = let open D in - single_field - (function - | "Int" -> - let open D in - let (>>=::) fst rest = uncons rest fst in - D.int >>=:: ((fun arg0 -> succeed (Int arg0))) - | "String" -> - let open D in - let (>>=::) fst rest = uncons rest fst in - D.string >>=:: ((fun arg0 -> succeed (String arg0))) - | any -> D.fail @@ (Printf.sprintf "Unrecognized field: %s" any)) + single_field (function + | "Int" -> D.int >|= fun arg -> Int arg + | "String" -> D.string >|= fun arg -> String arg + | any -> D.fail @@ Printf.sprintf "Unrecognized field: %s" any) + let _ = bar_decoder [@@@deriving.end] ``` @@ -116,68 +110,114 @@ and op = Add | Sub | Mul | Div [@@deriving_inline decoders] [@@@deriving.end] ``` -after invoking `dune build --auto-promote` will yield: +after invoking `dune build --auto-promote` (plus `ocamlformat`) will yield: ```ocaml (* In file foo.ml *) type expr = Num of int | BinOp of op * expr * expr and op = Add | Sub | Mul | Div [@@deriving_inline decoders] - + let _ = fun (_ : expr) -> () let _ = fun (_ : op) -> () + [@@@ocaml.warning "-27"] + let expr_decoder op_decoder = - D.fix - (fun expr_decoder_aux -> - let open D in - single_field - (function - | "Num" -> - let open D in - let (>>=::) fst rest = uncons rest fst in - D.int >>=:: ((fun arg0 -> succeed (Num arg0))) - | "BinOp" -> - let open D in - let (>>=::) fst rest = uncons rest fst in - op_decoder >>=:: - ((fun arg0 -> - expr_decoder_aux >>=:: - (fun arg1 -> - expr_decoder_aux >>=:: - (fun arg2 -> - succeed (BinOp (arg0, arg1, arg2)))))) - | any -> D.fail @@ (Printf.sprintf "Unrecognized field: %s" any))) + D.fix (fun expr_decoder_aux -> + let open D in + single_field (function + | "Num" -> D.int >|= fun arg -> Num arg + | "BinOp" -> + let open D in + let ( >>=:: ) fst rest = uncons rest fst in + op_decoder >>=:: fun arg0 -> + expr_decoder_aux >>=:: fun arg1 -> + expr_decoder_aux >>=:: fun arg2 -> + succeed (BinOp (arg0, arg1, arg2)) + | any -> D.fail @@ Printf.sprintf "Unrecognized field: %s" any)) + let _ = expr_decoder + let op_decoder op_decoder = let open D in - single_field - (function - | "Add" -> succeed Add - | "Sub" -> succeed Sub - | "Mul" -> succeed Mul - | "Div" -> succeed Div - | any -> D.fail @@ (Printf.sprintf "Unrecognized field: %s" any)) + single_field (function + | "Add" -> succeed Add + | "Sub" -> succeed Sub + | "Mul" -> succeed Mul + | "Div" -> succeed Div + | any -> D.fail @@ Printf.sprintf "Unrecognized field: %s" any) + let _ = op_decoder let op_decoder = D.fix op_decoder let _ = op_decoder let expr_decoder = expr_decoder op_decoder let _ = expr_decoder + [@@@ocaml.warning "+27"] + [@@@deriving.end] ``` Notice that the mutual recursion is handled for you! +## Type vars +The `ppx` can also handle types with type variables: +```ocaml +type 'a wrapper = { wrapped : 'a } [@@deriving_inline decoders] +[@@@deriving.end] +``` +becomes (additionally with `ocamlformat`): + +```ocaml +type 'a record_wrapper = { wrapped : 'a } [@@deriving_inline decoders] + +let _ = fun (_ : 'a record_wrapper) -> () + +let record_wrapper_decoder a_decoder = + let open D in + let open D.Infix in + let* wrapped = field "wrapped" a_decoder in + succeed { wrapped } + +let _ = record_wrapper_decoder + +[@@@deriving.end] +``` +Notice that the decoder for the type variable becomes a parameter of the generated decoder! + +## Encoders +All of the above information also applies to generating encoders. Using the above type as an example: +```ocaml +type 'a wrapper = { wrapped : 'a } [@@deriving_inline decoders] +[@@@deriving.end] +``` +becomes (additionally with `ocamlformat`): + +```ocaml +type 'a wrapper = { wrapped : 'a } [@@deriving_inline encoders] + +let _ = fun (_ : 'a record_wrapper) -> () + +let wrapper_encoder a_encoder { wrapped } = + E.obj [ ("wrapped", a_encoder wrapped) ] + +let _ = record_wrapper_encoder + +[@@@deriving.end] +``` + +Of course, you can generate both by using `[@@deriving_inline decoders, encoders]` or `[@@deriving decoders, encoders]`. The corresponding pair will be inverses of one another provided that all prior referenced decoder/encoder pairs are inverses! + + ## Limitations - Some of the decoders can be quite complicated relative to what you would write by hand -- There is not great support for types which feature type variables - There are a lot of rough edges in places like: - Error reporting - Correctly handling `loc` +- In an ideal world, it would be nice to generate the corresponding decoders/encoders within their own submodule. It remains to be seen how this can be done. ## Future Work -- [ ] Automatically generate corresponding encoders which are inverses of the decoders -- [ ] Better handling of type variables - [ ] Simplify generated decoders - [ ] Generate decoders from a module +- [ ] How to handle types produced from functors inline ## Contributing diff --git a/src/expander.ml b/src/decoders_deriver.ml similarity index 80% rename from src/expander.ml rename to src/decoders_deriver.ml index 34111e1..e4c6b7c 100644 --- a/src/expander.ml +++ b/src/decoders_deriver.ml @@ -15,46 +15,28 @@ let apply_substitution ~orig ~substi = in mapper#expression -let generate_attribute v ~loc = - let open Ast_builder.Default in - pstr_attribute ~loc - (attribute ~loc - ~name:(Located.mk ~loc "ocaml.warning") - ~payload:(PStr [ pstr_eval ~loc (estring ~loc v) [] ])) - -let suppress_warning_27 ~loc = generate_attribute ~loc "-27" -let enforce_warning_27 ~loc = generate_attribute ~loc "+27" - -let wrap_27 xs = - (suppress_warning_27 ~loc:Location.none :: xs) - @ [ enforce_warning_27 ~loc:Location.none ] - -(* let suppress_warning_27 = *) -(* let suppress_warning_27 = *) -(* let loc = Location.none in *) -(* let payload = *) -(* PStr *) -(* [ *) -(* Ast_helper.Str.eval *) -(* (Ast_helper.Exp.constant (Pconst_string ("-27", loc, None))); *) -(* ] *) -(* in *) -(* let attr_name = "ocaml.warning" *) - -(* in *) -(* let attribute = Ast_builder.Default.attribute ~loc ~name:attr_name ~payload in *) -(* Ast_builder.Default.pstr_attribute ~loc attribute *) - -(* let enforce_warning_27 = _ *) let to_decoder_name i = i ^ "_decoder" +let rec flatten_longident ~loc = function + | Lident txt -> txt + | Ldot (longident, txt) -> flatten_longident ~loc longident ^ "." ^ txt + | Lapply (fst, snd) -> + Location.raise_errorf ~loc "Cannot handle functors:%s (%s)" + (flatten_longident ~loc fst) + (flatten_longident ~loc snd) + +let longident_to_decoder_name ~loc = + CCFun.(to_decoder_name % flatten_longident ~loc) + +let name_to_decoder_name (i : string loc) = to_decoder_name i.txt + let decoder_pvar_of_type_decl type_decl = Ast_builder.Default.pvar ~loc:type_decl.ptype_name.loc - (to_decoder_name type_decl.ptype_name.txt) + (name_to_decoder_name type_decl.ptype_name) let decoder_evar_of_type_decl type_decl = Ast_builder.Default.evar ~loc:type_decl.ptype_name.loc - (to_decoder_name type_decl.ptype_name.txt) + (name_to_decoder_name type_decl.ptype_name) (** We take an expr implementation with name NAME and turn it into: let rec NAME_AUX = fun () -> expr in NAME_AUX (). @@ -79,11 +61,6 @@ let pexp_fun_multiarg ~loc fun_imple (args : pattern list) = let args_rev = List.rev args in CCList.fold_left folder fun_imple args_rev -let lident_of_constructor_decl (cd : constructor_declaration) = - let loc = cd.pcd_name.loc in - let name = cd.pcd_name.txt in - Ast_builder.Default.Located.lident ~loc name - let rec expr_of_typ (typ : core_type) ~(substitutions : (core_type * expression) list) : expression = let loc = { typ.ptyp_loc with loc_ghost = true } in @@ -123,8 +100,8 @@ let rec expr_of_typ (typ : core_type) (* failwith *) (* (Format.sprintf "This alias was a failure...: %s\n" *) (* (string_of_core_type typ)) *) - | { ptyp_desc = Ptyp_constr ({ txt = Lident lid; _ }, []); _ } as other_type - -> ( + | { ptyp_desc = Ptyp_constr ({ txt = longident; loc = typ_loc }, []); _ } as + other_type -> ( (* In the case where our type is truly recursive, we need to instead do `type_aux ()` *) let eq (ct1 : core_type) (ct2 : core_type) = (* TODO: This is a terrible way to compare the types... *) @@ -132,7 +109,21 @@ let rec expr_of_typ (typ : core_type) in match CCList.assoc_opt ~eq other_type substitutions with | Some replacement -> replacement - | None -> Ast_builder.Default.evar ~loc (to_decoder_name lid)) + | None -> + Ast_builder.Default.evar ~loc + (longident_to_decoder_name ~loc:typ_loc longident)) + | { ptyp_desc = Ptyp_var var; _ } -> + Ast_builder.Default.evar ~loc @@ to_decoder_name var + | { ptyp_desc = Ptyp_constr ({ txt = longident; loc }, args); _ } -> + let cstr_dec = + Ast_builder.Default.evar ~loc + @@ longident_to_decoder_name ~loc longident + in + + let arg_decs = CCList.map (expr_of_typ ~substitutions) args in + Ast_builder.Default.eapply ~loc cstr_dec arg_decs + (* Location.raise_errorf ~loc "Cannot constructor decoder for %s" *) + (* (string_of_core_type typ) *) | _ -> Location.raise_errorf ~loc "Cannot construct decoder for %s" (string_of_core_type typ) @@ -155,7 +146,6 @@ and expr_of_tuple ~loc ~substitutions ?lift typs = string >>=:: fun arg2 -> bool >>=:: fun arg3 -> succeed (lift (arg1, arg2, arg3)) *) - let argn = Printf.sprintf "arg%d" in let typ_decoder_exprs = List.map (expr_of_typ ~substitutions) typs in let base = (* Consists of the initial setup partial function def, which is the inport and local definition, @@ -168,7 +158,7 @@ and expr_of_tuple ~loc ~substitutions ?lift typs = 0 ) in let fn_builder (partial_expr, i) next_decoder = - let var = argn i in + let var = Utils.argn i in let var_pat = Ast_builder.Default.pvar ~loc var in ( (fun body -> partial_expr @@ -178,7 +168,7 @@ and expr_of_tuple ~loc ~substitutions ?lift typs = let complete_partial_expr, var_count = List.fold_left fn_builder base typ_decoder_exprs in - let var_names = CCList.init var_count argn in + let var_names = CCList.init var_count Utils.argn in let var_tuple = let expr_list = List.map (fun s -> [%expr [%e Ast_builder.Default.evar ~loc s]]) var_names @@ -197,14 +187,22 @@ and expr_of_tuple ~loc ~substitutions ?lift typs = and expr_of_constr_decl ~substitutions ({ pcd_args; pcd_loc = loc; _ } as cstr_decl : constructor_declaration) = (* We assume at this point that the decomposition into indiviaul fields is handled by caller *) - if pcd_args = Pcstr_tuple [] then - let cstr = lident_of_constructor_decl cstr_decl in - let cstr = Ast_builder.Default.pexp_construct ~loc cstr None in - [%expr succeed [%e cstr]] - else - let cstr = lident_of_constructor_decl cstr_decl in - let sub_expr = expr_of_constr_arg ~substitutions ~loc ~cstr pcd_args in - sub_expr + match pcd_args with + | Pcstr_tuple [] -> + let cstr = Utils.lident_of_constructor_decl cstr_decl in + let cstr = Ast_builder.Default.pexp_construct ~loc cstr None in + [%expr succeed [%e cstr]] + | Pcstr_tuple [ single ] -> + let cstr = Utils.lident_of_constructor_decl cstr_decl in + let arg_e = Ast_builder.Default.evar ~loc "arg" in + let arg_p = Ast_builder.Default.pvar ~loc "arg" in + let cstr = Ast_builder.Default.pexp_construct ~loc cstr (Some arg_e) in + let single_dec = expr_of_typ single ~substitutions in + [%expr [%e single_dec] >|= fun [%p arg_p] -> [%e cstr]] + | _ -> + let cstr = Utils.lident_of_constructor_decl cstr_decl in + let sub_expr = expr_of_constr_arg ~substitutions ~loc ~cstr pcd_args in + sub_expr and expr_of_constr_arg ~loc ~cstr ~substitutions (arg : constructor_arguments) = match arg with @@ -296,7 +294,7 @@ let expr_of_variant ~loc ~substitutions cstrs = let implementation_generator ~(loc : location) ~rec_flag ~substitutions type_decl : expression = let rec_flag = really_recursive rec_flag [ type_decl ] in - let name = to_decoder_name type_decl.ptype_name.txt in + let name = name_to_decoder_name type_decl.ptype_name in let imple_expr = match (type_decl.ptype_kind, type_decl.ptype_manifest) with | Ptype_abstract, Some manifest -> expr_of_typ ~substitutions manifest @@ -313,7 +311,8 @@ let implementation_generator ~(loc : location) ~rec_flag ~substitutions let single_type_decoder_gen ~(loc : location) ~rec_flag type_decl : structure_item list = let rec_flag = really_recursive rec_flag [ type_decl ] in - let name = to_decoder_name type_decl.ptype_name.txt in + let name = name_to_decoder_name type_decl.ptype_name in + let substitutions = match rec_flag with | Nonrecursive -> [] @@ -326,7 +325,26 @@ let single_type_decoder_gen ~(loc : location) ~rec_flag type_decl : let imple = implementation_generator ~loc ~rec_flag ~substitutions type_decl in - let name = to_decoder_name type_decl.ptype_name.txt in + let name = name_to_decoder_name type_decl.ptype_name in + let params = + (* TODO: can we drop the non type vars? What are these? *) + CCList.filter_map + (fun (param, _) -> + match param.ptyp_desc with Ptyp_var var -> Some var | _ -> None) + type_decl.ptype_params + in + let args = + CCList.rev + @@ CCList.map + (fun param -> Ast_builder.Default.pvar ~loc (to_decoder_name param)) + params + in + let imple = + (* We need the type variables to become arguments *) + CCList.fold_left + (fun impl arg -> [%expr fun [%p arg] -> [%e impl]]) + imple args + in [%str let [%p Ast_builder.Default.pvar ~loc name] = [%e imple]] let rec mutual_rec_fun_gen ~loc @@ -338,12 +356,12 @@ let rec mutual_rec_fun_gen ~loc | type_decl :: rest -> let var = pvar ~loc:type_decl.ptype_name.loc - (to_decoder_name type_decl.ptype_name.txt) + (name_to_decoder_name type_decl.ptype_name) in let substitutions = match really_recursive Recursive [ type_decl ] with | Recursive -> - let name = to_decoder_name type_decl.ptype_name.txt in + let name = name_to_decoder_name type_decl.ptype_name in let substi = Ast_builder.Default.evar ~loc (name ^ "_aux") in let new_substitution = (core_type_of_type_declaration type_decl, substi) @@ -368,7 +386,7 @@ let rec mutual_rec_fun_gen ~loc else List.map (fun type_decl -> - let name = to_decoder_name type_decl.ptype_name.txt in + let name = name_to_decoder_name type_decl.ptype_name in pvar ~loc:type_decl.ptype_name.loc name) rest in @@ -376,10 +394,10 @@ let rec mutual_rec_fun_gen ~loc let dec = [%stri let [%p var] = [%e imple_as_lambda]] in let substi = pexp_apply ~loc - (evar ~loc (to_decoder_name type_decl.ptype_name.txt)) + (evar ~loc (name_to_decoder_name type_decl.ptype_name)) (List.map (fun decl -> - (Nolabel, evar ~loc (to_decoder_name decl.ptype_name.txt))) + (Nolabel, evar ~loc (name_to_decoder_name decl.ptype_name))) rest) in let new_substitution = @@ -421,10 +439,10 @@ let str_gens ~(loc : location) ~(path : label) let _path = path in match (really_recursive rec_flag type_decls, type_decls) with | Nonrecursive, _ -> - List.(flatten (map (single_type_decoder_gen ~loc ~rec_flag) type_decls)) + CCList.flat_map (single_type_decoder_gen ~loc ~rec_flag) type_decls | Recursive, [ type_decl ] -> - wrap_27 @@ single_type_decoder_gen ~loc ~rec_flag type_decl + Utils.wrap_27 @@ single_type_decoder_gen ~loc ~rec_flag type_decl | Recursive, _type_decls -> - wrap_27 + Utils.wrap_27 @@ mutual_rec_fun_gen ~substitutions:[] ~loc type_decls @ fix_mutual_rec_funs ~loc type_decls diff --git a/src/encoders_deriver.ml b/src/encoders_deriver.ml new file mode 100644 index 0000000..0692f0e --- /dev/null +++ b/src/encoders_deriver.ml @@ -0,0 +1,252 @@ +open Ppxlib + +let to_encoder_name i = i ^ "_encoder" + +let rec flatten_longident ~loc = function + | Lident txt -> txt + | Ldot (longident, txt) -> flatten_longident longident ~loc ^ "." ^ txt + | Lapply (fst, snd) -> + Location.raise_errorf ~loc "Cannot handle functors:%s (%s)" + (flatten_longident ~loc fst) + (flatten_longident ~loc snd) + +let longident_to_encoder_name ~loc = + CCFun.(to_encoder_name % flatten_longident ~loc) + +let name_to_encoder_name (i : string loc) = to_encoder_name i.txt + +let rec expr_of_typ (typ : core_type) : expression = + let loc = { typ.ptyp_loc with loc_ghost = true } in + match typ with + | [%type: unit] | [%type: unit] -> Ast_builder.Default.evar ~loc "E.null" + | [%type: int] -> Ast_builder.Default.evar ~loc "E.int" + | [%type: int32] + | [%type: Int32.t] + | [%type: int64] + | [%type: Int64.t] + | [%type: nativeint] + | [%type: Nativeint.t] -> + failwith "Cannot yet handle any int-like but int" + | [%type: float] -> Ast_builder.Default.evar ~loc "E.float" + | [%type: bool] -> Ast_builder.Default.evar ~loc "E.bool" + | [%type: char] -> + failwith "Cannot directly handle character; please cast to string first" + | [%type: string] | [%type: String.t] -> + Ast_builder.Default.evar ~loc "E.string" + | [%type: bytes] | [%type: Bytes.t] -> + failwith "Cannot handle Bytes" (* TODO: figure out strategy *) + | [%type: [%t? inner_typ] list] -> + let list_encoder = Ast_builder.Default.evar ~loc "E.list" in + let sub_expr = expr_of_typ inner_typ in + Ast_helper.Exp.apply ~loc list_encoder [ (Nolabel, sub_expr) ] + | [%type: [%t? inner_typ] array] -> + let array_encoder = Ast_builder.Default.evar ~loc "E.array" in + let sub_expr = expr_of_typ inner_typ in + Ast_helper.Exp.apply ~loc array_encoder [ (Nolabel, sub_expr) ] + | [%type: [%t? inner_typ] option] -> + let opt_encoder = Ast_builder.Default.evar ~loc "E.nullable" in + let sub_expr = expr_of_typ (* ~substitutions *) inner_typ in + Ast_helper.Exp.apply ~loc opt_encoder [ (Nolabel, sub_expr) ] + | { ptyp_desc = Ptyp_tuple typs; _ } -> expr_of_tuple ~loc typs + | { ptyp_desc = Ptyp_var var; _ } -> + Ast_builder.Default.evar ~loc @@ to_encoder_name var + | { ptyp_desc = Ptyp_constr ({ txt = Lident lid; _ }, []); _ } -> + (* The assumption here is that if we get to this point, this type is recursive, and + we just assume that we already have an encoder available. + TODO: Is this really the case? + *) + Ast_builder.Default.evar ~loc (to_encoder_name lid) + | { ptyp_desc = Ptyp_constr ({ txt = longident; loc }, args); _ } -> + let cstr_dec = + Ast_builder.Default.evar ~loc + @@ longident_to_encoder_name ~loc longident + in + + let arg_decs = CCList.map expr_of_typ args in + Ast_builder.Default.eapply ~loc cstr_dec arg_decs + | _ -> + Location.raise_errorf ~loc "Cannot construct encoder for %s" + (string_of_core_type typ) + +and expr_of_tuple ~loc (* ~substitutions ?lift *) typs = + (* Want to take type a * b * c and produce + fun (arg1,arg2,arg3) -> E.list E.value [E.a arg1; E.b arg2; E.c arg3] + *) + let typ_encoders_exprs = List.map expr_of_typ (* ~substitutions *) typs in + let eargs = + CCList.mapi + (fun idx _typ -> Ast_builder.Default.evar ~loc @@ Utils.argn idx) + typs + in + let encoded_args = + Ast_builder.Default.elist ~loc + @@ CCList.map2 + (fun encoder arg -> [%expr [%e encoder] [%e arg]]) + typ_encoders_exprs eargs + in + + let encoder_result = [%expr E.list E.value [%e encoded_args]] in + [%expr [%e encoder_result]] + +and expr_of_record ~loc (* ~substitutions ?lift *) label_decls = + (* To help understand what this function is doing, imagine we had + a type [type t = {i : int; s : string}]. Then this will render the encoder: + let t_encoder : t E.encoder = + fun {i; s} -> E.obj [("i", int i); ("s", string s)] + *) + let encode_field { pld_name; pld_type; _ } = + Ast_builder.Default.( + pexp_tuple ~loc + [ + estring ~loc pld_name.txt; + eapply ~loc (expr_of_typ pld_type) [ evar ~loc pld_name.txt ]; + ]) + in + let encode_all = + let open Ast_builder.Default in + eapply ~loc (evar ~loc "E.obj") + @@ [ elist ~loc (CCList.map encode_field label_decls) ] + in + encode_all + +and expr_of_constr_arg ~loc (arg : constructor_arguments) = + match arg with + | Pcstr_tuple tups -> expr_of_tuple ~loc tups + | Pcstr_record labl_decls -> expr_of_record ~loc labl_decls + +and expr_of_constr_decl + ({ pcd_args; pcd_loc = loc; _ } as cstr_decl : constructor_declaration) = + (* We assume at this point that the decomposition into indiviaul fields is handled by caller *) + let cstr_name = Ast_builder.Default.estring ~loc cstr_decl.pcd_name.txt in + let encoded_args = + match pcd_args with + | Pcstr_tuple [] -> [%expr E.null] + | Pcstr_tuple [ single ] -> + let enc = expr_of_typ single in + let on = Ast_builder.Default.evar ~loc (Utils.argn 0) in + [%expr [%e enc] [%e on]] + | _ -> expr_of_constr_arg ~loc pcd_args + in + + [%expr E.obj [ ([%e cstr_name], [%e encoded_args]) ]] + +and expr_of_variant ~loc cstrs = + (* Producing from type `A | B of b | C of c` + to + function + | A -> {"A":null} + | B b -> {"B": b_encoder b} + | C c - {"C": c_encoder c} + *) + let open Ast_builder.Default in + let to_case (cstr : constructor_declaration) = + let inner_pattern = + match cstr.pcd_args with + | Pcstr_tuple [] -> None + | Pcstr_tuple [ _tuple ] -> Some (pvar ~loc (Utils.argn 0)) + | Pcstr_tuple tuples -> + Some + (ppat_tuple ~loc + @@ CCList.mapi (fun i _tup -> pvar ~loc (Utils.argn i)) tuples) + | Pcstr_record lbl_decls -> + let arg_fields = + CCList.map + (fun { pld_name; _ } -> + ( { txt = Lident pld_name.txt; loc }, + Ast_builder.Default.pvar ~loc + (*TODO: is this right loc*) pld_name.txt )) + lbl_decls + in + Some (Ast_builder.Default.ppat_record ~loc arg_fields Closed) + in + + let vpat = + ppat_construct ~loc (Utils.lident_of_constructor_decl cstr) inner_pattern + in + let enc_expression = expr_of_constr_decl cstr in + case ~lhs:vpat ~guard:None ~rhs:enc_expression + in + let cases = List.map to_case cstrs in + pexp_function ~loc cases + +let implementation_generator ~(loc : location) type_decl : expression = + let _name = to_encoder_name type_decl.ptype_name.txt in + let imple_expr = + match (type_decl.ptype_kind, type_decl.ptype_manifest) with + | Ptype_abstract, Some manifest -> ( + let expr = expr_of_typ manifest in + match manifest with + | { ptyp_desc = Ptyp_tuple typs; _ } -> + (* In the case of a top level tuple, we need to explicitly wrap in a lambda with + the arguments + *) + let args = + Ast_builder.Default.ppat_tuple ~loc + @@ CCList.mapi + (fun i _typ -> Ast_builder.Default.pvar ~loc (Utils.argn i)) + typs + in + [%expr fun [%p args] -> [%e expr]] + | _ -> expr) + | Ptype_variant cstrs, None -> expr_of_variant ~loc cstrs + | Ptype_record label_decs, _ -> + (* And in the case of a top-level record, we also need to explicitly wrap in a lambda with args *) + let arg_fields = + CCList.map + (fun { pld_name; _ } -> + ( { txt = Lident pld_name.txt; loc }, + Ast_builder.Default.pvar ~loc + (*TODO: is this right loc*) pld_name.txt )) + label_decs + in + let args = Ast_builder.Default.ppat_record ~loc arg_fields Closed in + let expr = expr_of_record ~loc label_decs in + [%expr fun [%p args] -> [%e expr]] + | Ptype_open, _ -> Location.raise_errorf ~loc "Unhandled open" + | _ -> Location.raise_errorf ~loc "Unhandled mystery" + in + imple_expr + +let single_type_encoder_gen ~(loc : location) type_decl = + let imple = implementation_generator ~loc type_decl in + let name = to_encoder_name type_decl.ptype_name.txt in + let pat = Ast_builder.Default.pvar ~loc name in + let params = + (* TODO: can we drop the non type vars? What are these? *) + CCList.filter_map + (fun (param, _) -> + match param.ptyp_desc with Ptyp_var var -> Some var | _ -> None) + type_decl.ptype_params + in + let args = + CCList.rev + @@ CCList.map + (fun param -> Ast_builder.Default.pvar ~loc (to_encoder_name param)) + params + in + let imple = + (* We need the type variables to become arguments *) + CCList.fold_left + (fun impl arg -> [%expr fun [%p arg] -> [%e impl]]) + imple args + in + Ast_builder.Default.value_binding ~loc ~pat ~expr:imple +(* [%str let [%p Ast_builder.Default.pvar ~loc name] = [%e imple]] *) + +let str_gens ~(loc : location) ~(path : label) + ((rec_flag : rec_flag), type_decls) : structure_item list = + let _path = path in + let rec_flag = really_recursive rec_flag type_decls in + + (* CCList.flat_map (single_type_encoder_gen ~loc ~rec_flag) type_decls *) + match (really_recursive rec_flag type_decls, type_decls) with + | Nonrecursive, _ -> + [ + (Ast_builder.Default.pstr_value ~loc Nonrecursive + @@ List.(map (single_type_encoder_gen ~loc) type_decls)); + ] + | Recursive, type_decls -> + [ + (Ast_builder.Default.pstr_value ~loc Recursive + @@ List.(map (single_type_encoder_gen ~loc) type_decls)); + ] diff --git a/src/ppx_deriving_decoders.ml b/src/ppx_deriving_decoders.ml index f5ed423..0dfe62e 100644 --- a/src/ppx_deriving_decoders.ml +++ b/src/ppx_deriving_decoders.ml @@ -1,8 +1,13 @@ open Ppxlib -let name = "decoders" +let () = + let name = "decoders" in + let str_type_decl = Deriving.Generator.make_noarg Decoders_deriver.str_gens in + (* let sig_type_decl = Deriving.Generator.make_noarg sig_gen in *) + Deriving.add name ~str_type_decl (* ~sig_type_decl *) |> Deriving.ignore let () = - let str_type_decl = Deriving.Generator.make_noarg Expander.str_gens in + let name = "encoders" in + let str_type_decl = Deriving.Generator.make_noarg Encoders_deriver.str_gens in (* let sig_type_decl = Deriving.Generator.make_noarg sig_gen in *) Deriving.add name ~str_type_decl (* ~sig_type_decl *) |> Deriving.ignore diff --git a/src/utils.ml b/src/utils.ml new file mode 100644 index 0000000..5ef9ca6 --- /dev/null +++ b/src/utils.ml @@ -0,0 +1,22 @@ +open Ppxlib + +let generate_attribute v ~loc = + let open Ast_builder.Default in + pstr_attribute ~loc + (attribute ~loc + ~name:(Located.mk ~loc "ocaml.warning") + ~payload:(PStr [ pstr_eval ~loc (estring ~loc v) [] ])) + +let suppress_warning_27 ~loc = generate_attribute ~loc "-27" +let enforce_warning_27 ~loc = generate_attribute ~loc "+27" + +let wrap_27 xs = + (suppress_warning_27 ~loc:Location.none :: xs) + @ [ enforce_warning_27 ~loc:Location.none ] + +let lident_of_constructor_decl (cd : constructor_declaration) = + let loc = cd.pcd_name.loc in + let name = cd.pcd_name.txt in + Ast_builder.Default.Located.lident ~loc name + +let argn = Printf.sprintf "arg%d" diff --git a/test/dummy.ml b/test/dummy.ml index 27cd62d..05747b2 100644 --- a/test/dummy.ml +++ b/test/dummy.ml @@ -1,5 +1,12 @@ module D = Decoders_yojson.Safe.Decode +module Blah (E : Decoders.Encode.S) = struct + type int_wrap = int + and int_list = int list + and str = string (* and int_str = int * string *) [@@deriving encoders] + + [@@@deriving.end] +end (* type my_list = Null | L of my_list [@@deriving_inline decoders] *) (* [@@@deriving.end] *) diff --git a/test/dune b/test/dune index 3afb581..3cbeb54 100644 --- a/test/dune +++ b/test/dune @@ -1,5 +1,5 @@ (library (name test) (inline_tests) - (libraries decoders decoders-yojson) + (libraries decoders containers decoders-yojson) (preprocess (pps ppx_deriving_decoders ppx_inline_test ))) diff --git a/test/test.ml b/test/test_decoders.ml similarity index 79% rename from test/test.ml rename to test/test_decoders.ml index e00a9fd..169574c 100644 --- a/test/test.ml +++ b/test/test_decoders.ml @@ -49,6 +49,9 @@ type a1 = { l : b1 option; m : c1 option } and b1 = { n : c1 } and c1 = { o : a1 } [@@deriving decoders] +type 'a record_wrapper = { wrapped : 'a } [@@deriving decoders] +type int_record_wrapper = int record_wrapper [@@deriving decoders] + let%test "int" = match D.decode_string my_int_decoder "1234" with | Ok i -> i = 1234 @@ -118,7 +121,7 @@ let%test "deep tuple" = | Error _ -> false let%test "basic constructor" = - match D.decode_string my_basic_cstr_decoder {|{"Int": [10]}|} with + match D.decode_string my_basic_cstr_decoder {|{"Int": 10}|} with | Ok b -> b = Int 10 | Error _ -> false @@ -141,7 +144,7 @@ let%test "basic record" = let%test "complex record" = match D.decode_string my_complex_record_decoder - {|{"basic" : {"i": 10}, "cstr": {"Int": [10]}}|} + {|{"basic" : {"i": 10}, "cstr": {"Int": 10}}|} with | Ok b -> b = { basic = { i = 10 }; cstr = Int 10 } | Error _ -> false @@ -156,7 +159,7 @@ let%test "simple constructor-less variant" = | _ -> false let%test "mixed constructor/less variant" = - (match D.decode_string status_decoder {|{"Online": [10]}|} with + (match D.decode_string status_decoder {|{"Online": 10}|} with | Ok (Online 10) -> true | _ -> false) && @@ -169,9 +172,12 @@ let%test "my list" = | Ok Null -> true | _ -> false) && - match D.decode_string my_list_decoder {|{"L": [{"Null": {}}]}|} with + match D.decode_string my_list_decoder {|{"L": {"Null": {}}}|} with | Ok (L Null) -> true - | _ -> false + | Ok _ -> false + | Error e -> + print_endline @@ D.string_of_error e; + false let%test "variant w/ record constructor" = (match D.decode_string constr_w_rec_decoder {|{"Empty": null}|} with @@ -220,8 +226,8 @@ let%test "expression mutually-recursive decoder" = D.decode_string expr_decoder {|{"BinOp" : [ {"Add": {}}, - {"BinOp" : [{"Div": {}}, {"Num": [10]}, {"Num": [5]}]}, - {"BinOp" : [{"Mul": {}}, {"Num": [10]}, {"Num": [3]}]} + {"BinOp" : [{"Div": {}}, {"Num": 10}, {"Num": 5}]}, + {"BinOp" : [{"Mul": {}}, {"Num": 10}, {"Num": 3}]} ]}|} with | Ok (BinOp (Add, BinOp (Div, Num 10, Num 5), BinOp (Mul, Num 10, Num 3))) -> @@ -230,3 +236,40 @@ let%test "expression mutually-recursive decoder" = | Error e -> print_endline @@ D.string_of_error e; false + +let%test "simple type var" = + match D.decode_string int_record_wrapper_decoder {|{"wrapped":-2389}|} with + | Ok { wrapped = -2389 } -> true + | _ -> false + +module Blah = struct + type t = int [@@deriving decoders] +end + +type blah_wrapped = Blah.t record_wrapper [@@deriving decoders] + +let%test "basic module-wrapped type" = + match D.decode_string blah_wrapped_decoder {|{"wrapped":10110}|} with + | Ok { wrapped = 10110 } -> true + | _ -> false + +module Outer = struct + module Inner = struct + type t = string [@@deriving decoders] + end +end + +type outer_inner_wrapped = Outer.Inner.t record_wrapper [@@deriving decoders] + +let%test "basic module-wrapped type" = + match D.decode_string outer_inner_wrapped_decoder {|{"wrapped":"value"}|} with + | Ok { wrapped = "value" } -> true + | _ -> false + +type ('a, 'b) double_wrap = { fst : 'a; snd : 'b } [@@deriving decoders] +type double_wrapped = (string, int) double_wrap [@@deriving decoders] + +let%test "double type var" = + match D.decode_string double_wrapped_decoder {|{"fst":"99","snd":100}|} with + | Ok { fst = "99"; snd = 100 } -> true + | _ -> false diff --git a/test/test_encoders.ml b/test/test_encoders.ml new file mode 100644 index 0000000..aeab430 --- /dev/null +++ b/test/test_encoders.ml @@ -0,0 +1,100 @@ +module E = Decoders_yojson.Safe.Encode + +type int_wrap = int [@@deriving encoders] +type int_list = int list [@@deriving encoders] +type int_array = int array [@@deriving encoders] +type wrapped_int = { int : int } [@@deriving encoders] +type wrapped_int_string = { i : int; s : string } [@@deriving encoders] +type int_string = int * string [@@deriving encoders] +type basic_recur = Empty | Rec of basic_recur [@@deriving encoders] + +type expr = Num of int | BinOp of op * expr * expr +and op = Add | Sub | Mul | Div [@@deriving encoders] + +type vars = + | Int of int + | Str of string + | Tup of int * string + | Rec of { i : int; s : string } + | Nothing +[@@deriving encoders] + +let%test "int_wrap" = + match E.encode_string int_wrap_encoder 1234 with "1234" -> true | _ -> false + +let%test "int_list" = + match E.encode_string int_list_encoder [ 1; 2; 3; 4 ] with + | {|[1,2,3,4]|} -> true + | _ -> false + +let%test "int_array" = + match E.encode_string int_array_encoder [| 1; 2; 3; 4 |] with + | {|[1,2,3,4]|} -> true + | _ -> false + +let%test "wrapped_int" = + match E.encode_string wrapped_int_encoder { int = 101 } with + | {|{"int":101}|} -> true + | _ -> false + +let%test "wrapped_int_string_string" = + match E.encode_string wrapped_int_string_encoder { i = -10; s = "super" } with + | {|{"i":-10,"s":"super"}|} -> true + | _ -> false + +let%test "int_string" = + match E.encode_string int_string_encoder (15, "the string") with + | {|[15,"the string"]|} -> true + | _ -> false + +let%test "vars" = + (match E.encode_string vars_encoder (Int 10) with + | {|{"Int":10}|} -> true + | _ -> false) + && (match E.encode_string vars_encoder (Str "something") with + | {|{"Str":"something"}|} -> true + | _ -> false) + && (match E.encode_string vars_encoder (Tup (43, "another")) with + | {|{"Tup":[43,"another"]}|} -> true + | _ -> false) + && (match E.encode_string vars_encoder (Rec { i = -43; s = "inner" }) with + | {|{"Rec":{"i":-43,"s":"inner"}}|} -> true + | _ -> false) + && + match E.encode_string vars_encoder Nothing with + | {|{"Nothing":null}|} -> true + | _ -> false + +let%test "basic_recursion" = + match E.encode_string basic_recur_encoder (Rec (Rec Empty)) with + | {|{"Rec":{"Rec":{"Empty":null}}}|} -> true + | _ -> false + +type 'a record_wrapper = { wrapped : 'a } [@@deriving encoders] +type int_record_wrapper = int record_wrapper [@@deriving encoders] + +let%test "basic type var" = + match E.encode_string int_record_wrapper_encoder { wrapped = 9876 } with + | {|{"wrapped":9876}|} -> true + | _ -> false + +type ('a, 'b) double_wrap = { fst : 'a; snd : 'b } [@@deriving encoders] +type double_wrapped = (string, int) double_wrap [@@deriving encoders] + +let%test "double type var" = + match E.encode_string double_wrapped_encoder { fst = "9"; snd = 10 } with + | {|{"fst":"9","snd":10}|} -> true + | _ -> false + +module Outer = struct + module Inner = struct + type t = string [@@deriving encoders] + end +end + +type outer_inner_wrapped = Outer.Inner.t record_wrapper [@@deriving encoders] + +let%test "module wrapped" = + match E.encode_string outer_inner_wrapped_encoder { wrapped = "a thing" } with + | {|{"wrapped":"a thing"}|} -> true + | _ -> false diff --git a/test/test_invert.ml b/test/test_invert.ml new file mode 100644 index 0000000..6b69772 --- /dev/null +++ b/test/test_invert.ml @@ -0,0 +1,91 @@ +module D = Decoders_yojson.Safe.Decode +module E = Decoders_yojson.Safe.Encode + +type expr = + | Int of int + | Real of float + | Add of expr * expr + | Sub of expr * expr + | Mul of expr * expr + | Div of expr * expr +[@@deriving decoders, encoders] + +let make_id enc dec = + let str_enc = E.encode_string enc in + let str_dec = D.decode_string dec in + CCFun.(str_dec % str_enc) + +let check id v = match id v with Ok v' when v = v' -> true | _ -> false +let expr_id = make_id expr_encoder expr_decoder + +let expr_id_print v = + let str_enc = E.encode_string expr_encoder in + let str_dec = D.decode_string expr_decoder in + let s = str_enc v in + print_endline s; + str_dec s + +let%test "expr_inv:1" = check expr_id (Int 10) + +let%test "expr_inv:2" = + check expr_id + (Add + ( Add (Int 10, Int (-5)), + Div (Real 1.4929, Sub (Real (-5392.1239230), Int 58292349823)) )) + +type my_list = Null | L of my_list [@@deriving decoders, encoders] + +let my_list_id = make_id my_list_encoder my_list_decoder +let%test "my_list:1" = check my_list_id Null +let%test "my_list:2" = check my_list_id (L (L (L (L Null)))) + +type a_rec = { b : b_rec option } +and b_rec = { a : a_rec option } [@@deriving decoders, encoders] + +let a_rec_id = make_id a_rec_encoder a_rec_decoder +let%test "a_rec:1" = check a_rec_id { b = None } +let%test "a_rec:2" = check a_rec_id { b = Some { a = None } } + +let%test "a_rec:3" = + check a_rec_id { b = Some { a = Some { b = Some { a = None } } } } + +(* More complex mutual recursive type *) +type a1 = { l : b1 option; m : c1 option } +and b1 = { n : c1 } +and c1 = { o : a1 } [@@deriving decoders, encoders] + +let a1_id = make_id a1_encoder a1_decoder +let%test "a1:1" = check a1_id { l = None; m = None } + +let%test "a1:2" = + check a1_id + { + l = + Some + { n = { o = { l = None; m = Some { o = { l = None; m = None } } } } }; + m = None; + } + +(* Random usage of modules intermixed with type vars *) + +module Outer = struct + type ('a, 'b, 'c) v = Fst of 'a list | Snd of 'b option | Trd of 'c + [@@deriving decoders, encoders] + + [@@@deriving.end] + + module Inner = struct + type t = int [@@deriving decoders, encoders] + end +end + +type nesting = (string, Outer.Inner.t, bool) Outer.v +[@@deriving decoders, encoders] + +let nesting_id = make_id nesting_encoder nesting_decoder + +let%test "nesting:1" = + check nesting_id (Fst [ "there"; "is"; "some"; "string" ]) + +let%test "nesting:2" = check nesting_id (Snd (Some 10)) +let%test "nesting:3" = check nesting_id (Trd false)