From 271fb0450dfa0db6900de23c1df213f2666adb48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Mon, 9 Dec 2024 13:55:01 -0500 Subject: [PATCH] add bdd based formula generator --- .github/workflows/test.yml | 2 +- tools/egraphs-cond-synth/src/main.rs | 8 ++ tools/egraphs-cond-synth/src/summarize.rs | 103 ++++++++++++++++++---- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a543463..48613bd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -174,7 +174,7 @@ jobs: run: cargo build --verbose --release -p patronus-egraphs-cond-synth - name: synthesize commute-add condition run: | - cargo run --release -p patronus-egraphs-cond-synth -- commute-add + cargo run --release -p patronus-egraphs-cond-synth -- --bdd-formula commute-add semver: name: Check Semantic Versioning of Patronus diff --git a/tools/egraphs-cond-synth/src/main.rs b/tools/egraphs-cond-synth/src/main.rs index f2442a2..b6e2148 100644 --- a/tools/egraphs-cond-synth/src/main.rs +++ b/tools/egraphs-cond-synth/src/main.rs @@ -8,6 +8,7 @@ mod samples; mod summarize; +use crate::summarize::bdd_summarize; use clap::Parser; use egg::*; use patronus::expr::*; @@ -80,4 +81,11 @@ fn main() { println!("{:?}", sample); } } + + if args.bdd_formula { + let summarize_start = std::time::Instant::now(); + let formula = bdd_summarize(&rule_info, &samples); + let summarize_delta_t = std::time::Instant::now() - summarize_start; + println!("Generated formula in {summarize_delta_t:?}:\n{}", formula); + } } diff --git a/tools/egraphs-cond-synth/src/summarize.rs b/tools/egraphs-cond-synth/src/summarize.rs index 6f6496d..b85df95 100644 --- a/tools/egraphs-cond-synth/src/summarize.rs +++ b/tools/egraphs-cond-synth/src/summarize.rs @@ -7,17 +7,63 @@ use egg::Var; use patronus::expr::WidthInt; use rustc_hash::FxHashMap; -/// generate a simplified re-writ condition from samples, using BDDs -pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) { +/// generate a simplified re-write condition from samples, using BDDs +pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { + // generate all labels and the corresponding BDD terminals + let labels = get_labels(rule); + let mut bdd = boolean_expression::BDD::::new(); + let vars: Vec<_> = (0..labels.len()).map(|ii| bdd.terminal(ii)).collect(); + + // start condition as trivially `true` + let mut cond = boolean_expression::BDD_ONE; for (assignment, is_equal) in samples.iter() { let v = FxHashMap::from_iter(assignment); + let mut outputs = vec![]; + for feature in FEATURES.iter() { + (feature.eval)(rule, &v, &mut outputs); + } + let lits = outputs + .into_iter() + .enumerate() + .map(|(terminal, is_true)| { + if is_true { + vars[terminal] + } else { + bdd.not(vars[terminal]) + } + }) + .collect::>(); + let term = lits.into_iter().reduce(|a, b| bdd.and(a, b)).unwrap(); + let term = if is_equal { term } else { bdd.not(term) }; + + cond = bdd.and(cond, term); } + + // extract simplified expression + format!("{:?}", bdd.to_expr(cond)) +} + +fn get_labels(rule: &RuleInfo) -> Vec { + FEATURES + .iter() + .map(|f| (f.labels)(rule)) + .reduce(|mut a, mut b| { + a.append(&mut b); + a + }) + .unwrap_or_default() } const FEATURES: &[Feature] = &[ Feature { name: "is_unsigned", // (13) - len: |r| Some(r.signs().count()), + labels: |r| { + let mut o = vec![]; + for sign in r.signs() { + o.push(format!("!{sign}")); + } + o + }, eval: |r, v, o| { for sign in r.signs() { // s_i == unsign @@ -27,12 +73,16 @@ const FEATURES: &[Feature] = &[ }, Feature { name: "is_width_equal", // (14) - len: |r| { - if r.widths().count() <= 0 { - None - } else { - Some(r.widths().count() * (r.widths().count() - 1)) + labels: |r| { + let mut o = vec![]; + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + o.push(format!("{w_i} == {w_j}")); + } + } } + o }, eval: |r, v, o| { for w_i in r.widths() { @@ -47,12 +97,18 @@ const FEATURES: &[Feature] = &[ }, Feature { name: "is_width_smaller", // (15) + (16) - len: |r| { - if r.widths().count() <= 0 { - None - } else { - Some(r.widths().count() * (r.widths().count() - 1) * 3) + labels: |r| { + let mut o = vec![]; + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + o.push(format!("{w_i} < {w_j}")); + o.push(format!("{w_i} + 1 < {w_j}")); + o.push(format!("{w_i} - 1 < {w_j}")); + } + } } + o }, eval: |r, v, o| { for w_i in r.widths() { @@ -72,12 +128,21 @@ const FEATURES: &[Feature] = &[ }, Feature { name: "is_width_sum_smaller", // (17) + (18) - len: |r| { - if r.widths().count() <= 1 { - None - } else { - Some(r.widths().count() * (r.widths().count() - 1) * (r.widths().count() - 2) * 2) + labels: |r| { + let mut o = vec![]; + for w_i in r.widths() { + for w_j in r.widths() { + if w_i != w_j { + for w_k in r.widths() { + if w_k != w_i && w_k != w_j { + o.push(format!("{w_i} + {w_j} < {w_k}")); + o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64")); + } + } + } + } } + o }, eval: |r, v, o| { for w_i in r.widths() { @@ -101,6 +166,6 @@ const FEATURES: &[Feature] = &[ struct Feature { name: &'static str, - len: fn(rule: &RuleInfo) -> Option, + labels: fn(rule: &RuleInfo) -> Vec, eval: fn(rule: &RuleInfo, v: &FxHashMap, out: &mut Vec), }