Skip to content

Commit

Permalink
egraph: split of rewrite code
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 20, 2024
1 parent 566d5c7 commit 086cfae
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 302 deletions.
345 changes: 44 additions & 301 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
// author: Kevin Laeufer <[email protected]>

use baa::BitVecOps;
use egg::{
define_language, ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, RecExpr,
Rewrite, Subst, Var,
};
use egg::{define_language, Id, Language, RecExpr};
use patronus::expr::*;
use rustc_hash::FxHashMap;
use std::cmp::{max, Ordering};
Expand Down Expand Up @@ -36,64 +33,6 @@ define_language! {
}
}

/// our version of the egg re-write macro
macro_rules! arith_rewrite {
(
$name:expr;
$lhs:expr => $rhs:expr
) => {{
ArithRewrite::new::<&str>($name, $lhs, $rhs, [], None)
}};
(
$name:expr;
$lhs:expr => $rhs:expr;
if $vars:expr, $cond:expr
) => {{
ArithRewrite::new($name, $lhs, $rhs, $vars, Some($cond))
}};
}

/// Generate our ROVER inspired rewrite rules.
pub fn create_rewrites() -> Vec<ArithRewrite> {
vec![
// a + b => b + a
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
// (a << b) << x => a << (b + c)
arith_rewrite!("merge-left-shift";
// we require that b, c and (b + c) are all unsigned
// we do not want (b + c) to wrap, because in that case the result would always be zero
// the value being shifted has to be consistently signed or unsigned
"(<< ?wo ?wab ?sa (<< ?wab ?wa ?sa ?a ?wb unsign ?b) ?wc unsign ?c)" =>
"(<< ?wo ?wa ?sa ?a (max+1 ?wb ?wc) unsign (+ (max+1 ?wb ?wc) ?wb unsign ?b ?wc unsign ?c))";
// wab >= wo
if["?wo", "?wab"], |w| w[1] >= w[0]),
// a << (b + c) => (a << b) << x
arith_rewrite!("unmerge-left-shift";
// we require that b, c and (b + c) are all unsigned
// we do not want (b + c) to wrap, because in that case the result would always be zero
// the value being shifted has to be consistently signed or unsigned
"(<< ?wo ?wa ?sa ?a ?wbc unsign (+ ?wbc ?wb unsign ?b ?wc unsign ?c))" =>
"(<< ?wo ?wo ?sa (<< ?wo ?wa ?sa ?a ?wb unsign ?b) ?wc unsign ?c)";
// ?wbc >= max(wb, wc) + 1
if["?wbc", "?wb", "?wc"], |w| w[0] >= (max(w[1], w[2]) + 1)),
// a * 2 <=> a + a
arith_rewrite!("mult-to-add";
"(* ?wo ?wa ?sa ?a ?wb ?sb 2)" =>
"(+ ?wo ?wa ?sa ?a ?wa ?sa ?a)";
// (!sb && wb > 1) || (sb && wb > 2) || (wo <= wb)
if["?wb", "?sb", "?wo"],
|w| (w[1] == 0 && w[0] > 1) || (w[1] == 1 && w[0] > 2) || w[2] <= w[0]),
// (a * b) << c => (a << c) * b
arith_rewrite!("left-shift-mult";
// TODO: currently all signs are forced to unsigned
"(<< ?wo ?wab unsign (* ?wab ?wa unsign ?a ?wb unsign ?b) ?wc unsign ?c)" =>
// we set the width of (a << c) to the result width to satisfy wac >= wo
"(* ?wo ?wo unsign (<< ?wo ?wa unsign ?a ?wc unsign ?c) ?wb unsign ?b)";
// wab >= wo && all_signs_the_same
if["?wab", "?wo"], |w| w[0] >= w[1]),
]
}

/// Wrapper struct in order to do custom parsing.
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct WidthValue(WidthInt);
Expand Down Expand Up @@ -418,7 +357,7 @@ fn get_width(root: usize, expressions: &[Arith]) -> WidthInt {
}
}

fn is_bin_op(a: &Arith) -> bool {
pub fn is_bin_op(a: &Arith) -> bool {
matches!(
a,
Arith::Add(_)
Expand Down Expand Up @@ -489,172 +428,6 @@ fn extend(
}
}

pub struct ArithRewrite {
name: String,
/// most general lhs pattern
lhs: Pattern<Arith>,
/// rhs pattern with all widths derived from the lhs, maybe be the same as rhs
rhs_derived: Pattern<Arith>,
/// variables use by the condition
cond_vars: Vec<Var>,
/// condition of the re_write
cond: Option<fn(&[WidthInt]) -> bool>,
}

impl ArithRewrite {
fn new<S: AsRef<str>>(
name: &str,
lhs: &str,
rhs_derived: &str,
cond_vars: impl IntoIterator<Item = S>,
cond: Option<fn(&[WidthInt]) -> bool>,
) -> Self {
let cond_vars = cond_vars
.into_iter()
.map(|n| n.as_ref().parse().unwrap())
.collect();
let lhs = lhs.parse::<_>().unwrap();
check_width_consistency(&lhs);
let rhs_derived = rhs_derived.parse::<_>().unwrap();
check_width_consistency(&rhs_derived);
Self {
name: name.to_string(),
lhs,
rhs_derived,
cond,
cond_vars,
}
}

pub fn name(&self) -> &str {
&self.name
}

pub fn patterns(&self) -> (&PatternAst<Arith>, &PatternAst<Arith>) {
(&self.lhs.ast, &self.rhs_derived.ast)
}

pub fn to_egg(&self) -> Vec<Rewrite<Arith, ()>> {
// TODO: support bi-directional rules
if let Some(cond) = self.cond {
let vars: Vec<Var> = self.cond_vars.clone();
let condition = move |egraph: &mut EGraph, _, subst: &Subst| {
let values: Vec<WidthInt> = vars
.iter()
.map(|v| {
get_const_width_or_sign(egraph, subst[*v])
.expect("failed to find constant width")
})
.collect();
cond(values.as_slice())
};
let cond_app = ConditionalApplier {
condition,
applier: self.rhs_derived.clone(),
};
vec![Rewrite::new(self.name.clone(), self.lhs.clone(), cond_app).unwrap()]
} else {
vec![Rewrite::new(
self.name.clone(),
self.lhs.clone(),
self.rhs_derived.clone(),
)
.unwrap()]
}
}

pub fn eval_condition(&self, a: &[(Var, WidthInt)]) -> bool {
if let Some(cond) = self.cond {
let values: Vec<WidthInt> = self
.cond_vars
.iter()
.map(|v| a.iter().find(|(k, _)| k == v).unwrap().1)
.collect();
cond(values.as_slice())
} else {
// unconditional rewrite
true
}
}
}

type EGraph = egg::EGraph<Arith, ()>;

/// Finds a width or sign constant in the e-class referred to by the substitution
/// and returns its value. Errors if no such constant can be found.
fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> Option<WidthInt> {
egraph[id]
.nodes
.iter()
.flat_map(|n| match n {
Arith::Width(w) => Some((*w).into()),
Arith::Sign(s) => Some((*s).into()),
Arith::WidthMaxPlus1([a, b]) => {
let a = get_const_width_or_sign(egraph, *a).expect("failed to find constant width");
let b = get_const_width_or_sign(egraph, *b).expect("failed to find constant width");
Some(max(a, b) + 1)
}
_ => None,
})
.next()
}

/// Checks that input and output widths of operations are consistent.
fn check_width_consistency(pattern: &Pattern<Arith>) {
let exprs = pattern.ast.as_ref();
for e_node_or_var in exprs.iter() {
if let ENodeOrVar::ENode(expr) = e_node_or_var {
if is_bin_op(expr) {
// w, w_a, s_a, a, w_b, s_b, b
let a_width_id = usize::from(expr.children()[1]);
let a_id = usize::from(expr.children()[3]);
if let Some(a_op_out_width_id) = get_output_width_id(&exprs[a_id]) {
assert_eq!(
a_width_id, a_op_out_width_id,
"In `{expr}`, subexpression `{}` has inconsistent width: {} != {}",
&exprs[a_id], &exprs[a_width_id], &exprs[a_op_out_width_id]
);
}
let b_width_id = usize::from(expr.children()[4]);
let b_id = usize::from(expr.children()[6]);
if let Some(b_op_out_width_id) = get_output_width_id(&exprs[b_id]) {
assert_eq!(
b_width_id, b_op_out_width_id,
"In `{expr}`, subexpression `{}` has inconsistent width: {} != {}",
&exprs[b_id], &exprs[b_width_id], &exprs[b_op_out_width_id]
);
}
}
}
}
}

/// returns the egg id of the output width, if `expr` has one
fn get_output_width_id(expr: &ENodeOrVar<Arith>) -> Option<usize> {
if let ENodeOrVar::ENode(expr) = expr {
if is_bin_op(expr) {
// w, w_a, s_a, a, w_b, s_b, b
Some(usize::from(expr.children()[0]))
} else {
None
}
} else {
None
}
}

/// returns all our rewrites in a format that can be directly used by egg
pub fn create_egg_rewrites() -> Vec<Rewrite<Arith, ()>> {
create_rewrites()
.into_iter()
.map(|r| r.to_egg())
.reduce(|mut a, mut b| {
a.append(&mut b);
a
})
.unwrap_or(vec![])
}

fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> {
use std::process::{Command, Stdio};
let mut child = Command::new("dot")
Expand Down Expand Up @@ -747,30 +520,52 @@ fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> {
write!(out, "}}")
}

pub type EGraph = egg::EGraph<Arith, ()>;

/// Finds a width or sign constant in the e-class referred to by the substitution
/// and returns its value. Errors if no such constant can be found.
pub fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> Option<WidthInt> {
egraph[id]
.nodes
.iter()
.flat_map(|n| match n {
Arith::Width(w) => Some((*w).into()),
Arith::Sign(s) => Some((*s).into()),
Arith::WidthMaxPlus1([a, b]) => {
let a = get_const_width_or_sign(egraph, *a).expect("failed to find constant width");
let b = get_const_width_or_sign(egraph, *b).expect("failed to find constant width");
Some(max(a, b) + 1)
}
_ => None,
})
.next()
}

#[cfg(test)]
pub(crate) fn verification_fig_1(ctx: &mut Context) -> (ExprRef, ExprRef) {
let a = ctx.bv_symbol("A", 16);
let b = ctx.bv_symbol("B", 16);
let m = ctx.bv_symbol("M", 4);
let n = ctx.bv_symbol("N", 4);
let spec = ctx.build(|c| {
c.mul(
c.zero_extend(c.shift_left(c.zero_extend(a, 15), c.zero_extend(m, 27)), 32),
c.zero_extend(c.shift_left(c.zero_extend(b, 15), c.zero_extend(n, 27)), 32),
)
});
let implementation = ctx.build(|c| {
c.shift_left(
c.zero_extend(c.mul(c.zero_extend(a, 16), c.zero_extend(b, 16)), 31),
c.zero_extend(c.add(c.zero_extend(m, 1), c.zero_extend(n, 1)), 58),
)
});
(spec, implementation)
}

#[cfg(test)]
mod tests {
use super::*;

fn verification_fig_1(ctx: &mut Context) -> (ExprRef, ExprRef) {
let a = ctx.bv_symbol("A", 16);
let b = ctx.bv_symbol("B", 16);
let m = ctx.bv_symbol("M", 4);
let n = ctx.bv_symbol("N", 4);
let spec = ctx.build(|c| {
c.mul(
c.zero_extend(c.shift_left(c.zero_extend(a, 15), c.zero_extend(m, 27)), 32),
c.zero_extend(c.shift_left(c.zero_extend(b, 15), c.zero_extend(n, 27)), 32),
)
});
let implementation = ctx.build(|c| {
c.shift_left(
c.zero_extend(c.mul(c.zero_extend(a, 16), c.zero_extend(b, 16)), 31),
c.zero_extend(c.add(c.zero_extend(m, 1), c.zero_extend(n, 1)), 58),
)
});
(spec, implementation)
}

#[test]
fn test_data_path_verification_fig_1_conversion() {
let mut ctx = Context::default();
Expand All @@ -787,56 +582,4 @@ mod tests {
assert_eq!(spec_back, spec);
assert_eq!(impl_back, implementation);
}

#[test]
fn test_data_path_verification_fig_1_rewrites() {
let mut ctx = Context::default();
let (spec, implementation) = verification_fig_1(&mut ctx);
let spec_e = to_arith(&ctx, spec);
let impl_e = to_arith(&ctx, implementation);

println!("{spec_e}");
println!("{impl_e}");

// run egraph operations
let egg_rewrites = create_egg_rewrites();
let runner = egg::Runner::default()
.with_expr(&spec_e)
.with_expr(&impl_e)
.run(&egg_rewrites);

runner.print_report();

let spec_class = runner.roots[0];
let impl_class = runner.roots[1];
println!("{spec_class} {impl_class}");

// to_pdf("graph.pdf", &runner.egraph).unwrap();
// runner.egraph.dot().to_pdf("full_graph.pdf").unwrap();
}

#[test]
fn test_rewrites() {
let mut ctx = Context::default();
let a = ctx.bv_symbol("A", 16);
let b = ctx.bv_symbol("B", 16);
let in_smt_expr = ctx.add(a, b);
assert_eq!(in_smt_expr.serialize_to_str(&ctx), "add(A, B)");

// run egraph operations
let egg_expr_in = to_arith(&ctx, in_smt_expr);
let egg_rewrites = create_egg_rewrites();
let runner = egg::Runner::default()
.with_expr(&egg_expr_in)
.run(&egg_rewrites);

// check how many different nodes are representing the root node now
let root = runner.roots[0];
let root_nodes = &runner.egraph[root].nodes;
assert_eq!(
root_nodes.len(),
2,
"there should be two nodes if the rule has been applied"
);
}
}
7 changes: 6 additions & 1 deletion patronus-egraphs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>
mod arithmetic;
pub use arithmetic::{create_rewrites, from_arith, to_arith, Arith, ArithRewrite, Sign};
mod rewrites;

pub use arithmetic::{
from_arith, get_const_width_or_sign, is_bin_op, to_arith, Arith, EGraph, Sign,
};
pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite};
Loading

0 comments on commit 086cfae

Please sign in to comment.