Skip to content

Commit

Permalink
Lazy evaluation of implies operator
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirsch committed Jan 15, 2025
1 parent 0eea271 commit a085665
Showing 1 changed file with 74 additions and 23 deletions.
97 changes: 74 additions & 23 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,23 @@ def get_bitwuzla(self, tm):
return self.bitwuzla

class Values:
false = None
true = None

def __init__(self, sid_line):
self.sid_line = sid_line
self.values = {}

def FALSE():
if Values.false is None:
Values.false = Values(Bool.boolean).set_value(Bool.boolean, 0, Constant.true)
return Values.false

def TRUE():
if Values.true is None:
Values.true = Values(Bool.boolean).set_value(Bool.boolean, 1, Constant.true)
return Values.true

def AND(arg1_line, arg2_line):
if arg1_line == Constant.true and arg2_line == Constant.true:
return Constant.true
Expand Down Expand Up @@ -405,6 +418,17 @@ def NOT(arg1_line):
return Unary(next_nid(), OP_NOT, Bool.boolean,
arg1_line, arg1_line.comment, arg1_line.line_no)

def IMPLIES(arg1_line, arg2_line):
if arg1_line == Constant.false or arg2_line == Constant.true:
return Constant.true
elif arg1_line == Constant.true:
return arg2_line
elif arg2_line == Constant.false:
return Values.NOT(arg1_line)
else:
return Implies(next_nid(), OP_IMPLIES, Bool.boolean,
arg1_line, arg2_line, arg1_line.comment, arg1_line.line_no)

def constrain(self, constraining_line):
if constraining_line == Constant.true:
return self
Expand All @@ -431,21 +455,26 @@ def merge(self, values):
results.set_value(values.sid_line, value, constraint)
return results

def get_boolean_constraints(self):
assert isinstance(self.sid_line, Bool)
assert len(self.values) <= 2
false_line = Constant.false
true_line = Constant.false
for value in self.values:
constraint_line = self.values[value]
if value == 0:
false_line = constraint_line
else:
assert value == 1
true_line = constraint_line
return false_line, true_line

def get_expression(self):
# naive transition from domain propagation to bit blasting
assert len(self.values) > 0
if isinstance(self.sid_line, Bool):
assert len(self.values) <= 2
false_line = Constant.false
true_line = Constant.false
for value in self.values:
constraint_line = self.values[value]
if value == 0:
false_line = Values.NOT(constraint_line)
else:
assert value == 1
true_line = constraint_line
exp_line = Values.OR(false_line, true_line)
# constraint on false value implies constraint on true value
return Values.IMPLIES(*self.get_boolean_constraints())
else:
exp_line = None
for value in self.values:
Expand All @@ -464,8 +493,8 @@ def get_expression(self):
else:
exp_line = Constd(next_nid(), self.sid_line, value,
constraint_line.comment, constraint_line.line_no)
assert exp_line is not None
return exp_line
assert exp_line is not None
return exp_line

def set_value(self, sid_line, value, constraint_line):
assert self.sid_line == sid_line
Expand Down Expand Up @@ -900,6 +929,7 @@ class Ext(Indexed):

def __init__(self, nid, op, sid_line, arg1_line, w, comment, line_no):
super().__init__(nid, sid_line, arg1_line, comment, line_no)
assert op in Ext.keywords
self.op = op
self.w = w
if sid_line.size != arg1_line.sid_line.size + w:
Expand Down Expand Up @@ -1011,6 +1041,7 @@ class Unary(Expression):

def __init__(self, nid, op, sid_line, arg1_line, comment, line_no):
super().__init__(nid, sid_line, arg1_line.domain, comment, line_no)
assert op in Unary.keywords
self.op = op
self.arg1_line = arg1_line
if not isinstance(arg1_line, Expression):
Expand Down Expand Up @@ -1110,6 +1141,7 @@ class Binary(Expression):

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
super().__init__(nid, sid_line, arg1_line.domain | arg2_line.domain, comment, line_no)
assert op in Binary.keywords
self.op = op
self.arg1_line = arg1_line
self.arg2_line = arg2_line
Expand Down Expand Up @@ -1155,6 +1187,7 @@ class Implies(Binary):
keyword = OP_IMPLIES

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
assert op == Implies.keyword
super().__init__(nid, Implies.keyword, sid_line, arg1_line, arg2_line, comment, line_no)
if not isinstance(sid_line, Bool):
raise model_error("Boolean result", line_no)
Expand All @@ -1174,6 +1207,27 @@ def propagate(self, arg1_value, arg2_value):
Values.AND(constraint1_line, constraint2_line))
return results

def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
if isinstance(arg1_value, Values):
false_line, true_line = arg1_value.get_boolean_constraints()
if false_line == Constant.true:
self.cache_values[step] = Values.TRUE()
return self.cache_values[step]
else:
# lazy evaluation of implied value
arg2_value = self.arg2_line.get_values(step)
if isinstance(arg2_value, Values):
self.cache_values[step] = self.propagate(arg1_value, arg2_value)
return self.cache_values[step]
else:
arg2_value = self.arg2_line.get_values(step)
arg1_value = arg1_value.get_expression()
arg2_value = arg2_value.get_expression()
self.cache_values[step] = self.copy(arg1_value, arg2_value)
return self.cache_values[step]

def get_z3(self):
if self.z3 is None:
self.z3 = z3.Implies(self.arg1_line.get_z3(), self.arg2_line.get_z3())
Expand All @@ -1189,6 +1243,7 @@ class Comparison(Binary):
keywords = {OP_EQ, OP_NEQ, OP_SGT, OP_UGT, OP_SGTE, OP_UGTE, OP_SLT, OP_ULT, OP_SLTE, OP_ULTE}

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
assert op in Comparison.keywords
super().__init__(nid, op, sid_line, arg1_line, arg2_line, comment, line_no)
if not isinstance(sid_line, Bool):
raise model_error("Boolean result", line_no)
Expand Down Expand Up @@ -1309,6 +1364,7 @@ class Logical(Binary):
keywords = {OP_AND, OP_OR, OP_XOR}

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
assert op in Logical.keywords
super().__init__(nid, op, sid_line, arg1_line, arg2_line, comment, line_no)
if not isinstance(sid_line, Bitvector):
raise model_error("Boolean or bitvector result", line_no)
Expand Down Expand Up @@ -1363,6 +1419,7 @@ class Computation(Binary):
keywords = {OP_SLL, OP_SRL, OP_SRA, OP_ADD, OP_SUB, OP_MUL, OP_SDIV, OP_UDIV, OP_SREM, OP_UREM}

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
assert op in Computation.keywords
super().__init__(nid, op, sid_line, arg1_line, arg2_line, comment, line_no)
if not isinstance(sid_line, Bitvec):
raise model_error("bitvector result", line_no)
Expand Down Expand Up @@ -1427,6 +1484,7 @@ class Concat(Binary):
keyword = OP_CONCAT

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
assert op == Concat.keyword
super().__init__(nid, Concat.keyword, sid_line, arg1_line, arg2_line, comment, line_no)
if not isinstance(sid_line, Bitvec):
raise model_error("bitvector result", line_no)
Expand Down Expand Up @@ -1464,6 +1522,7 @@ class Read(Binary):
READ_ARRAY_ITERATIVELY = True

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, comment, line_no):
assert op == Read.keyword
super().__init__(nid, Read.keyword, sid_line, arg1_line, arg2_line, comment, line_no)
if not isinstance(arg1_line.sid_line, Array):
raise model_error("array first operand", line_no)
Expand Down Expand Up @@ -1551,6 +1610,7 @@ class Ternary(Expression):

def __init__(self, nid, op, sid_line, arg1_line, arg2_line, arg3_line, comment, line_no):
super().__init__(nid, sid_line, arg1_line.domain | arg2_line.domain | arg3_line.domain, comment, line_no)
assert op in Ternary.keywords
self.op = op
self.arg1_line = arg1_line
self.arg2_line = arg2_line
Expand Down Expand Up @@ -1618,16 +1678,7 @@ def get_values(self, step):
if step not in self.cache_values:
arg1_value = self.arg1_line.get_values(step)
if isinstance(arg1_value, Values):
assert len(arg1_value.values) <= 2
false_line = Constant.false
true_line = Constant.false
for value in arg1_value.values:
constraint_line = arg1_value.values[value]
if value == 0:
false_line = constraint_line
else:
assert value == 1
true_line = constraint_line
false_line, true_line = arg1_value.get_boolean_constraints()
if false_line == Constant.false:
arg2_value = self.arg2_line.get_values(step)
if isinstance(arg2_value, Values):
Expand Down

0 comments on commit a085665

Please sign in to comment.