Skip to content

Commit

Permalink
try to correctly implement fixed point simplification for child expre…
Browse files Browse the repository at this point in the history
…ssions
  • Loading branch information
ekiwi committed Nov 21, 2024
1 parent f3dad07 commit 041f99c
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ easy-smt = "0.2.1"
# used for simulator initialization
rand = { version = "0.8.5", default-features = false }
rand_xoshiro = "0.6.0"
smallvec = "1.11.2"
smallvec = { version = "1.x", features = ["union"] }
baa = "0.14.4"
boolean_expression = "0.4.4"
egg = "0.9.5"
Expand Down
66 changes: 58 additions & 8 deletions src/expr/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,30 @@ where
fn insert(&mut self, e: ExprRef, data: T);
}

pub fn extract_fixed_point(data: &impl ExprMetaData<Option<ExprRef>>, mut key: ExprRef) -> ExprRef {
// TODO: actually update data in order to speed up future lookups, similar to union find
loop {
let value = data[key].unwrap();
if value == key {
return value;
}
key = value;
/// finds the fixed point value and updates values it discovers along the way
pub fn get_fixed_point<T: ExprMetaData<Option<ExprRef>>>(
m: &mut T,
key: ExprRef,
) -> Option<ExprRef> {
// fast path without updating any pointers
if key == m[key]? {
return Some(key);
}

// pointer chasing, similar to union find, but not the asymptotically fast path halving version
let mut value = key;
while value != m[value]? {
value = m[value]?;
}
// update pointers
let final_value = value;
value = key;
while value != final_value {
let next = m[value]?;
m.insert(value, Some(final_value));
value = next;
}
Some(value)
}

/// A sparse hash map to stare meta-data related to each expression
Expand All @@ -45,19 +60,22 @@ pub struct SparseExprMetaData<T: Default + Clone + Debug> {
impl<T: Default + Clone + Debug> Index<ExprRef> for SparseExprMetaData<T> {
type Output = T;

#[inline]
fn index(&self, e: ExprRef) -> &Self::Output {
self.inner.get(&e).unwrap_or(&self.default)
}
}

impl<T: Default + Clone + Debug> ExprMetaData<T> for SparseExprMetaData<T> {
#[inline]
fn iter<'a>(&'a self) -> impl Iterator<Item = (ExprRef, &'a T)>
where
T: 'a,
{
self.inner.iter().map(|(k, v)| (*k, v))
}

#[inline]
fn insert(&mut self, e: ExprRef, data: T) {
self.inner.insert(e, data);
}
Expand All @@ -72,6 +90,7 @@ pub struct DenseExprMetaData<T: Default + Clone + Debug> {
}

impl<T: Default + Clone + Debug> DenseExprMetaData<T> {
#[inline]
pub fn into_vec(self) -> Vec<T> {
self.inner
}
Expand All @@ -80,12 +99,14 @@ impl<T: Default + Clone + Debug> DenseExprMetaData<T> {
impl<T: Default + Clone + Debug> Index<ExprRef> for DenseExprMetaData<T> {
type Output = T;

#[inline]
fn index(&self, e: ExprRef) -> &Self::Output {
self.inner.get(e.index()).unwrap_or(&self.default)
}
}

impl<T: Default + Clone + Debug> ExprMetaData<T> for DenseExprMetaData<T> {
#[inline]
fn iter<'a>(&'a self) -> impl Iterator<Item = (ExprRef, &'a T)>
where
T: 'a,
Expand All @@ -96,6 +117,7 @@ impl<T: Default + Clone + Debug> ExprMetaData<T> for DenseExprMetaData<T> {
}
}

#[inline]
fn insert(&mut self, e: ExprRef, data: T) {
if self.inner.len() <= e.index() {
self.inner.resize(e.index(), T::default());
Expand All @@ -114,6 +136,7 @@ struct ExprMetaDataIter<'a, T> {
impl<'a, T> Iterator for ExprMetaDataIter<'a, T> {
type Item = (ExprRef, &'a T);

#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self.inner.next() {
None => None,
Expand All @@ -132,6 +155,7 @@ pub struct DenseExprMetaDataBool {
inner: Vec<u64>,
}

#[inline]
fn index_to_word_and_bit(index: ExprRef) -> (usize, u32) {
let index = index.index();
let word = index / Word::BITS as usize;
Expand All @@ -142,6 +166,7 @@ fn index_to_word_and_bit(index: ExprRef) -> (usize, u32) {
impl Index<ExprRef> for DenseExprMetaDataBool {
type Output = bool;

#[inline]
fn index(&self, index: ExprRef) -> &Self::Output {
let (word_idx, bit) = index_to_word_and_bit(index);
let word = self.inner.get(word_idx).cloned().unwrap_or_default();
Expand All @@ -154,6 +179,7 @@ impl Index<ExprRef> for DenseExprMetaDataBool {
}

impl ExprMetaData<bool> for DenseExprMetaDataBool {
#[inline]
fn iter<'a>(&'a self) -> impl Iterator<Item = (ExprRef, &'a bool)>
where
bool: 'a,
Expand All @@ -165,6 +191,7 @@ impl ExprMetaData<bool> for DenseExprMetaDataBool {
}
}

#[inline]
fn insert(&mut self, e: ExprRef, data: bool) {
let (word_idx, bit) = index_to_word_and_bit(e);
if self.inner.len() <= word_idx {
Expand Down Expand Up @@ -211,3 +238,26 @@ impl<'a> Iterator for ExprMetaBoolIter<'a> {

const TRU: bool = true;
const FALS: bool = false;

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_get_fixed_point() {
let mut m = DenseExprMetaData::default();
let zero = ExprRef::from_index(0);
let one = ExprRef::from_index(1);
let two = ExprRef::from_index(2);
m.insert(zero, Some(one));
m.insert(one, Some(two));
m.insert(two, Some(two));

assert_eq!(get_fixed_point(&mut m, two), Some(two));
assert_eq!(get_fixed_point(&mut m, one), Some(two));
assert_eq!(get_fixed_point(&mut m, zero), Some(two));
// our current implementation updates the whole path
assert_eq!(m[zero], Some(two));
assert_eq!(m[one], Some(two));
}
}
4 changes: 2 additions & 2 deletions src/expr/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
do_transform_expr, BVLitValue, Context, Expr, ExprMetaData, ExprRef, SparseExprMetaData,
TypeCheck, WidthInt,
};
use crate::expr::meta::extract_fixed_point;
use crate::expr::meta::get_fixed_point;
use crate::expr::transform::ExprTransformMode;
use baa::BitVecOps;

Expand Down Expand Up @@ -35,7 +35,7 @@ impl<T: ExprMetaData<Option<ExprRef>>> Simplifier<T> {
vec![e],
simplify,
);
extract_fixed_point(&self.cache, e)
get_fixed_point(&mut self.cache, e).unwrap()
}
}

Expand Down
28 changes: 20 additions & 8 deletions src/expr/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::expr::meta::get_fixed_point;
use crate::expr::*;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
Expand All @@ -10,6 +11,7 @@ pub enum ExprTransformMode {
FixedPoint,
}

#[inline]
pub(crate) fn do_transform_expr<T: ExprMetaData<Option<ExprRef>>>(
ctx: &mut Context,
mode: ExprTransformMode,
Expand All @@ -25,7 +27,12 @@ pub(crate) fn do_transform_expr<T: ExprMetaData<Option<ExprRef>>>(
let mut children_changed = false; // track whether any of the children changed
let mut all_transformed = true; // tracks whether all children have been transformed or if there is more work to do
ctx.get(expr_ref).for_each_child(|c| {
match transformed[*c] {
let transformed_child = if mode == ExprTransformMode::FixedPoint {
get_fixed_point(transformed, *c)
} else {
transformed[*c]
};
match transformed_child {
Some(new_child_expr) => {
if new_child_expr != *c {
children_changed = true; // child changed
Expand All @@ -46,14 +53,9 @@ pub(crate) fn do_transform_expr<T: ExprMetaData<Option<ExprRef>>>(
}

// call out to the transform
let tran_res = (tran)(ctx, expr_ref, &children);
let tran_res = tran(ctx, expr_ref, &children);
let new_expr_ref = match tran_res {
Some(e) => {
if mode == ExprTransformMode::FixedPoint && transformed[e].is_none() {
todo.push(e);
}
e
}
Some(e) => e,
None => {
if children_changed {
update_expr_children(ctx, expr_ref, &children)
Expand All @@ -66,6 +68,16 @@ pub(crate) fn do_transform_expr<T: ExprMetaData<Option<ExprRef>>>(
};
// remember the transformed version
transformed.insert(expr_ref, Some(new_expr_ref));

// in fixed point mode, we might not be done yet
let is_at_fixed_point = expr_ref == new_expr_ref;
if mode == ExprTransformMode::FixedPoint
&& !is_at_fixed_point
&& transformed[new_expr_ref].is_none()
{
// see if we can further simplify the new expression
todo.push(new_expr_ref);
}
}
}

Expand Down

0 comments on commit 041f99c

Please sign in to comment.