Skip to content

Commit

Permalink
fixes to array get
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiek committed Jan 7, 2025
1 parent 2c3bde7 commit 8e7c500
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Deduce.lark
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ ident: IDENT -> ident
| "@" term_hi "<" type_list ">" -> term_inst
| "array" "(" term ")" -> make_array
| term_hi "(" term_list ")" -> call
| term_hi "[" INT "]" -> array_get
| term_hi "[" term "]" -> array_get
| "λ" var_list "{" term "}" -> lambda
| "fun" var_list "{" term "}" -> lambda
| "generic" ident_list "{" term "}" -> generic
Expand Down
23 changes: 12 additions & 11 deletions abstract_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def copy(self):
return ArrayType(self.location, self.elt_type.copy())

def __str__(self):
return '[' + (self.elt_type) + ']'
return '[' + str(self.elt_type) + ']'

def __eq__(self, other):
match other:
Expand Down Expand Up @@ -567,7 +567,7 @@ def reduce(self, env):
print('\t var ' + self.name + ' ===> ' + str(res))
return res.reduce(env)
else:
return self
return self
else:
return self

Expand Down Expand Up @@ -1182,44 +1182,45 @@ def uniquify(self, env):
@dataclass
class ArrayGet(Term):
subject: Term
index: int
position: Term

def __eq__(self, other):
if isinstance(other, ArrayGet):
return self.subject == other.subject \
and self.index == other.index
and self.position == other.position
else:
return False

def copy(self):
return ArrayGet(self.location, self.typeof,
self.subject.copy(), self.index)
self.subject.copy(), self.position.copy())

def __str__(self):
return str(self.subject) + '[' + str(self.index) + ']'
return str(self.subject) + '[' + str(self.position) + ']'

def reduce(self, env):
subject_red = self.subject.reduce(env)
index_red = self.index.reduce(env)
position_red = self.position.reduce(env)
match subject_red:
case Array(loc2, _, elements):
if isNat(index_red):
index = natToInt(index_red)
if isNat(position_red):
index = natToInt(position_red)
if 0 <= index and index < len(elements):
return elements[index].reduce(env)
else:
error(self.location, 'array index out of bounds\n' \
+ 'index: ' + str(index) + '\n' \
+ 'array length: ' + str(len(elements)))
return ArrayGet(self.location, self.typeof, subject_red, index_red)
return ArrayGet(self.location, self.typeof, subject_red, position_red)

def substitute(self, sub):
return ArrayGet(self.location, self.typeof,
self.subject.substitute(sub),
self.index)
self.position.substitute(sub))

def uniquify(self, env):
self.subject.uniquify(env)
self.position.uniquify(env)

@dataclass
class TLet(Term):
Expand Down
2 changes: 1 addition & 1 deletion parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def parse_tree_to_ast(e, parent):
elif e.data == 'array_get':
return ArrayGet(e.meta, None,
parse_tree_to_ast(e.children[0], e),
intToNat(e.meta, int(e.children[1])))
parse_tree_to_ast(e.children[1], e))
elif e.data == 'make_array':
return MakeArray(e.meta, None,
parse_tree_to_ast(e.children[0], e))
Expand Down
10 changes: 7 additions & 3 deletions proof_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,8 @@ def check_type(typ, env):
check_type(ty, env)
case GenericUnknownInst(loc, typ):
check_type(typ, env)
case ArrayType(loc, elt_type):
check_type(elt_type, env)
case _:
print('error in check_type: unhandled type ' + repr(typ) + ' ' + str(type(typ)))
exit(-1)
Expand Down Expand Up @@ -1932,9 +1934,10 @@ def type_synth_term(term, env, recfun, subterms):

case ArrayGet(loc, _, array, index):
new_array = type_synth_term(array, env, recfun, subterms)
new_index = type_synth_term(index, env, recfun, subterms)
match new_array.typeof:
case ArrayType(loc2, elt_type):
ret = ArrayGet(loc, elt_type, new_array, index)
ret = ArrayGet(loc, elt_type, new_array, new_index)
case _:
error(loc, 'expected an array, not ' + str(new_array.typeof))

Expand Down Expand Up @@ -2558,11 +2561,12 @@ def check_deduce(ast, module_name):
print(s)

if get_verbose():
print('env:\n' + str(env))
print('--------- Proof Checking ------------------------')
for s in ast3:
env = collect_env(s, env)

if get_verbose():
print('env:\n' + str(env))

if module_name not in checked_modules:
for s in ast3:
check_proofs(s, env)
Expand Down
6 changes: 2 additions & 4 deletions rec_desc_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,14 @@ def parse_term_hi():

def parse_array_get():
while_parsing = 'while parsing array access\n' \
+ '\tterm ::= term "[" integer "]"\n'
+ '\tterm ::= term "[" term "]"\n'
term = parse_term_hi()

while (not end_of_file()) and current_token().type == 'LSQB':
try:
start_token = current_token()
advance()
index = intToNat(meta_from_tokens(current_token(),current_token()),
int(current_token().value))
advance()
index = parse_term()
if current_token().type != 'RSQB':
error(meta_from_tokens(start_token, current_token()),
'expected closing "]", not\n\t' \
Expand Down
11 changes: 11 additions & 0 deletions test/should-pass/array3.pf
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import Nat
import List

define L = [1,2,3]
define A = array(L)
define i = 0
define j = 1
define k = 2
assert A[i] = 1
assert A[j] = 2
assert A[k] = 3

0 comments on commit 8e7c500

Please sign in to comment.