From 6f5e9b5505fd9d99343a9b9f093bbdf773f8ccbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Tue, 10 Dec 2024 10:26:18 -0500 Subject: [PATCH] cond synth: collect features --- tools/egraphs-cond-synth/Cargo.toml | 1 + tools/egraphs-cond-synth/src/summarize.rs | 65 ++++++++++++++++++----- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/tools/egraphs-cond-synth/Cargo.toml b/tools/egraphs-cond-synth/Cargo.toml index 54c2e89..864dbb3 100644 --- a/tools/egraphs-cond-synth/Cargo.toml +++ b/tools/egraphs-cond-synth/Cargo.toml @@ -19,3 +19,4 @@ boolean_expression.workspace = true indicatif = "0.17.9" rayon = "1.10.0" thread_local = "1.1.8" +bitvec = "1.0.1" diff --git a/tools/egraphs-cond-synth/src/summarize.rs b/tools/egraphs-cond-synth/src/summarize.rs index ab8dfe5..d5fadcd 100644 --- a/tools/egraphs-cond-synth/src/summarize.rs +++ b/tools/egraphs-cond-synth/src/summarize.rs @@ -3,30 +3,34 @@ // author: Kevin Laeufer use crate::samples::{RuleInfo, Samples}; +use bitvec::macros::internal::funty::Fundamental; +use bitvec::prelude as bv; use egg::Var; use patronus::expr::WidthInt; use rustc_hash::FxHashMap; /// generate a simplified re-write condition from samples, using BDDs pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { - // generate all labels and the corresponding BDD terminals - let labels = get_labels(rule); + let results = check_features(rule, samples); + + // generate BDD terminals let mut bdd = boolean_expression::BDD::::new(); - let vars: Vec<_> = labels.iter().map(|ii| bdd.terminal(ii.clone())).collect(); + let vars: Vec<_> = results + .labels() + .iter() + .map(|ii| bdd.terminal(ii.clone())) + .collect(); + + println!("There are {} features", results.num_features()); // start condition as trivially `false` let mut cond = boolean_expression::BDD_ZERO; - for (assignment, is_equal) in samples.iter() { - let v = FxHashMap::from_iter(assignment); - let mut outputs = vec![]; - for feature in FEATURES.iter() { - (feature.eval)(rule, &v, &mut outputs); - } - let lits = outputs + for (features, is_equal) in results.iter() { + let lits = features .into_iter() .enumerate() .map(|(terminal, is_true)| { - if is_true { + if is_true.as_bool() { vars[terminal] } else { bdd.not(vars[terminal]) @@ -43,6 +47,43 @@ pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { format!("{:?}", bdd.to_expr(cond)) } +pub fn check_features(rule: &RuleInfo, samples: &Samples) -> FeatureResult { + let labels = get_labels(rule); + let mut results = bv::BitVec::new(); + + for (assignment, is_equal) in samples.iter() { + let v = FxHashMap::from_iter(assignment); + results.push(is_equal); + for feature in FEATURES.iter() { + (feature.eval)(rule, &v, &mut results); + } + } + + FeatureResult { labels, results } +} + +pub struct FeatureResult { + labels: Vec, + results: bv::BitVec, +} + +impl FeatureResult { + pub fn num_features(&self) -> usize { + self.labels.len() + } + pub fn labels(&self) -> &[String] { + &self.labels + } + pub fn iter(&self) -> impl Iterator + '_ { + let cs = self.num_features() + 1; + self.results.chunks(cs).map(|c| { + let is_equivalent = c[0]; + let features = &c[1..]; + (features, is_equivalent) + }) + } +} + fn get_labels(rule: &RuleInfo) -> Vec { FEATURES .iter() @@ -167,5 +208,5 @@ const FEATURES: &[Feature] = &[ struct Feature { name: &'static str, labels: fn(rule: &RuleInfo) -> Vec, - eval: fn(rule: &RuleInfo, v: &FxHashMap, out: &mut Vec), + eval: fn(rule: &RuleInfo, v: &FxHashMap, out: &mut bv::BitVec), }