-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
337 additions
and
302 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,10 +3,7 @@ | |
// author: Kevin Laeufer <[email protected]> | ||
|
||
use baa::BitVecOps; | ||
use egg::{ | ||
define_language, ConditionalApplier, ENodeOrVar, Id, Language, Pattern, PatternAst, RecExpr, | ||
Rewrite, Subst, Var, | ||
}; | ||
use egg::{define_language, Id, Language, RecExpr}; | ||
use patronus::expr::*; | ||
use rustc_hash::FxHashMap; | ||
use std::cmp::{max, Ordering}; | ||
|
@@ -36,64 +33,6 @@ define_language! { | |
} | ||
} | ||
|
||
/// our version of the egg re-write macro | ||
macro_rules! arith_rewrite { | ||
( | ||
$name:expr; | ||
$lhs:expr => $rhs:expr | ||
) => {{ | ||
ArithRewrite::new::<&str>($name, $lhs, $rhs, [], None) | ||
}}; | ||
( | ||
$name:expr; | ||
$lhs:expr => $rhs:expr; | ||
if $vars:expr, $cond:expr | ||
) => {{ | ||
ArithRewrite::new($name, $lhs, $rhs, $vars, Some($cond)) | ||
}}; | ||
} | ||
|
||
/// Generate our ROVER inspired rewrite rules. | ||
pub fn create_rewrites() -> Vec<ArithRewrite> { | ||
vec![ | ||
// a + b => b + a | ||
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"), | ||
// (a << b) << x => a << (b + c) | ||
arith_rewrite!("merge-left-shift"; | ||
// we require that b, c and (b + c) are all unsigned | ||
// we do not want (b + c) to wrap, because in that case the result would always be zero | ||
// the value being shifted has to be consistently signed or unsigned | ||
"(<< ?wo ?wab ?sa (<< ?wab ?wa ?sa ?a ?wb unsign ?b) ?wc unsign ?c)" => | ||
"(<< ?wo ?wa ?sa ?a (max+1 ?wb ?wc) unsign (+ (max+1 ?wb ?wc) ?wb unsign ?b ?wc unsign ?c))"; | ||
// wab >= wo | ||
if["?wo", "?wab"], |w| w[1] >= w[0]), | ||
// a << (b + c) => (a << b) << x | ||
arith_rewrite!("unmerge-left-shift"; | ||
// we require that b, c and (b + c) are all unsigned | ||
// we do not want (b + c) to wrap, because in that case the result would always be zero | ||
// the value being shifted has to be consistently signed or unsigned | ||
"(<< ?wo ?wa ?sa ?a ?wbc unsign (+ ?wbc ?wb unsign ?b ?wc unsign ?c))" => | ||
"(<< ?wo ?wo ?sa (<< ?wo ?wa ?sa ?a ?wb unsign ?b) ?wc unsign ?c)"; | ||
// ?wbc >= max(wb, wc) + 1 | ||
if["?wbc", "?wb", "?wc"], |w| w[0] >= (max(w[1], w[2]) + 1)), | ||
// a * 2 <=> a + a | ||
arith_rewrite!("mult-to-add"; | ||
"(* ?wo ?wa ?sa ?a ?wb ?sb 2)" => | ||
"(+ ?wo ?wa ?sa ?a ?wa ?sa ?a)"; | ||
// (!sb && wb > 1) || (sb && wb > 2) || (wo <= wb) | ||
if["?wb", "?sb", "?wo"], | ||
|w| (w[1] == 0 && w[0] > 1) || (w[1] == 1 && w[0] > 2) || w[2] <= w[0]), | ||
// (a * b) << c => (a << c) * b | ||
arith_rewrite!("left-shift-mult"; | ||
// TODO: currently all signs are forced to unsigned | ||
"(<< ?wo ?wab unsign (* ?wab ?wa unsign ?a ?wb unsign ?b) ?wc unsign ?c)" => | ||
// we set the width of (a << c) to the result width to satisfy wac >= wo | ||
"(* ?wo ?wo unsign (<< ?wo ?wa unsign ?a ?wc unsign ?c) ?wb unsign ?b)"; | ||
// wab >= wo && all_signs_the_same | ||
if["?wab", "?wo"], |w| w[0] >= w[1]), | ||
] | ||
} | ||
|
||
/// Wrapper struct in order to do custom parsing. | ||
#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] | ||
pub struct WidthValue(WidthInt); | ||
|
@@ -418,7 +357,7 @@ fn get_width(root: usize, expressions: &[Arith]) -> WidthInt { | |
} | ||
} | ||
|
||
fn is_bin_op(a: &Arith) -> bool { | ||
pub fn is_bin_op(a: &Arith) -> bool { | ||
matches!( | ||
a, | ||
Arith::Add(_) | ||
|
@@ -489,172 +428,6 @@ fn extend( | |
} | ||
} | ||
|
||
pub struct ArithRewrite { | ||
name: String, | ||
/// most general lhs pattern | ||
lhs: Pattern<Arith>, | ||
/// rhs pattern with all widths derived from the lhs, maybe be the same as rhs | ||
rhs_derived: Pattern<Arith>, | ||
/// variables use by the condition | ||
cond_vars: Vec<Var>, | ||
/// condition of the re_write | ||
cond: Option<fn(&[WidthInt]) -> bool>, | ||
} | ||
|
||
impl ArithRewrite { | ||
fn new<S: AsRef<str>>( | ||
name: &str, | ||
lhs: &str, | ||
rhs_derived: &str, | ||
cond_vars: impl IntoIterator<Item = S>, | ||
cond: Option<fn(&[WidthInt]) -> bool>, | ||
) -> Self { | ||
let cond_vars = cond_vars | ||
.into_iter() | ||
.map(|n| n.as_ref().parse().unwrap()) | ||
.collect(); | ||
let lhs = lhs.parse::<_>().unwrap(); | ||
check_width_consistency(&lhs); | ||
let rhs_derived = rhs_derived.parse::<_>().unwrap(); | ||
check_width_consistency(&rhs_derived); | ||
Self { | ||
name: name.to_string(), | ||
lhs, | ||
rhs_derived, | ||
cond, | ||
cond_vars, | ||
} | ||
} | ||
|
||
pub fn name(&self) -> &str { | ||
&self.name | ||
} | ||
|
||
pub fn patterns(&self) -> (&PatternAst<Arith>, &PatternAst<Arith>) { | ||
(&self.lhs.ast, &self.rhs_derived.ast) | ||
} | ||
|
||
pub fn to_egg(&self) -> Vec<Rewrite<Arith, ()>> { | ||
// TODO: support bi-directional rules | ||
if let Some(cond) = self.cond { | ||
let vars: Vec<Var> = self.cond_vars.clone(); | ||
let condition = move |egraph: &mut EGraph, _, subst: &Subst| { | ||
let values: Vec<WidthInt> = vars | ||
.iter() | ||
.map(|v| { | ||
get_const_width_or_sign(egraph, subst[*v]) | ||
.expect("failed to find constant width") | ||
}) | ||
.collect(); | ||
cond(values.as_slice()) | ||
}; | ||
let cond_app = ConditionalApplier { | ||
condition, | ||
applier: self.rhs_derived.clone(), | ||
}; | ||
vec![Rewrite::new(self.name.clone(), self.lhs.clone(), cond_app).unwrap()] | ||
} else { | ||
vec![Rewrite::new( | ||
self.name.clone(), | ||
self.lhs.clone(), | ||
self.rhs_derived.clone(), | ||
) | ||
.unwrap()] | ||
} | ||
} | ||
|
||
pub fn eval_condition(&self, a: &[(Var, WidthInt)]) -> bool { | ||
if let Some(cond) = self.cond { | ||
let values: Vec<WidthInt> = self | ||
.cond_vars | ||
.iter() | ||
.map(|v| a.iter().find(|(k, _)| k == v).unwrap().1) | ||
.collect(); | ||
cond(values.as_slice()) | ||
} else { | ||
// unconditional rewrite | ||
true | ||
} | ||
} | ||
} | ||
|
||
type EGraph = egg::EGraph<Arith, ()>; | ||
|
||
/// Finds a width or sign constant in the e-class referred to by the substitution | ||
/// and returns its value. Errors if no such constant can be found. | ||
fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> Option<WidthInt> { | ||
egraph[id] | ||
.nodes | ||
.iter() | ||
.flat_map(|n| match n { | ||
Arith::Width(w) => Some((*w).into()), | ||
Arith::Sign(s) => Some((*s).into()), | ||
Arith::WidthMaxPlus1([a, b]) => { | ||
let a = get_const_width_or_sign(egraph, *a).expect("failed to find constant width"); | ||
let b = get_const_width_or_sign(egraph, *b).expect("failed to find constant width"); | ||
Some(max(a, b) + 1) | ||
} | ||
_ => None, | ||
}) | ||
.next() | ||
} | ||
|
||
/// Checks that input and output widths of operations are consistent. | ||
fn check_width_consistency(pattern: &Pattern<Arith>) { | ||
let exprs = pattern.ast.as_ref(); | ||
for e_node_or_var in exprs.iter() { | ||
if let ENodeOrVar::ENode(expr) = e_node_or_var { | ||
if is_bin_op(expr) { | ||
// w, w_a, s_a, a, w_b, s_b, b | ||
let a_width_id = usize::from(expr.children()[1]); | ||
let a_id = usize::from(expr.children()[3]); | ||
if let Some(a_op_out_width_id) = get_output_width_id(&exprs[a_id]) { | ||
assert_eq!( | ||
a_width_id, a_op_out_width_id, | ||
"In `{expr}`, subexpression `{}` has inconsistent width: {} != {}", | ||
&exprs[a_id], &exprs[a_width_id], &exprs[a_op_out_width_id] | ||
); | ||
} | ||
let b_width_id = usize::from(expr.children()[4]); | ||
let b_id = usize::from(expr.children()[6]); | ||
if let Some(b_op_out_width_id) = get_output_width_id(&exprs[b_id]) { | ||
assert_eq!( | ||
b_width_id, b_op_out_width_id, | ||
"In `{expr}`, subexpression `{}` has inconsistent width: {} != {}", | ||
&exprs[b_id], &exprs[b_width_id], &exprs[b_op_out_width_id] | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// returns the egg id of the output width, if `expr` has one | ||
fn get_output_width_id(expr: &ENodeOrVar<Arith>) -> Option<usize> { | ||
if let ENodeOrVar::ENode(expr) = expr { | ||
if is_bin_op(expr) { | ||
// w, w_a, s_a, a, w_b, s_b, b | ||
Some(usize::from(expr.children()[0])) | ||
} else { | ||
None | ||
} | ||
} else { | ||
None | ||
} | ||
} | ||
|
||
/// returns all our rewrites in a format that can be directly used by egg | ||
pub fn create_egg_rewrites() -> Vec<Rewrite<Arith, ()>> { | ||
create_rewrites() | ||
.into_iter() | ||
.map(|r| r.to_egg()) | ||
.reduce(|mut a, mut b| { | ||
a.append(&mut b); | ||
a | ||
}) | ||
.unwrap_or(vec![]) | ||
} | ||
|
||
fn to_pdf(filename: &str, egraph: &EGraph) -> std::io::Result<()> { | ||
use std::process::{Command, Stdio}; | ||
let mut child = Command::new("dot") | ||
|
@@ -747,30 +520,52 @@ fn write_to_dot(out: &mut impl Write, egraph: &EGraph) -> std::io::Result<()> { | |
write!(out, "}}") | ||
} | ||
|
||
pub type EGraph = egg::EGraph<Arith, ()>; | ||
|
||
/// Finds a width or sign constant in the e-class referred to by the substitution | ||
/// and returns its value. Errors if no such constant can be found. | ||
pub fn get_const_width_or_sign(egraph: &EGraph, id: Id) -> Option<WidthInt> { | ||
egraph[id] | ||
.nodes | ||
.iter() | ||
.flat_map(|n| match n { | ||
Arith::Width(w) => Some((*w).into()), | ||
Arith::Sign(s) => Some((*s).into()), | ||
Arith::WidthMaxPlus1([a, b]) => { | ||
let a = get_const_width_or_sign(egraph, *a).expect("failed to find constant width"); | ||
let b = get_const_width_or_sign(egraph, *b).expect("failed to find constant width"); | ||
Some(max(a, b) + 1) | ||
} | ||
_ => None, | ||
}) | ||
.next() | ||
} | ||
|
||
#[cfg(test)] | ||
pub(crate) fn verification_fig_1(ctx: &mut Context) -> (ExprRef, ExprRef) { | ||
let a = ctx.bv_symbol("A", 16); | ||
let b = ctx.bv_symbol("B", 16); | ||
let m = ctx.bv_symbol("M", 4); | ||
let n = ctx.bv_symbol("N", 4); | ||
let spec = ctx.build(|c| { | ||
c.mul( | ||
c.zero_extend(c.shift_left(c.zero_extend(a, 15), c.zero_extend(m, 27)), 32), | ||
c.zero_extend(c.shift_left(c.zero_extend(b, 15), c.zero_extend(n, 27)), 32), | ||
) | ||
}); | ||
let implementation = ctx.build(|c| { | ||
c.shift_left( | ||
c.zero_extend(c.mul(c.zero_extend(a, 16), c.zero_extend(b, 16)), 31), | ||
c.zero_extend(c.add(c.zero_extend(m, 1), c.zero_extend(n, 1)), 58), | ||
) | ||
}); | ||
(spec, implementation) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
fn verification_fig_1(ctx: &mut Context) -> (ExprRef, ExprRef) { | ||
let a = ctx.bv_symbol("A", 16); | ||
let b = ctx.bv_symbol("B", 16); | ||
let m = ctx.bv_symbol("M", 4); | ||
let n = ctx.bv_symbol("N", 4); | ||
let spec = ctx.build(|c| { | ||
c.mul( | ||
c.zero_extend(c.shift_left(c.zero_extend(a, 15), c.zero_extend(m, 27)), 32), | ||
c.zero_extend(c.shift_left(c.zero_extend(b, 15), c.zero_extend(n, 27)), 32), | ||
) | ||
}); | ||
let implementation = ctx.build(|c| { | ||
c.shift_left( | ||
c.zero_extend(c.mul(c.zero_extend(a, 16), c.zero_extend(b, 16)), 31), | ||
c.zero_extend(c.add(c.zero_extend(m, 1), c.zero_extend(n, 1)), 58), | ||
) | ||
}); | ||
(spec, implementation) | ||
} | ||
|
||
#[test] | ||
fn test_data_path_verification_fig_1_conversion() { | ||
let mut ctx = Context::default(); | ||
|
@@ -787,56 +582,4 @@ mod tests { | |
assert_eq!(spec_back, spec); | ||
assert_eq!(impl_back, implementation); | ||
} | ||
|
||
#[test] | ||
fn test_data_path_verification_fig_1_rewrites() { | ||
let mut ctx = Context::default(); | ||
let (spec, implementation) = verification_fig_1(&mut ctx); | ||
let spec_e = to_arith(&ctx, spec); | ||
let impl_e = to_arith(&ctx, implementation); | ||
|
||
println!("{spec_e}"); | ||
println!("{impl_e}"); | ||
|
||
// run egraph operations | ||
let egg_rewrites = create_egg_rewrites(); | ||
let runner = egg::Runner::default() | ||
.with_expr(&spec_e) | ||
.with_expr(&impl_e) | ||
.run(&egg_rewrites); | ||
|
||
runner.print_report(); | ||
|
||
let spec_class = runner.roots[0]; | ||
let impl_class = runner.roots[1]; | ||
println!("{spec_class} {impl_class}"); | ||
|
||
// to_pdf("graph.pdf", &runner.egraph).unwrap(); | ||
// runner.egraph.dot().to_pdf("full_graph.pdf").unwrap(); | ||
} | ||
|
||
#[test] | ||
fn test_rewrites() { | ||
let mut ctx = Context::default(); | ||
let a = ctx.bv_symbol("A", 16); | ||
let b = ctx.bv_symbol("B", 16); | ||
let in_smt_expr = ctx.add(a, b); | ||
assert_eq!(in_smt_expr.serialize_to_str(&ctx), "add(A, B)"); | ||
|
||
// run egraph operations | ||
let egg_expr_in = to_arith(&ctx, in_smt_expr); | ||
let egg_rewrites = create_egg_rewrites(); | ||
let runner = egg::Runner::default() | ||
.with_expr(&egg_expr_in) | ||
.run(&egg_rewrites); | ||
|
||
// check how many different nodes are representing the root node now | ||
let root = runner.roots[0]; | ||
let root_nodes = &runner.egraph[root].nodes; | ||
assert_eq!( | ||
root_nodes.len(), | ||
2, | ||
"there should be two nodes if the rule has been applied" | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,9 @@ | |
// released under BSD 3-Clause License | ||
// author: Kevin Laeufer <[email protected]> | ||
mod arithmetic; | ||
pub use arithmetic::{create_rewrites, from_arith, to_arith, Arith, ArithRewrite, Sign}; | ||
mod rewrites; | ||
|
||
pub use arithmetic::{ | ||
from_arith, get_const_width_or_sign, is_bin_op, to_arith, Arith, EGraph, Sign, | ||
}; | ||
pub use rewrites::{create_egg_rewrites, create_rewrites, ArithRewrite}; |
Oops, something went wrong.