Skip to content

Commit

Permalink
an improved fix for issue #42
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiek committed Dec 18, 2024
1 parent 4e814d3 commit 0e88ca9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 19 deletions.
13 changes: 13 additions & 0 deletions abstract_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,13 +648,17 @@ def is_match(pattern, arg, subst):
if constr == Var(loc3, ty3, name, rs) and len(params) == len(args):
for (k,v) in zip(params, args):
subst[k] = v
if isinstance(v, TermInst):
v.inferred = False
ret = True
else:
ret = False
case TermInst(loc4, tyof, Var(loc3, ty3, name, rs), tyargs):
if constr == Var(loc3, ty3, name, rs) and len(params) == len(args):
for (k,v) in zip(params, args):
subst[k] = v
if isinstance(v, TermInst):
v.inferred = False
ret = True
else:
ret = False
Expand Down Expand Up @@ -808,6 +812,9 @@ def reduce(self, env):
ret = Call(self.location, self.typeof, fun, args, self.infix)
case Lambda(loc, ty, vars, body):
subst = {k: v for ((k,t),v) in zip(vars, args)}
for (k,v) in subst.items():
if isinstance(v, TermInst):
v.inferred = False
body_env = env
new_body = body.substitute(subst)
old_defs = get_reduce_only()
Expand All @@ -827,6 +834,9 @@ def reduce(self, env):
subst[x] = ty
for (k,v) in zip(fun_case.parameters, rest_args):
subst[k] = v
for (k,v) in subst.items():
if isinstance(v, TermInst):
v.inferred = False
new_fun_case_body = fun_case.body.substitute(subst)
old_defs = get_reduce_only()
reduce_defs = [x for x in old_defs]
Expand Down Expand Up @@ -858,6 +868,9 @@ def reduce(self, env):
body_env = env
for (k,v) in zip(fun_case.parameters, rest_args):
subst[k] = v
for (k,v) in subst.items():
if isinstance(v, TermInst):
v.inferred = False
new_fun_case_body = fun_case.body.substitute(subst)
old_defs = get_reduce_only()
reduce_defs = [x for x in old_defs]
Expand Down
10 changes: 8 additions & 2 deletions deduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,15 @@ def deduce_directory(directory, recursive_directories):
print("Couldn't find a file to deduce!")
exit(1)

# Start deducing
sys.setrecursionlimit(5000) # We can probably use a loop for some tail recursive functions
sys.setrecursionlimit(10000)
# We can probably use a loop for some tail recursive functions
# And even the non-tail recursive functions can be turned into a
# loop by using an explicit stack. But these alternatives would
# hurt the readability of the code and increase the maintenance
# burden. So when you hit the recursion limit, just bump the number
# higher.

# Start deducing
parser.set_deduce_directory(os.path.dirname(sys.argv[0]))
rec_desc_parser.set_deduce_directory(os.path.dirname(sys.argv[0]))
parser.init_parser()
Expand Down
30 changes: 13 additions & 17 deletions proof_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ def check_proof(proof, env):
v, ty = var
try:
new_arg = type_check_term(arg, ty.substitute(sub), env, None, [])
if isinstance(new_arg, TermInst):
new_arg.inferred = False
except Exception as e:
if isinstance(ty, TypeType):
error(loc, f"In instantiation of\n\t{str(univ)} : {str(allfrm)}\n" \
Expand Down Expand Up @@ -1251,12 +1253,9 @@ def check_proof_of(proof, formula, env):

trm = pattern_to_term(indcase.pattern)
new_trm = type_check_term(trm, typ, body_env, None, [])
# The following type synthesis step is because the term may get
# inserted into a synthesis context, and if its
# a TermInst, it needs to be marked as not-inferred so that it
# gets printed. -Jeremy
newer_trm = type_synth_term(new_trm, body_env, None, [])
pre_goal = instantiate(loc, formula, newer_trm)
if isinstance(new_trm, TermInst):
new_trm.inferred = False
pre_goal = instantiate(loc, formula, new_trm)
goal = check_formula(pre_goal, body_env)

for ((x,frm1),frm2) in zip(indcase.induction_hypotheses, induction_hypotheses):
Expand Down Expand Up @@ -1349,7 +1348,8 @@ def check_proof_of(proof, formula, env):
constr_params))

new_subject_case = type_check_term(subject_case, ty, body_env, None, [])
new_subject_case = type_synth_term(new_subject_case, body_env, None, [])
if isinstance(new_subject_case, TermInst):
new_subject_case.inferred = False

assumptions = [(label,check_formula(asm, body_env) if asm else None) for (label,asm) in scase.assumptions]
if len(assumptions) == 1:
Expand Down Expand Up @@ -1732,7 +1732,7 @@ def type_first_letter(typ):
print('error in type_first_letter: unhandled type ' + repr(typ))
exit(-1)

def type_check_term_inst(loc, subject, tyargs, inferred, synth):
def type_check_term_inst(loc, subject, tyargs, inferred):
for ty in tyargs:
check_type(ty, env)
new_subject = type_synth_term(subject, env, recfun, subterms)
Expand All @@ -1747,13 +1747,11 @@ def type_check_term_inst(loc, subject, tyargs, inferred, synth):
retty = FunctionType(loc2, [], inst_param_types, inst_return_type)
case GenericUnknownInst(loc2, union_type):
retty = TypeInst(loc2, union_type, tyargs)
if synth:
inferred = False
case _:
error(loc, 'expected a type name, not ' + str(ty))
return TermInst(loc, retty, new_subject, tyargs, inferred)

def type_check_term_inst_var(loc, subject_var, tyargs, inferred, env, synth):
def type_check_term_inst_var(loc, subject_var, tyargs, inferred, env):
match subject_var:
case Var(loc2, tyof, name, rs):
for ty in tyargs:
Expand All @@ -1769,8 +1767,6 @@ def type_check_term_inst_var(loc, subject_var, tyargs, inferred, env, synth):
retty = FunctionType(loc3, [], inst_param_types, inst_return_type)
case GenericUnknownInst(loc3, union_type):
retty = TypeInst(loc3, union_type, tyargs)
if synth:
inferred = False
case _:
error(loc, 'cannot instantiate a term of type ' + str(ty))
return TermInst(loc, retty, Var(loc2, tyof, rs[0], [rs[0]]), tyargs, inferred)
Expand Down Expand Up @@ -1963,10 +1959,10 @@ def process_case(c, result_type, cases_present):

case TermInst(loc, _, Var(loc2, tyof, name, rs), tyargs, inferred):
ret = type_check_term_inst_var(loc, Var(loc2, tyof, name, rs), tyargs,
inferred, env, True)
inferred, env)

case TermInst(loc, _, subject, tyargs, inferred):
ret = type_check_term_inst(loc, subject, tyargs, inferred, True)
ret = type_check_term_inst(loc, subject, tyargs, inferred)

case TAnnote(loc, tyof, subject, typ):
check_type(typ, env)
Expand Down Expand Up @@ -2134,10 +2130,10 @@ def process_case(c, result_type, cases_present):

case TermInst(loc, _, Var(loc2, tyof, name, rs), tyargs, inferred):
return type_check_term_inst_var(loc, Var(loc2, tyof, name, rs), tyargs,
inferred, env, False)
inferred, env)

case TermInst(loc, _, subject, tyargs, inferred):
return type_check_term_inst(loc, subject, tyargs, inferred, False)
return type_check_term_inst(loc, subject, tyargs, inferred)

case _:
if get_verbose():
Expand Down

0 comments on commit 0e88ca9

Please sign in to comment.