Skip to content

Commit

Permalink
[egrapj] wip: change how rewrites are stored
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 9, 2024
1 parent f2e162b commit 7a22776
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 39 deletions.
31 changes: 12 additions & 19 deletions patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ define_language! {
">>" = RightShift([Id; 7]),
">>>" = ArithmeticRightShift([Id; 7]),
Symbol(ArithSymbol),
Width(WidthInt),
Signed(bool),
// used for signedness (0 = unsigned, 1 = signed) or width parameters
WidthConst(WidthInt),
}
}

Expand Down Expand Up @@ -139,6 +139,7 @@ pub fn to_arith(ctx: &Context, e: ExprRef) -> egg::RecExpr<Arith> {
out
}

#[allow(clippy::too_many_arguments)]
fn convert_bin_op(
ctx: &Context,
out: &mut RecExpr<Arith>,
Expand All @@ -157,11 +158,11 @@ fn convert_bin_op(
debug_assert_eq!(width_out, a.get_bv_type(ctx).unwrap());
debug_assert_eq!(width_out, b.get_bv_type(ctx).unwrap());
// convert signedness and widths into e-nodes
let width_out = out.add(Arith::Width(width_out));
let width_a = out.add(Arith::Width(width_a));
let width_b = out.add(Arith::Width(width_b));
let sign_a = out.add(Arith::Signed(sign_a));
let sign_b = out.add(Arith::Signed(sign_b));
let width_out = out.add(Arith::WidthConst(width_out));
let width_a = out.add(Arith::WidthConst(width_a));
let width_b = out.add(Arith::WidthConst(width_b));
let sign_a = out.add(Arith::WidthConst(sign_a as WidthInt));
let sign_b = out.add(Arith::WidthConst(sign_b as WidthInt));
out.add(op([
width_out,
width_a,
Expand Down Expand Up @@ -215,8 +216,7 @@ pub fn from_arith(ctx: &mut Context, expr: &RecExpr<Arith>) -> ExprRef {
Arith::ArithmeticRightShift(_) => patronus_bin_op(ctx, &mut stack, |ctx, a, b| {
ctx.arithmetic_shift_right(a, b)
}),
Arith::Width(width) => ctx.bit_vec_val(*width, 32),
Arith::Signed(is_minus) => ctx.bit_vec_val(*is_minus, 1),
Arith::WidthConst(width) => ctx.bit_vec_val(*width, 32),
};
stack.push(result);
}
Expand Down Expand Up @@ -303,27 +303,20 @@ fn commute_add_condition(
let wo = get_width_from_e_graph(egraph, subst, wo);
let wa = get_width_from_e_graph(egraph, subst, wa);
let wb = get_width_from_e_graph(egraph, subst, wb);
let sa = get_signed_from_e_graph(egraph, subst, sa);
let sb = get_signed_from_e_graph(egraph, subst, sb);
let sa = get_width_from_e_graph(egraph, subst, sa);
let sb = get_width_from_e_graph(egraph, subst, sb);
// actual condition
wa == wb && wo >= wa
}
}

fn get_width_from_e_graph(egraph: &mut EGraph, subst: &egg::Subst, v: Var) -> WidthInt {
match egraph[subst[v]].nodes.as_slice() {
[Arith::Width(w)] => *w,
[Arith::WidthConst(w)] => *w,
_ => unreachable!("expected a width!"),
}
}

fn get_signed_from_e_graph(egraph: &mut EGraph, subst: &egg::Subst, v: Var) -> bool {
match egraph[subst[v]].nodes.as_slice() {
[Arith::Signed(s)] => *s,
_ => unreachable!("expected a signed!"),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
89 changes: 82 additions & 7 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,89 @@ struct Args {
dump_smt: bool,
#[arg(long)]
bdd_formula: bool,
#[arg(
long,
help = "checks the current condition, prints out if it disagrees with the examples we generate"
)]
check_cond: bool,
#[arg(value_name = "RULE", index = 1)]
rule: String,
}

/// our version of the egg re-write macro
macro_rules! arith_rewrite {
(
$name:expr;
$lhs:expr => $rhs:expr;
if $cond:expr
) => {{
ArithRewrite::new($name, $lhs, $rhs)
}};
}

struct ArithRewrite {
name: String,
lhs: Pattern<Arith>,
rhs: Pattern<Arith>,
}

impl ArithRewrite {
fn new(name: &str, lhs: &str, rhs: &str) -> Self {
Self {
name: name.to_string(),
lhs: lhs.parse::<_>().unwrap(),
rhs: rhs.parse::<_>().unwrap(),
}
}

fn to_egg(&self) -> Rewrite<Arith, ()> {
Rewrite::new(self.name.clone(), self.lhs.clone(), self.rhs.clone()).unwrap()
}
}

fn create_rewrites() -> Vec<Rewrite<Arith, ()>> {
vec![
rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
rewrite!("merge-left-shift";
let rewrites = vec![
arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"; if true),
arith_rewrite!("merge-left-shift";
// we require that b, c and (b + c) are all unsigned
"(<< ?wo ?wab ?sab (<< ?wab ?wa ?sa ?a ?wb 0 ?b) ?wc 0 ?c)" =>
// note: in this version we set the width of (b + c) on the RHS to be the width of the
// result (w_o)
"(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))"),
]
"(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))"; if merge_left_shift_cond("?wo", "?wa", "?sa", "?wb", "?wc")),
];
rewrites.into_iter().map(|r| r.to_egg()).collect()
}

type EGraph = egg::EGraph<Arith, ()>;

fn merge_left_shift_cond(
wo: &'static str,
wa: &'static str,
sa: &'static str,
wb: &'static str,
wc: &'static str,
) -> impl Fn(&mut EGraph, Id, &Subst) -> bool {
let wo = wo.parse().unwrap();
let wa = wa.parse().unwrap();
let sa = sa.parse().unwrap();
let wb = wb.parse().unwrap();
let wc = wc.parse().unwrap();
move |egraph, _, subst| {
let wo = get_width_from_e_graph(egraph, subst, wo);
let wa = get_width_from_e_graph(egraph, subst, wa);
let sa = get_width_from_e_graph(egraph, subst, sa);
let wb = get_width_from_e_graph(egraph, subst, wb);
let wc = get_width_from_e_graph(egraph, subst, wc);
// actual condition
wa == wb && wo >= wa
}
}

fn get_width_from_e_graph(egraph: &mut EGraph, subst: &Subst, v: Var) -> WidthInt {
match egraph[subst[v]].nodes.as_slice() {
[Arith::WidthConst(w)] => *w,
_ => unreachable!("expected a width!"),
}
}

fn main() {
Expand All @@ -62,8 +131,14 @@ fn main() {
}
};

let (samples, rule_info) =
samples::generate_samples(&args.rule, rule, args.max_width, true, args.dump_smt);
let (samples, rule_info) = samples::generate_samples(
&args.rule,
rule,
args.max_width,
true,
args.dump_smt,
args.check_cond,
);
let delta_t = std::time::Instant::now() - start;

println!("Found {} equivalent rewrites.", samples.num_equivalent());
Expand Down
39 changes: 26 additions & 13 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub fn generate_samples(
max_width: WidthInt,
show_progress: bool,
dump_smt: bool,
check_cond: bool,
) -> (Samples, RuleInfo) {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
println!("{}: {} => {}", rule_name, lhs, rhs);
Expand Down Expand Up @@ -209,6 +210,12 @@ fn extract_patterns<L: Language>(
Some((left, right))
}

// fn extract_condition<L: Language>(
// rule: &Rewrite<L, ()>,
// ) {
// rule.applier.apply_matches()
// }

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RuleInfo {
/// width parameters
Expand Down Expand Up @@ -378,19 +385,18 @@ fn analyze_pattern(pat: &PatternAst<Arith>) -> RuleInfo {

fn symbol_from_pattern(pat: &PatternAst<Arith>, a: Id, w: Id, s: Id) -> Option<RuleSymbol> {
if let ENodeOrVar::Var(var) = pat[a] {
let width = width_or_sign_from_pattern(pat, w);
let sign = width_or_sign_from_pattern(pat, s);
let width = width_const_from_pattern(pat, w);
let sign = width_const_from_pattern(pat, s);
Some(RuleSymbol { var, width, sign })
} else {
None
}
}

fn width_or_sign_from_pattern(pat: &PatternAst<Arith>, id: Id) -> VarOrConst {
fn width_const_from_pattern(pat: &PatternAst<Arith>, id: Id) -> VarOrConst {
match &pat[id] {
ENodeOrVar::ENode(node) => match node {
&Arith::Width(w) => VarOrConst::C(w),
&Arith::Signed(s) => VarOrConst::C(s as WidthInt),
&Arith::WidthConst(w) => VarOrConst::C(w),
_ => unreachable!("not a widht!"),
},
ENodeOrVar::Var(var) => VarOrConst::V(*var),
Expand Down Expand Up @@ -418,10 +424,11 @@ fn gen_substitution(
let assignment = FxHashMap::from_iter(assignment.clone());
let mut out = FxHashMap::default();
for &width_var in rule.widths.iter() {
out.insert(width_var, Arith::Width(assignment[&width_var]));
out.insert(width_var, Arith::WidthConst(assignment[&width_var]));
}
for &sign_var in rule.signs.iter() {
out.insert(sign_var, Arith::Signed(assignment[&sign_var] != 0));
debug_assert!(assignment[&sign_var] <= 1);
out.insert(sign_var, Arith::WidthConst(assignment[&sign_var]));
}
for child in rule.symbols.iter() {
let width = match child.width {
Expand Down Expand Up @@ -497,20 +504,26 @@ mod tests {
width: b.get_bv_type(&ctx).unwrap(),
};
let subst = FxHashMap::from_iter([
(Var::from_str("?wo").unwrap(), Arith::Width(2)),
(Var::from_str("?wa").unwrap(), Arith::Width(a_arith.width)),
(Var::from_str("?sa").unwrap(), Arith::Signed(true)),
(Var::from_str("?wo").unwrap(), Arith::WidthConst(2)),
(
Var::from_str("?wa").unwrap(),
Arith::WidthConst(a_arith.width),
),
(Var::from_str("?sa").unwrap(), Arith::WidthConst(1)),
(Var::from_str("?a").unwrap(), Arith::Symbol(a_arith)),
(Var::from_str("?wb").unwrap(), Arith::Width(b_arith.width)),
(Var::from_str("?sb").unwrap(), Arith::Signed(true)),
(
Var::from_str("?wb").unwrap(),
Arith::WidthConst(b_arith.width),
),
(Var::from_str("?sb").unwrap(), Arith::WidthConst(1)),
(Var::from_str("?b").unwrap(), Arith::Symbol(b_arith)),
]);

assert_eq!(lhs.to_string(), "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)");
let lhs_sub = instantiate_pattern(lhs, &subst);
assert_eq!(
lhs_sub.to_string(),
"(+ 2 1 true \"StringRef(0):bv<1>\" 1 true \"StringRef(1):bv<1>\")"
"(+ 2 1 1 \"StringRef(0):bv<1>\" 1 1 \"StringRef(1):bv<1>\")"
);
}
}

0 comments on commit 7a22776

Please sign in to comment.