Skip to content

Commit

Permalink
cond synth: collect features
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 10, 2024
1 parent 7a22776 commit 6f5e9b5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 12 deletions.
1 change: 1 addition & 0 deletions tools/egraphs-cond-synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ boolean_expression.workspace = true
indicatif = "0.17.9"
rayon = "1.10.0"
thread_local = "1.1.8"
bitvec = "1.0.1"
65 changes: 53 additions & 12 deletions tools/egraphs-cond-synth/src/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,34 @@
// author: Kevin Laeufer <[email protected]>

use crate::samples::{RuleInfo, Samples};
use bitvec::macros::internal::funty::Fundamental;
use bitvec::prelude as bv;
use egg::Var;
use patronus::expr::WidthInt;
use rustc_hash::FxHashMap;

/// generate a simplified re-write condition from samples, using BDDs
pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
// generate all labels and the corresponding BDD terminals
let labels = get_labels(rule);
let results = check_features(rule, samples);

// generate BDD terminals
let mut bdd = boolean_expression::BDD::<String>::new();
let vars: Vec<_> = labels.iter().map(|ii| bdd.terminal(ii.clone())).collect();
let vars: Vec<_> = results
.labels()
.iter()
.map(|ii| bdd.terminal(ii.clone()))
.collect();

println!("There are {} features", results.num_features());

// start condition as trivially `false`
let mut cond = boolean_expression::BDD_ZERO;
for (assignment, is_equal) in samples.iter() {
let v = FxHashMap::from_iter(assignment);
let mut outputs = vec![];
for feature in FEATURES.iter() {
(feature.eval)(rule, &v, &mut outputs);
}
let lits = outputs
for (features, is_equal) in results.iter() {
let lits = features
.into_iter()
.enumerate()
.map(|(terminal, is_true)| {
if is_true {
if is_true.as_bool() {
vars[terminal]
} else {
bdd.not(vars[terminal])
Expand All @@ -43,6 +47,43 @@ pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
format!("{:?}", bdd.to_expr(cond))
}

pub fn check_features(rule: &RuleInfo, samples: &Samples) -> FeatureResult {
let labels = get_labels(rule);
let mut results = bv::BitVec::new();

for (assignment, is_equal) in samples.iter() {
let v = FxHashMap::from_iter(assignment);
results.push(is_equal);
for feature in FEATURES.iter() {
(feature.eval)(rule, &v, &mut results);
}
}

FeatureResult { labels, results }
}

pub struct FeatureResult {
labels: Vec<String>,
results: bv::BitVec,
}

impl FeatureResult {
pub fn num_features(&self) -> usize {
self.labels.len()
}
pub fn labels(&self) -> &[String] {
&self.labels
}
pub fn iter(&self) -> impl Iterator<Item = (&bv::BitSlice, bool)> + '_ {
let cs = self.num_features() + 1;
self.results.chunks(cs).map(|c| {
let is_equivalent = c[0];
let features = &c[1..];
(features, is_equivalent)
})
}
}

fn get_labels(rule: &RuleInfo) -> Vec<String> {
FEATURES
.iter()
Expand Down Expand Up @@ -167,5 +208,5 @@ const FEATURES: &[Feature] = &[
struct Feature {
name: &'static str,
labels: fn(rule: &RuleInfo) -> Vec<String>,
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut Vec<bool>),
eval: fn(rule: &RuleInfo, v: &FxHashMap<Var, WidthInt>, out: &mut bv::BitVec),
}

0 comments on commit 6f5e9b5

Please sign in to comment.