Skip to content

Commit

Permalink
expr: replace get with [..]
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 2, 2024
1 parent 38e9293 commit 8e7433c
Show file tree
Hide file tree
Showing 18 changed files with 78 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "patronus"
version = "0.24.0"
version = "0.25.0"
edition = "2021"
authors = ["Kevin Laeufer <[email protected]>"]
description = "Hardware bug-finding toolkit."
Expand Down
2 changes: 1 addition & 1 deletion src/btor2/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ fn improve_state_names(ctx: &mut Context, sys: &mut TransitionSystem) {
// since the alias signal refers to the same expression as the state symbol,
// it will generate a signal info with the better name
if let Some(name_ref) = sys.names[state.symbol] {
let old_name_ref = ctx.get(state.symbol).get_symbol_name_ref().unwrap();
let old_name_ref = ctx[state.symbol].get_symbol_name_ref().unwrap();
if old_name_ref != name_ref {
renames.insert(state.symbol, name_ref);
}
Expand Down
8 changes: 4 additions & 4 deletions src/egraphs/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pub fn to_arith(ctx: &Context, e: ExprRef) -> egg::RecExpr<Arith> {
children.push(remove_ext(ctx, *c).0);
});
},
|_ctx, expr, children| match ctx.get(expr).clone() {
|_ctx, expr, children| match ctx[expr].clone() {
Expr::BVSymbol { name, width } => out.add(Arith::Symbol(name, width)),
Expr::BVAdd(a, b, width) => out.add(convert_bin_op(
ctx,
Expand Down Expand Up @@ -205,9 +205,9 @@ fn convert_bin_op(

/// Removes any sign or zero extend expressions and returns whether the removed extension was signed.
fn remove_ext(ctx: &Context, e: ExprRef) -> (ExprRef, bool) {
match ctx.get(e) {
Expr::BVZeroExt { e, .. } => (remove_ext(ctx, *e).0, false),
Expr::BVSignExt { e, .. } => (remove_ext(ctx, *e).0, true),
match ctx[e] {
Expr::BVZeroExt { e, .. } => (remove_ext(ctx, e).0, false),
Expr::BVSignExt { e, .. } => (remove_ext(ctx, e).0, true),
_ => (e, false),
}
}
Expand Down
19 changes: 12 additions & 7 deletions src/expr/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::borrow::Borrow;
use std::cell::RefCell;
use std::fmt::{Debug, Formatter};
use std::num::{NonZeroU16, NonZeroU32};
use std::ops::Index;

/// Uniquely identifies a [`String`] stored in a [`Context`].
#[derive(PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
Expand Down Expand Up @@ -95,14 +96,8 @@ impl Default for Context {

/// Adding and removing nodes.
impl Context {
pub fn get(&self, reference: ExprRef) -> &Expr {
self.exprs
.get_index((reference.0.get() as usize) - 1)
.expect("Invalid ExprRef!")
}

pub fn get_symbol_name(&self, reference: ExprRef) -> Option<&str> {
self.get(reference).get_symbol_name(self)
self[reference].get_symbol_name(self)
}

pub(crate) fn add_expr(&mut self, value: Expr) -> ExprRef {
Expand Down Expand Up @@ -130,6 +125,16 @@ impl Context {
}
}

impl Index<ExprRef> for Context {
type Output = Expr;

fn index(&self, index: ExprRef) -> &Self::Output {
self.exprs
.get_index(index.index())
.expect("Invalid ExprRef!")
}
}

/// Convenience methods to construct IR nodes.
impl Context {
// helper functions to construct expressions
Expand Down
16 changes: 8 additions & 8 deletions src/expr/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ pub fn eval_bv_expr(
expr: ExprRef,
) -> BitVecValue {
debug_assert!(
ctx.get(expr).get_bv_type(ctx).is_some(),
ctx[expr].get_bv_type(ctx).is_some(),
"Not a bit-vector expression: {:?}",
ctx.get(expr)
ctx[expr]
);
let (mut bv_stack, array_stack) = eval_expr_internal(ctx, symbols, expr);
debug_assert!(array_stack.is_empty());
Expand All @@ -183,9 +183,9 @@ pub fn eval_array_expr(
expr: ExprRef,
) -> ArrayValue {
debug_assert!(
ctx.get(expr).get_array_type(ctx).is_some(),
ctx[expr].get_array_type(ctx).is_some(),
"Not an array expression: {:?}",
ctx.get(expr)
ctx[expr]
);
let (bv_stack, mut array_stack) = eval_expr_internal(ctx, symbols, expr);
debug_assert!(bv_stack.is_empty());
Expand All @@ -197,11 +197,11 @@ pub fn eval_expr(ctx: &Context, symbols: &(impl GetExprValue + ?Sized), expr: Ex
let (mut bv_stack, mut array_stack) = eval_expr_internal(ctx, symbols, expr);
debug_assert_eq!(bv_stack.len() + array_stack.len(), 1);
if let Some(value) = bv_stack.pop() {
debug_assert!(ctx.get(expr).is_bv_type());
debug_assert!(ctx[expr].is_bv_type());
Value::BitVec(value)
} else {
let value = array_stack.pop().unwrap();
debug_assert!(ctx.get(expr).is_array_type());
debug_assert!(ctx[expr].is_array_type());
Value::Array(value)
}
}
Expand All @@ -217,7 +217,7 @@ fn eval_expr_internal(

todo.push((expr, false));
while let Some((e, args_available)) = todo.pop() {
let expr = ctx.get(e);
let expr = &ctx[e];

// Check if there are children that we need to compute first.
if !args_available {
Expand Down Expand Up @@ -300,7 +300,7 @@ fn eval_expr_internal(
| Expr::BVSignedMod(_, _, _)
| Expr::BVSignedRem(_, _, _)
| Expr::BVUnsignedRem(_, _, _) => {
todo!("implement eval support for {:?}", ctx.get(e))
todo!("implement eval support for {:?}", ctx[e])
}
Expr::BVSub(_, _, _) => bin_op(&mut bv_stack, |a, b| a.sub(&b)),
// BVArrayRead needs array support!
Expand Down
4 changes: 2 additions & 2 deletions src/expr/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl<'a> Parser<'a> {
let new_sym = self.ctx.bv_symbol(name, width);
// compare width
if let Some(other) = self.symbols.get(name) {
let other_width = self.ctx.get(*other).get_bv_type(self.ctx).unwrap();
let other_width = self.ctx[*other].get_bv_type(self.ctx).unwrap();
assert_eq!(
width, other_width,
"Two symbols with same name {name} have different widths!"
Expand All @@ -147,7 +147,7 @@ impl<'a> Parser<'a> {
.symbols
.get(name)
.unwrap_or_else(|| panic!("symbol of unknown type: `{name}` @ {}", self.inp));
let width = self.ctx.get(other).get_bv_type(self.ctx).unwrap();
let width = self.ctx[other].get_bv_type(self.ctx).unwrap();
self.consume_c(&c);
Some(self.ctx.bv_symbol(name, width))
}
Expand Down
4 changes: 2 additions & 2 deletions src/expr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,12 @@ where
F: Fn(&ExprRef, &mut W) -> std::io::Result<bool>,
W: Write,
{
serialize_expr(ctx.get(*expr), ctx, writer, serialize_child)
serialize_expr(&ctx[*expr], ctx, writer, serialize_child)
}

impl SerializableIrNode for ExprRef {
fn serialize<W: Write>(&self, ctx: &Context, writer: &mut W) -> std::io::Result<()> {
ctx.get(*self).serialize(ctx, writer)
ctx[*self].serialize(ctx, writer)
}
}

Expand Down
49 changes: 23 additions & 26 deletions src/expr/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<T: ExprMap<Option<ExprRef>>> Simplifier<T> {

/// Simplifies one expression (not its children)
pub(crate) fn simplify(ctx: &mut Context, expr: ExprRef, children: &[ExprRef]) -> Option<ExprRef> {
match (ctx.get(expr).clone(), children) {
match (ctx[expr].clone(), children) {
(Expr::BVNot(_, _), [e]) => simplify_bv_not(ctx, *e),
(Expr::BVZeroExt { by, .. }, [e]) => simplify_bv_zero_ext(ctx, *e, by),
(Expr::BVSlice { lo, hi, .. }, [e]) => simplify_bv_slice(ctx, *e, hi, lo),
Expand Down Expand Up @@ -71,7 +71,7 @@ fn simplify_ite(ctx: &mut Context, cond: ExprRef, tru: ExprRef, fals: ExprRef) -
}

// constant condition
if let Expr::BVLiteral(value) = ctx.get(cond) {
if let Expr::BVLiteral(value) = ctx[cond] {
if value.get(ctx).is_fals() {
// ite(false, a, b) -> b
return Some(fals);
Expand All @@ -82,14 +82,11 @@ fn simplify_ite(ctx: &mut Context, cond: ExprRef, tru: ExprRef, fals: ExprRef) -
}

// boolean return type
let value_width = ctx.get(tru).get_bv_type(ctx).unwrap();
debug_assert_eq!(
ctx.get(fals).get_bv_type(ctx),
ctx.get(tru).get_bv_type(ctx)
);
let value_width = ctx[tru].get_bv_type(ctx).unwrap();
debug_assert_eq!(ctx[fals].get_bv_type(ctx), ctx[tru].get_bv_type(ctx));
if value_width == 1 {
// boolean value simplifications
match (ctx.get(tru), ctx.get(fals)) {
match (&ctx[tru], &ctx[fals]) {
(Expr::BVLiteral(vt), Expr::BVLiteral(vf)) => {
let res = match (
vt.get(ctx).to_bool().unwrap(),
Expand Down Expand Up @@ -137,7 +134,7 @@ enum Lits {
/// Finds the maximum number of literals. Only works on commutative operations.
#[inline]
fn find_lits_commutative(ctx: &Context, a: ExprRef, b: ExprRef) -> Lits {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => Lits::Two(*va, *vb),
(Expr::BVLiteral(va), _) => Lits::One((*va, a), b),
(_, Expr::BVLiteral(vb)) => Lits::One((*vb, b), a),
Expand All @@ -147,7 +144,7 @@ fn find_lits_commutative(ctx: &Context, a: ExprRef, b: ExprRef) -> Lits {

#[inline]
fn find_one_concat(ctx: &Context, a: ExprRef, b: ExprRef) -> Option<(ExprRef, ExprRef, ExprRef)> {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
(Expr::BVConcat(c_a, c_b, _), _) => Some((*c_a, *c_b, b)),
(_, Expr::BVConcat(c_a, c_b, _)) => Some((*c_a, *c_b, a)),
_ => None,
Expand Down Expand Up @@ -180,8 +177,8 @@ fn simplify_bv_equal(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRe

// check to see if we are comparing to a concat
if let Some((concat_a, concat_b, other)) = find_one_concat(ctx, a, b) {
let a_width = ctx.get(concat_a).get_bv_type(ctx).unwrap();
let b_width = ctx.get(concat_b).get_bv_type(ctx).unwrap();
let a_width = ctx[concat_a].get_bv_type(ctx).unwrap();
let b_width = ctx[concat_b].get_bv_type(ctx).unwrap();
let width = a_width + b_width;
debug_assert_eq!(width, other.get_bv_type(ctx).unwrap());
let eq_a = ctx.build(|c| c.bv_equal(concat_a, c.slice(other, width - 1, width - a_width)));
Expand Down Expand Up @@ -213,7 +210,7 @@ fn simplify_bv_and(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
Some(expr)
} else {
// (a # b) & mask -> ((a & mask_upper) # (b & mask_lower))
if let Expr::BVConcat(a, b, width) = ctx.get(expr).clone() {
if let Expr::BVConcat(a, b, width) = ctx[expr].clone() {
let b_width = b.get_bv_type(ctx).unwrap();
debug_assert_eq!(width, b_width + a.get_bv_type(ctx).unwrap());
let a_mask = ctx.bv_lit(&lit.get(ctx).slice(width - 1, b_width));
Expand Down Expand Up @@ -253,7 +250,7 @@ fn simplify_bv_and(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
}
}
Lits::None => {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
// a & !a -> 0
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.zero(*w)),
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.zero(*w)),
Expand Down Expand Up @@ -289,7 +286,7 @@ fn simplify_bv_or(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
}
}
Lits::None => {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
// a | !a -> 1
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.ones(*w)),
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.ones(*w)),
Expand All @@ -304,7 +301,7 @@ fn simplify_bv_or(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
fn simplify_bv_xor(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
// a xor a -> 0
if a == b {
let width = ctx.get(a).get_bv_type(ctx).unwrap();
let width = ctx[a].get_bv_type(ctx).unwrap();
return Some(ctx.zero(width));
}

Expand All @@ -326,7 +323,7 @@ fn simplify_bv_xor(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
}
}
Lits::None => {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
// a xor !a -> 1
(Expr::BVNot(inner, w), _) if *inner == b => Some(ctx.ones(*w)),
(_, Expr::BVNot(inner, w)) if *inner == a => Some(ctx.ones(*w)),
Expand All @@ -337,7 +334,7 @@ fn simplify_bv_xor(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef>
}

fn simplify_bv_not(ctx: &mut Context, e: ExprRef) -> Option<ExprRef> {
match ctx.get(e) {
match &ctx[e] {
Expr::BVNot(inner, _) => Some(*inner), // double negation
Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).not())),
_ => None,
Expand All @@ -348,7 +345,7 @@ fn simplify_bv_zero_ext(ctx: &mut Context, e: ExprRef, by: WidthInt) -> Option<E
if by == 0 {
Some(e)
} else {
match ctx.get(e) {
match &ctx[e] {
// zero extend constant
Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).zero_extend(by))),
// normalize to concat(${by}'d0, e);
Expand All @@ -361,7 +358,7 @@ fn simplify_bv_sign_ext(ctx: &mut Context, e: ExprRef, by: WidthInt) -> Option<E
if by == 0 {
Some(e)
} else {
match ctx.get(e) {
match &ctx[e] {
Expr::BVLiteral(value) => Some(ctx.bv_lit(&value.get(ctx).sign_extend(by))),
Expr::BVSignExt {
e: inner_e,
Expand All @@ -374,14 +371,14 @@ fn simplify_bv_sign_ext(ctx: &mut Context, e: ExprRef, by: WidthInt) -> Option<E
}

fn simplify_bv_concat(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprRef> {
match (ctx.get(a).clone(), ctx.get(b).clone()) {
match (ctx[a].clone(), ctx[b].clone()) {
// normalize concat to be right recursive
(Expr::BVConcat(a_a, a_b, _), _) => Some(ctx.build(|c| c.concat(a_a, c.concat(a_b, b)))),
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).concat(&vb.get(ctx))))
}
(Expr::BVLiteral(va), Expr::BVConcat(b_a, b_b, _)) => {
if let Expr::BVLiteral(v_b_a) = *ctx.get(b_a) {
if let Expr::BVLiteral(v_b_a) = ctx[b_a] {
let lit = ctx.bv_lit(&va.get(ctx).concat(&v_b_a.get(ctx)));
Some(ctx.concat(lit, b_b))
} else {
Expand Down Expand Up @@ -413,7 +410,7 @@ fn simplify_bv_concat(ctx: &mut Context, a: ExprRef, b: ExprRef) -> Option<ExprR

fn simplify_bv_slice(ctx: &mut Context, e: ExprRef, hi: WidthInt, lo: WidthInt) -> Option<ExprRef> {
debug_assert!(hi >= lo);
match ctx.get(e).clone() {
match ctx[e].clone() {
// combine slices
Expr::BVSlice {
lo: inner_lo,
Expand Down Expand Up @@ -481,7 +478,7 @@ fn simplify_bv_shift_left(
b: ExprRef,
width: WidthInt,
) -> Option<ExprRef> {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).shift_left(&vb.get(ctx))))
}
Expand Down Expand Up @@ -511,7 +508,7 @@ fn simplify_bv_shift_right(
b: ExprRef,
width: WidthInt,
) -> Option<ExprRef> {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).shift_right(&vb.get(ctx))))
}
Expand Down Expand Up @@ -542,7 +539,7 @@ fn simplify_bv_arithmetic_shift_right(
b: ExprRef,
width: WidthInt,
) -> Option<ExprRef> {
match (ctx.get(a), ctx.get(b)) {
match (&ctx[a], &ctx[b]) {
(Expr::BVLiteral(va), Expr::BVLiteral(vb)) => {
Some(ctx.bv_lit(&va.get(ctx).arithmetic_shift_right(&vb.get(ctx))))
}
Expand Down
4 changes: 2 additions & 2 deletions src/expr/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub(crate) fn do_transform_expr<T: ExprMap<Option<ExprRef>>>(
children.clear();
let mut children_changed = false; // track whether any of the children changed
let mut all_transformed = true; // tracks whether all children have been transformed or if there is more work to do
ctx.get(expr_ref).for_each_child(|c| {
ctx[expr_ref].for_each_child(|c| {
let transformed_child = if mode == ExprTransformMode::FixedPoint {
get_fixed_point(transformed, *c)
} else {
Expand Down Expand Up @@ -82,7 +82,7 @@ pub(crate) fn do_transform_expr<T: ExprMap<Option<ExprRef>>>(
}

fn update_expr_children(ctx: &mut Context, expr_ref: ExprRef, children: &[ExprRef]) -> ExprRef {
let new_expr = match (ctx.get(expr_ref), children) {
let new_expr = match (&ctx[expr_ref], children) {
(Expr::BVSymbol { .. }, _) => panic!("No children, should never get here."),
(Expr::BVLiteral { .. }, _) => panic!("No children, should never get here."),
(Expr::BVZeroExt { by, width, .. }, [e]) => Expr::BVZeroExt {
Expand Down
4 changes: 2 additions & 2 deletions src/expr/traversal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn bottom_up_multi_pat<R>(
let mut child_vec = Vec::with_capacity(4);

while let Some((e, bottom_up)) = todo.pop() {
let expr = ctx.get(e);
let expr = &ctx[e];

// Check if there are children that we need to compute first.
if !bottom_up {
Expand Down Expand Up @@ -82,7 +82,7 @@ pub fn top_down(
while let Some(e) = todo.pop() {
let do_continue = f(ctx, e) == TraversalCmd::Continue;
if do_continue {
ctx.get(e).for_each_child(|&c| todo.push(c));
ctx[e].for_each_child(|&c| todo.push(c));
}
}
}
Loading

0 comments on commit 8e7433c

Please sign in to comment.