From 0e88ca917e471eadf085805b71f9de83a0d760d7 Mon Sep 17 00:00:00 2001 From: "Jeremy G. Siek" Date: Wed, 18 Dec 2024 15:23:37 -0500 Subject: [PATCH] an improved fix for issue #42 --- abstract_syntax.py | 13 +++++++++++++ deduce.py | 10 ++++++++-- proof_checker.py | 30 +++++++++++++----------------- 3 files changed, 34 insertions(+), 19 deletions(-) diff --git a/abstract_syntax.py b/abstract_syntax.py index c6b62cd..6560932 100644 --- a/abstract_syntax.py +++ b/abstract_syntax.py @@ -648,6 +648,8 @@ def is_match(pattern, arg, subst): if constr == Var(loc3, ty3, name, rs) and len(params) == len(args): for (k,v) in zip(params, args): subst[k] = v + if isinstance(v, TermInst): + v.inferred = False ret = True else: ret = False @@ -655,6 +657,8 @@ def is_match(pattern, arg, subst): if constr == Var(loc3, ty3, name, rs) and len(params) == len(args): for (k,v) in zip(params, args): subst[k] = v + if isinstance(v, TermInst): + v.inferred = False ret = True else: ret = False @@ -808,6 +812,9 @@ def reduce(self, env): ret = Call(self.location, self.typeof, fun, args, self.infix) case Lambda(loc, ty, vars, body): subst = {k: v for ((k,t),v) in zip(vars, args)} + for (k,v) in subst.items(): + if isinstance(v, TermInst): + v.inferred = False body_env = env new_body = body.substitute(subst) old_defs = get_reduce_only() @@ -827,6 +834,9 @@ def reduce(self, env): subst[x] = ty for (k,v) in zip(fun_case.parameters, rest_args): subst[k] = v + for (k,v) in subst.items(): + if isinstance(v, TermInst): + v.inferred = False new_fun_case_body = fun_case.body.substitute(subst) old_defs = get_reduce_only() reduce_defs = [x for x in old_defs] @@ -858,6 +868,9 @@ def reduce(self, env): body_env = env for (k,v) in zip(fun_case.parameters, rest_args): subst[k] = v + for (k,v) in subst.items(): + if isinstance(v, TermInst): + v.inferred = False new_fun_case_body = fun_case.body.substitute(subst) old_defs = get_reduce_only() reduce_defs = [x for x in old_defs] diff --git a/deduce.py b/deduce.py index 6c463bb..90f6719 100644 --- a/deduce.py +++ b/deduce.py @@ -113,9 +113,15 @@ def deduce_directory(directory, recursive_directories): print("Couldn't find a file to deduce!") exit(1) - # Start deducing - sys.setrecursionlimit(5000) # We can probably use a loop for some tail recursive functions + sys.setrecursionlimit(10000) + # We can probably use a loop for some tail recursive functions + # And even the non-tail recursive functions can be turned into a + # loop by using an explicit stack. But these alternatives would + # hurt the readability of the code and increase the maintenance + # burden. So when you hit the recursion limit, just bump the number + # higher. + # Start deducing parser.set_deduce_directory(os.path.dirname(sys.argv[0])) rec_desc_parser.set_deduce_directory(os.path.dirname(sys.argv[0])) parser.init_parser() diff --git a/proof_checker.py b/proof_checker.py index 92a036f..1755168 100644 --- a/proof_checker.py +++ b/proof_checker.py @@ -504,6 +504,8 @@ def check_proof(proof, env): v, ty = var try: new_arg = type_check_term(arg, ty.substitute(sub), env, None, []) + if isinstance(new_arg, TermInst): + new_arg.inferred = False except Exception as e: if isinstance(ty, TypeType): error(loc, f"In instantiation of\n\t{str(univ)} : {str(allfrm)}\n" \ @@ -1251,12 +1253,9 @@ def check_proof_of(proof, formula, env): trm = pattern_to_term(indcase.pattern) new_trm = type_check_term(trm, typ, body_env, None, []) - # The following type synthesis step is because the term may get - # inserted into a synthesis context, and if its - # a TermInst, it needs to be marked as not-inferred so that it - # gets printed. -Jeremy - newer_trm = type_synth_term(new_trm, body_env, None, []) - pre_goal = instantiate(loc, formula, newer_trm) + if isinstance(new_trm, TermInst): + new_trm.inferred = False + pre_goal = instantiate(loc, formula, new_trm) goal = check_formula(pre_goal, body_env) for ((x,frm1),frm2) in zip(indcase.induction_hypotheses, induction_hypotheses): @@ -1349,7 +1348,8 @@ def check_proof_of(proof, formula, env): constr_params)) new_subject_case = type_check_term(subject_case, ty, body_env, None, []) - new_subject_case = type_synth_term(new_subject_case, body_env, None, []) + if isinstance(new_subject_case, TermInst): + new_subject_case.inferred = False assumptions = [(label,check_formula(asm, body_env) if asm else None) for (label,asm) in scase.assumptions] if len(assumptions) == 1: @@ -1732,7 +1732,7 @@ def type_first_letter(typ): print('error in type_first_letter: unhandled type ' + repr(typ)) exit(-1) -def type_check_term_inst(loc, subject, tyargs, inferred, synth): +def type_check_term_inst(loc, subject, tyargs, inferred): for ty in tyargs: check_type(ty, env) new_subject = type_synth_term(subject, env, recfun, subterms) @@ -1747,13 +1747,11 @@ def type_check_term_inst(loc, subject, tyargs, inferred, synth): retty = FunctionType(loc2, [], inst_param_types, inst_return_type) case GenericUnknownInst(loc2, union_type): retty = TypeInst(loc2, union_type, tyargs) - if synth: - inferred = False case _: error(loc, 'expected a type name, not ' + str(ty)) return TermInst(loc, retty, new_subject, tyargs, inferred) -def type_check_term_inst_var(loc, subject_var, tyargs, inferred, env, synth): +def type_check_term_inst_var(loc, subject_var, tyargs, inferred, env): match subject_var: case Var(loc2, tyof, name, rs): for ty in tyargs: @@ -1769,8 +1767,6 @@ def type_check_term_inst_var(loc, subject_var, tyargs, inferred, env, synth): retty = FunctionType(loc3, [], inst_param_types, inst_return_type) case GenericUnknownInst(loc3, union_type): retty = TypeInst(loc3, union_type, tyargs) - if synth: - inferred = False case _: error(loc, 'cannot instantiate a term of type ' + str(ty)) return TermInst(loc, retty, Var(loc2, tyof, rs[0], [rs[0]]), tyargs, inferred) @@ -1963,10 +1959,10 @@ def process_case(c, result_type, cases_present): case TermInst(loc, _, Var(loc2, tyof, name, rs), tyargs, inferred): ret = type_check_term_inst_var(loc, Var(loc2, tyof, name, rs), tyargs, - inferred, env, True) + inferred, env) case TermInst(loc, _, subject, tyargs, inferred): - ret = type_check_term_inst(loc, subject, tyargs, inferred, True) + ret = type_check_term_inst(loc, subject, tyargs, inferred) case TAnnote(loc, tyof, subject, typ): check_type(typ, env) @@ -2134,10 +2130,10 @@ def process_case(c, result_type, cases_present): case TermInst(loc, _, Var(loc2, tyof, name, rs), tyargs, inferred): return type_check_term_inst_var(loc, Var(loc2, tyof, name, rs), tyargs, - inferred, env, False) + inferred, env) case TermInst(loc, _, subject, tyargs, inferred): - return type_check_term_inst(loc, subject, tyargs, inferred, False) + return type_check_term_inst(loc, subject, tyargs, inferred) case _: if get_verbose():