diff --git a/lib/ast/astDef.ml b/lib/ast/astDef.ml index 745e1ad..d291c42 100644 --- a/lib/ast/astDef.ml +++ b/lib/ast/astDef.ml @@ -494,6 +494,10 @@ module Type = struct | _ -> Error.error type_attr.type_loc "Expected type variable" + let field_val = function + | App (Fld, [val_typ], _) -> val_typ + | _ -> failwith "Expected field type" + let set_elem = function | App (Map, [elem; App (Bool, _, _)], _) -> elem | _ -> failwith "Expected Set type" @@ -1088,6 +1092,39 @@ module Expr = struct | App (_, es, _) -> Set.union_list (module QualIdent) (List.map es ~f:au_preds) | Binder (_, _, _, e, _) -> au_preds e + + (** Lift quantifiers up, but only if no new quantifier alternations are introduced *) + (*let lift_quantifiers (expr: t) : t = + let rec merge sm zs xs ys ys2 = + match xs, ys with + | (x, typ1) :: xs1, (y, typ2) :: ys1 -> + if Type.(typ1 = typ2) + then merge (Map.add_exn ~key:x ~data:y sm) ((y, typ2) :: zs) xs1 (ys2 @ ys1) [] + else merge sm zs xs ys1 ((y, typ2) :: ys2) + | [], _ -> sm, ys @ ys2 @ zs + | _, [] -> + if List.is_empty ys2 then sm, xs @ zs + else merge sm (List.hd_exn xs :: zs) (List.tl_exn xs) ys2 [] + in + let rec lift_op_same loc tvs op b fs = + let fs_same, fs_diff = List.partition_map ~f:(function + | Binder (Exists, + let fs1, vs = + List.fold_right ~f:(fun f (fs2, vs2) -> + let f1, vs1 = lift tvs (mk_binder b tvs f) in + let sm, vs = merge (Map.empty (module QualIdent)) [] vs1 vs2 [] in + subst_idents sm f1 :: fs2, vs) + fs ~init:([], []) + in + match op with + | And -> mk_and ~loc fs1, vs + | Or -> mk_or ~loc fs1, vs + | _ -> assert false + in + + and lift tvs = function e -> e + in*) + let rec existential_vars_type ?(acc = Map.empty (module Ident)) ?(pol = true) (expr: t) : Type.t IdentMap.t = match expr with (* TODO: Biimplication? *) diff --git a/lib/ast/progUtils.ml b/lib/ast/progUtils.ml index 8d074bc..edf3ffe 100644 --- a/lib/ast/progUtils.ml +++ b/lib/ast/progUtils.ml @@ -206,45 +206,25 @@ let intros_type_module ~(loc : location) in *) introduce_typecheck_symbol ~loc ~f symbol -let rec does_symbol_implement_ra (symbol : AstDef.Module.symbol) : bool t = - (*Logs.debug (fun m -> - m "ProgUtils.does_symbol_implement_ra: symbol = %a" - AstDef.Ident.pr - (AstDef.Symbol.to_name symbol));*) - let open Syntax in - match symbol with - | ModDef mod_def -> - let mod_decl = mod_def.mod_decl in - return mod_decl.mod_decl_is_ra - | ModInst mod_inst -> ( - let* does_type_implement_ra = - let* mod_inst_type_symbol = - find_and_reify mod_inst.mod_inst_type - in - does_symbol_implement_ra mod_inst_type_symbol - in - - if does_type_implement_ra then return true - else - match mod_inst.mod_inst_def with - | None -> return false - | Some (mod_inst_def_funct, mod_inst_def_args) -> - let* mod_inst_def_funct_is_ra = - let* mod_inst_def_funct_symbol = - find_and_reify mod_inst_def_funct - in - does_symbol_implement_ra mod_inst_def_funct_symbol - in - - return mod_inst_def_funct_is_ra) - | _ -> return false - -let rec does_type_implement_ra (tp : AstDef.type_expr) : bool t = - let open Syntax in +let is_ra_type (tp : AstDef.type_expr) : bool t = + let open Syntax in + let rec does_ident_implement_ra qual_ident = + let* symbol = find qual_ident in + Symbol.extract symbol ~f:(fun subst -> function + | AstDef.Module.ModDef m -> return m.mod_decl.mod_decl_is_ra + | ModInst mod_inst -> + let* is_ra = does_ident_implement_ra mod_inst.mod_inst_type in + if is_ra then return true + else + (match mod_inst.mod_inst_def with + | None -> return false + | Some (mod_inst_def_funct, mod_inst_def_args) -> + does_ident_implement_ra mod_inst_def_funct) + | _ -> return false) + in match tp with | App (Var qi, [], _) -> - let* symbol = find_and_reify (QualIdent.pop qi) in - does_symbol_implement_ra symbol + does_ident_implement_ra (QualIdent.pop qi) | _ -> return false let field_get_ra_qual_iden (field : AstDef.Module.field_def) = @@ -255,12 +235,11 @@ let field_get_ra_qual_iden (field : AstDef.Module.field_def) = Error.error field.field_loc "ProgUtils.field_get_ra_module: Expected field definition" in - match field_type with | App (Var qual_iden, [], _) -> QualIdent.pop qual_iden | _ -> Error.error field.field_loc - "ProgUtils.field_get_ra_module: Expected field type to be a variable" + "ProgUtils.field_get_ra_module: Expected field type to be a type identifier" let pred_get_ra_qual_iden pred_qual_iden = let open Syntax in diff --git a/lib/frontend/rewrites/rewrites.ml b/lib/frontend/rewrites/rewrites.ml index 09fc051..e2e5740 100644 --- a/lib/frontend/rewrites/rewrites.ml +++ b/lib/frontend/rewrites/rewrites.ml @@ -1558,31 +1558,7 @@ let rec rewrite_frac_field_types (symbol : Module.symbol) : | CallDef _ -> Rewriter.return symbol | FieldDef f -> - let* is_field_an_ra = - match f.field_type with - | App (Fld, [ App (Var qual_iden, args, _) ], _) -> - assert (List.is_empty args); - - Logs.debug (fun m -> m - "Rewrites.rewrite_frac_field_types: \ - field_def: %a - field_qual_iden: %a" - Ident.pr f.field_name - QualIdent.pr qual_iden - ); - - let module_name = QualIdent.pop qual_iden in - - let* does_module_implement_ra = - let* module_symbol = - Rewriter.find_and_reify module_name - in - ProgUtils.does_symbol_implement_ra module_symbol - in - - Rewriter.return does_module_implement_ra - | _ -> Rewriter.return false - in + let* is_field_an_ra = ProgUtils.is_ra_type (Type.field_val f.field_type) in Logs.debug (fun m -> m "Rewrites.rewrite_frac_field_types: