diff --git a/Deduce.lark b/Deduce.lark index 9bfaca3..cd3d00a 100644 --- a/Deduce.lark +++ b/Deduce.lark @@ -85,7 +85,7 @@ ident: IDENT -> ident | "@" term_hi "<" type_list ">" -> term_inst | "array" "(" term ")" -> make_array | term_hi "(" term_list ")" -> call - | term_hi "[" INT "]" -> array_get + | term_hi "[" term "]" -> array_get | "λ" var_list "{" term "}" -> lambda | "fun" var_list "{" term "}" -> lambda | "generic" ident_list "{" term "}" -> generic diff --git a/abstract_syntax.py b/abstract_syntax.py index 12d2bab..d9933a1 100644 --- a/abstract_syntax.py +++ b/abstract_syntax.py @@ -254,7 +254,7 @@ def copy(self): return ArrayType(self.location, self.elt_type.copy()) def __str__(self): - return '[' + (self.elt_type) + ']' + return '[' + str(self.elt_type) + ']' def __eq__(self, other): match other: @@ -567,7 +567,7 @@ def reduce(self, env): print('\t var ' + self.name + ' ===> ' + str(res)) return res.reduce(env) else: - return self + return self else: return self @@ -1182,44 +1182,45 @@ def uniquify(self, env): @dataclass class ArrayGet(Term): subject: Term - index: int + position: Term def __eq__(self, other): if isinstance(other, ArrayGet): return self.subject == other.subject \ - and self.index == other.index + and self.position == other.position else: return False def copy(self): return ArrayGet(self.location, self.typeof, - self.subject.copy(), self.index) + self.subject.copy(), self.position.copy()) def __str__(self): - return str(self.subject) + '[' + str(self.index) + ']' + return str(self.subject) + '[' + str(self.position) + ']' def reduce(self, env): subject_red = self.subject.reduce(env) - index_red = self.index.reduce(env) + position_red = self.position.reduce(env) match subject_red: case Array(loc2, _, elements): - if isNat(index_red): - index = natToInt(index_red) + if isNat(position_red): + index = natToInt(position_red) if 0 <= index and index < len(elements): return elements[index].reduce(env) else: error(self.location, 'array index out of bounds\n' \ + 'index: ' + str(index) + '\n' \ + 'array length: ' + str(len(elements))) - return ArrayGet(self.location, self.typeof, subject_red, index_red) + return ArrayGet(self.location, self.typeof, subject_red, position_red) def substitute(self, sub): return ArrayGet(self.location, self.typeof, self.subject.substitute(sub), - self.index) + self.position.substitute(sub)) def uniquify(self, env): self.subject.uniquify(env) + self.position.uniquify(env) @dataclass class TLet(Term): diff --git a/parser.py b/parser.py index cb414a0..cdbfa4e 100644 --- a/parser.py +++ b/parser.py @@ -222,7 +222,7 @@ def parse_tree_to_ast(e, parent): elif e.data == 'array_get': return ArrayGet(e.meta, None, parse_tree_to_ast(e.children[0], e), - intToNat(e.meta, int(e.children[1]))) + parse_tree_to_ast(e.children[1], e)) elif e.data == 'make_array': return MakeArray(e.meta, None, parse_tree_to_ast(e.children[0], e)) diff --git a/proof_checker.py b/proof_checker.py index 33d5a90..e3953be 100644 --- a/proof_checker.py +++ b/proof_checker.py @@ -1736,6 +1736,8 @@ def check_type(typ, env): check_type(ty, env) case GenericUnknownInst(loc, typ): check_type(typ, env) + case ArrayType(loc, elt_type): + check_type(elt_type, env) case _: print('error in check_type: unhandled type ' + repr(typ) + ' ' + str(type(typ))) exit(-1) @@ -1932,9 +1934,10 @@ def type_synth_term(term, env, recfun, subterms): case ArrayGet(loc, _, array, index): new_array = type_synth_term(array, env, recfun, subterms) + new_index = type_synth_term(index, env, recfun, subterms) match new_array.typeof: case ArrayType(loc2, elt_type): - ret = ArrayGet(loc, elt_type, new_array, index) + ret = ArrayGet(loc, elt_type, new_array, new_index) case _: error(loc, 'expected an array, not ' + str(new_array.typeof)) @@ -2558,11 +2561,12 @@ def check_deduce(ast, module_name): print(s) if get_verbose(): - print('env:\n' + str(env)) print('--------- Proof Checking ------------------------') for s in ast3: env = collect_env(s, env) - + if get_verbose(): + print('env:\n' + str(env)) + if module_name not in checked_modules: for s in ast3: check_proofs(s, env) diff --git a/rec_desc_parser.py b/rec_desc_parser.py index bf94739..963eca9 100644 --- a/rec_desc_parser.py +++ b/rec_desc_parser.py @@ -362,16 +362,14 @@ def parse_term_hi(): def parse_array_get(): while_parsing = 'while parsing array access\n' \ - + '\tterm ::= term "[" integer "]"\n' + + '\tterm ::= term "[" term "]"\n' term = parse_term_hi() while (not end_of_file()) and current_token().type == 'LSQB': try: start_token = current_token() advance() - index = intToNat(meta_from_tokens(current_token(),current_token()), - int(current_token().value)) - advance() + index = parse_term() if current_token().type != 'RSQB': error(meta_from_tokens(start_token, current_token()), 'expected closing "]", not\n\t' \ diff --git a/test/should-pass/array3.pf b/test/should-pass/array3.pf new file mode 100644 index 0000000..6dcde77 --- /dev/null +++ b/test/should-pass/array3.pf @@ -0,0 +1,11 @@ +import Nat +import List + +define L = [1,2,3] +define A = array(L) +define i = 0 +define j = 1 +define k = 2 +assert A[i] = 1 +assert A[j] = 2 +assert A[k] = 3