From 7a227763d9289c16e0ce87dc2d152594f926856e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Mon, 9 Dec 2024 15:04:48 -0500 Subject: [PATCH] [egrapj] wip: change how rewrites are stored --- patronus-egraphs/src/arithmetic.rs | 31 ++++----- tools/egraphs-cond-synth/src/main.rs | 89 +++++++++++++++++++++++-- tools/egraphs-cond-synth/src/samples.rs | 39 +++++++---- 3 files changed, 120 insertions(+), 39 deletions(-) diff --git a/patronus-egraphs/src/arithmetic.rs b/patronus-egraphs/src/arithmetic.rs index be3b3f5..57de7d7 100644 --- a/patronus-egraphs/src/arithmetic.rs +++ b/patronus-egraphs/src/arithmetic.rs @@ -23,8 +23,8 @@ define_language! { ">>" = RightShift([Id; 7]), ">>>" = ArithmeticRightShift([Id; 7]), Symbol(ArithSymbol), - Width(WidthInt), - Signed(bool), + // used for signedness (0 = unsigned, 1 = signed) or width parameters + WidthConst(WidthInt), } } @@ -139,6 +139,7 @@ pub fn to_arith(ctx: &Context, e: ExprRef) -> egg::RecExpr { out } +#[allow(clippy::too_many_arguments)] fn convert_bin_op( ctx: &Context, out: &mut RecExpr, @@ -157,11 +158,11 @@ fn convert_bin_op( debug_assert_eq!(width_out, a.get_bv_type(ctx).unwrap()); debug_assert_eq!(width_out, b.get_bv_type(ctx).unwrap()); // convert signedness and widths into e-nodes - let width_out = out.add(Arith::Width(width_out)); - let width_a = out.add(Arith::Width(width_a)); - let width_b = out.add(Arith::Width(width_b)); - let sign_a = out.add(Arith::Signed(sign_a)); - let sign_b = out.add(Arith::Signed(sign_b)); + let width_out = out.add(Arith::WidthConst(width_out)); + let width_a = out.add(Arith::WidthConst(width_a)); + let width_b = out.add(Arith::WidthConst(width_b)); + let sign_a = out.add(Arith::WidthConst(sign_a as WidthInt)); + let sign_b = out.add(Arith::WidthConst(sign_b as WidthInt)); out.add(op([ width_out, width_a, @@ -215,8 +216,7 @@ pub fn from_arith(ctx: &mut Context, expr: &RecExpr) -> ExprRef { Arith::ArithmeticRightShift(_) => patronus_bin_op(ctx, &mut stack, |ctx, a, b| { ctx.arithmetic_shift_right(a, b) }), - Arith::Width(width) => ctx.bit_vec_val(*width, 32), - Arith::Signed(is_minus) => ctx.bit_vec_val(*is_minus, 1), + Arith::WidthConst(width) => ctx.bit_vec_val(*width, 32), }; stack.push(result); } @@ -303,8 +303,8 @@ fn commute_add_condition( let wo = get_width_from_e_graph(egraph, subst, wo); let wa = get_width_from_e_graph(egraph, subst, wa); let wb = get_width_from_e_graph(egraph, subst, wb); - let sa = get_signed_from_e_graph(egraph, subst, sa); - let sb = get_signed_from_e_graph(egraph, subst, sb); + let sa = get_width_from_e_graph(egraph, subst, sa); + let sb = get_width_from_e_graph(egraph, subst, sb); // actual condition wa == wb && wo >= wa } @@ -312,18 +312,11 @@ fn commute_add_condition( fn get_width_from_e_graph(egraph: &mut EGraph, subst: &egg::Subst, v: Var) -> WidthInt { match egraph[subst[v]].nodes.as_slice() { - [Arith::Width(w)] => *w, + [Arith::WidthConst(w)] => *w, _ => unreachable!("expected a width!"), } } -fn get_signed_from_e_graph(egraph: &mut EGraph, subst: &egg::Subst, v: Var) -> bool { - match egraph[subst[v]].nodes.as_slice() { - [Arith::Signed(s)] => *s, - _ => unreachable!("expected a signed!"), - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/tools/egraphs-cond-synth/src/main.rs b/tools/egraphs-cond-synth/src/main.rs index b6e2148..14b2e1f 100644 --- a/tools/egraphs-cond-synth/src/main.rs +++ b/tools/egraphs-cond-synth/src/main.rs @@ -27,20 +27,89 @@ struct Args { dump_smt: bool, #[arg(long)] bdd_formula: bool, + #[arg( + long, + help = "checks the current condition, prints out if it disagrees with the examples we generate" + )] + check_cond: bool, #[arg(value_name = "RULE", index = 1)] rule: String, } +/// our version of the egg re-write macro +macro_rules! arith_rewrite { + ( + $name:expr; + $lhs:expr => $rhs:expr; + if $cond:expr + ) => {{ + ArithRewrite::new($name, $lhs, $rhs) + }}; +} + +struct ArithRewrite { + name: String, + lhs: Pattern, + rhs: Pattern, +} + +impl ArithRewrite { + fn new(name: &str, lhs: &str, rhs: &str) -> Self { + Self { + name: name.to_string(), + lhs: lhs.parse::<_>().unwrap(), + rhs: rhs.parse::<_>().unwrap(), + } + } + + fn to_egg(&self) -> Rewrite { + Rewrite::new(self.name.clone(), self.lhs.clone(), self.rhs.clone()).unwrap() + } +} + fn create_rewrites() -> Vec> { - vec![ - rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"), - rewrite!("merge-left-shift"; + let rewrites = vec![ + arith_rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"; if true), + arith_rewrite!("merge-left-shift"; // we require that b, c and (b + c) are all unsigned "(<< ?wo ?wab ?sab (<< ?wab ?wa ?sa ?a ?wb 0 ?b) ?wc 0 ?c)" => // note: in this version we set the width of (b + c) on the RHS to be the width of the // result (w_o) - "(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))"), - ] + "(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))"; if merge_left_shift_cond("?wo", "?wa", "?sa", "?wb", "?wc")), + ]; + rewrites.into_iter().map(|r| r.to_egg()).collect() +} + +type EGraph = egg::EGraph; + +fn merge_left_shift_cond( + wo: &'static str, + wa: &'static str, + sa: &'static str, + wb: &'static str, + wc: &'static str, +) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let wo = wo.parse().unwrap(); + let wa = wa.parse().unwrap(); + let sa = sa.parse().unwrap(); + let wb = wb.parse().unwrap(); + let wc = wc.parse().unwrap(); + move |egraph, _, subst| { + let wo = get_width_from_e_graph(egraph, subst, wo); + let wa = get_width_from_e_graph(egraph, subst, wa); + let sa = get_width_from_e_graph(egraph, subst, sa); + let wb = get_width_from_e_graph(egraph, subst, wb); + let wc = get_width_from_e_graph(egraph, subst, wc); + // actual condition + wa == wb && wo >= wa + } +} + +fn get_width_from_e_graph(egraph: &mut EGraph, subst: &Subst, v: Var) -> WidthInt { + match egraph[subst[v]].nodes.as_slice() { + [Arith::WidthConst(w)] => *w, + _ => unreachable!("expected a width!"), + } } fn main() { @@ -62,8 +131,14 @@ fn main() { } }; - let (samples, rule_info) = - samples::generate_samples(&args.rule, rule, args.max_width, true, args.dump_smt); + let (samples, rule_info) = samples::generate_samples( + &args.rule, + rule, + args.max_width, + true, + args.dump_smt, + args.check_cond, + ); let delta_t = std::time::Instant::now() - start; println!("Found {} equivalent rewrites.", samples.num_equivalent()); diff --git a/tools/egraphs-cond-synth/src/samples.rs b/tools/egraphs-cond-synth/src/samples.rs index 6f25952..465745d 100644 --- a/tools/egraphs-cond-synth/src/samples.rs +++ b/tools/egraphs-cond-synth/src/samples.rs @@ -16,6 +16,7 @@ pub fn generate_samples( max_width: WidthInt, show_progress: bool, dump_smt: bool, + check_cond: bool, ) -> (Samples, RuleInfo) { let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule"); println!("{}: {} => {}", rule_name, lhs, rhs); @@ -209,6 +210,12 @@ fn extract_patterns( Some((left, right)) } +// fn extract_condition( +// rule: &Rewrite, +// ) { +// rule.applier.apply_matches() +// } + #[derive(Debug, Clone, Eq, PartialEq)] pub struct RuleInfo { /// width parameters @@ -378,19 +385,18 @@ fn analyze_pattern(pat: &PatternAst) -> RuleInfo { fn symbol_from_pattern(pat: &PatternAst, a: Id, w: Id, s: Id) -> Option { if let ENodeOrVar::Var(var) = pat[a] { - let width = width_or_sign_from_pattern(pat, w); - let sign = width_or_sign_from_pattern(pat, s); + let width = width_const_from_pattern(pat, w); + let sign = width_const_from_pattern(pat, s); Some(RuleSymbol { var, width, sign }) } else { None } } -fn width_or_sign_from_pattern(pat: &PatternAst, id: Id) -> VarOrConst { +fn width_const_from_pattern(pat: &PatternAst, id: Id) -> VarOrConst { match &pat[id] { ENodeOrVar::ENode(node) => match node { - &Arith::Width(w) => VarOrConst::C(w), - &Arith::Signed(s) => VarOrConst::C(s as WidthInt), + &Arith::WidthConst(w) => VarOrConst::C(w), _ => unreachable!("not a widht!"), }, ENodeOrVar::Var(var) => VarOrConst::V(*var), @@ -418,10 +424,11 @@ fn gen_substitution( let assignment = FxHashMap::from_iter(assignment.clone()); let mut out = FxHashMap::default(); for &width_var in rule.widths.iter() { - out.insert(width_var, Arith::Width(assignment[&width_var])); + out.insert(width_var, Arith::WidthConst(assignment[&width_var])); } for &sign_var in rule.signs.iter() { - out.insert(sign_var, Arith::Signed(assignment[&sign_var] != 0)); + debug_assert!(assignment[&sign_var] <= 1); + out.insert(sign_var, Arith::WidthConst(assignment[&sign_var])); } for child in rule.symbols.iter() { let width = match child.width { @@ -497,12 +504,18 @@ mod tests { width: b.get_bv_type(&ctx).unwrap(), }; let subst = FxHashMap::from_iter([ - (Var::from_str("?wo").unwrap(), Arith::Width(2)), - (Var::from_str("?wa").unwrap(), Arith::Width(a_arith.width)), - (Var::from_str("?sa").unwrap(), Arith::Signed(true)), + (Var::from_str("?wo").unwrap(), Arith::WidthConst(2)), + ( + Var::from_str("?wa").unwrap(), + Arith::WidthConst(a_arith.width), + ), + (Var::from_str("?sa").unwrap(), Arith::WidthConst(1)), (Var::from_str("?a").unwrap(), Arith::Symbol(a_arith)), - (Var::from_str("?wb").unwrap(), Arith::Width(b_arith.width)), - (Var::from_str("?sb").unwrap(), Arith::Signed(true)), + ( + Var::from_str("?wb").unwrap(), + Arith::WidthConst(b_arith.width), + ), + (Var::from_str("?sb").unwrap(), Arith::WidthConst(1)), (Var::from_str("?b").unwrap(), Arith::Symbol(b_arith)), ]); @@ -510,7 +523,7 @@ mod tests { let lhs_sub = instantiate_pattern(lhs, &subst); assert_eq!( lhs_sub.to_string(), - "(+ 2 1 true \"StringRef(0):bv<1>\" 1 true \"StringRef(1):bv<1>\")" + "(+ 2 1 1 \"StringRef(0):bv<1>\" 1 1 \"StringRef(1):bv<1>\")" ); } }