Skip to content

Commit

Permalink
CSV export
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 10, 2024
1 parent 8ad0cbd commit 7106815
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 7 deletions.
1 change: 1 addition & 0 deletions tools/egraphs-cond-synth/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/*.json
/*.csv
15 changes: 12 additions & 3 deletions tools/egraphs-cond-synth/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::samples::{RuleInfo, Samples};
use crate::samples::{get_var_name, RuleInfo, Samples};
use bitvec::prelude as bv;
use egg::Var;
use patronus::expr::WidthInt;
Expand Down Expand Up @@ -63,7 +63,7 @@ const FEATURES: &[Feature] = &[
labels: |r| {
let mut o = vec![];
for sign in r.signs() {
o.push(format!("!{sign}"));
o.push(format!("!{}", get_var_name(&sign).unwrap()));
}
o
},
Expand All @@ -81,7 +81,11 @@ const FEATURES: &[Feature] = &[
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
o.push(format!("{w_i} == {w_j}"));
o.push(format!(
"{} == {}",
get_var_name(&w_i).unwrap(),
get_var_name(&w_j).unwrap(),
));
}
}
}
Expand All @@ -105,6 +109,8 @@ const FEATURES: &[Feature] = &[
for w_i in r.widths() {
for w_j in r.widths() {
if w_i != w_j {
let w_i = get_var_name(&w_i).unwrap();
let w_j = get_var_name(&w_j).unwrap();
o.push(format!("{w_i} < {w_j}"));
o.push(format!("{w_i} + 1 < {w_j}"));
o.push(format!("{w_i} - 1 < {w_j}"));
Expand Down Expand Up @@ -138,6 +144,9 @@ const FEATURES: &[Feature] = &[
if w_i != w_j {
for w_k in r.widths() {
if w_k != w_i && w_k != w_j {
let w_i = get_var_name(&w_i).unwrap();
let w_j = get_var_name(&w_j).unwrap();
let w_k = get_var_name(&w_k).unwrap();
o.push(format!("{w_i} + {w_j} < {w_k}"));
o.push(format!("{w_i} as u64 + 2u64.pow({w_j}) < {w_k} as u64"));
}
Expand Down
52 changes: 49 additions & 3 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ mod rewrites;
mod samples;
mod summarize;

use crate::features::apply_features;
use crate::features::{apply_features, FeatureResult};
use crate::rewrites::create_rewrites;
use crate::samples::{get_rule_info, Samples};
use crate::samples::{get_rule_info, get_var_name, Samples};
use crate::summarize::bdd_summarize;
use clap::Parser;
use patronus::expr::*;
use std::path::PathBuf;
use std::io::Write;
use std::path::{Path, PathBuf};

#[derive(Parser, Debug)]
#[command(name = "patronus-egraphs-cond-synth")]
Expand Down Expand Up @@ -116,10 +117,55 @@ fn main() {
let feature_delta_t = std::time::Instant::now() - feature_start;
println!("{feature_delta_t:?} to apply all features");

if let Some(filename) = args.write_csv {
write_csv(filename, &samples, &features).expect("failed to write CSV");
}

if args.bdd_formula {
let summarize_start = std::time::Instant::now();
let formula = bdd_summarize(&features);
let summarize_delta_t = std::time::Instant::now() - summarize_start;
println!("Generated formula in {summarize_delta_t:?}:\n{}", formula);
}
}

fn write_csv(
filename: impl AsRef<Path>,
samples: &Samples,
features: &FeatureResult,
) -> std::io::Result<()> {
let mut o = std::io::BufWriter::new(std::fs::File::create(filename)?);

// header
write!(o, "equivalent?,")?;
for var in samples.vars() {
write!(o, "{},", get_var_name(&var).unwrap())?;
}
let num_features = features.num_features();
for (ii, feature) in features.labels().iter().enumerate() {
write!(o, "{}", feature)?;
if ii < num_features - 1 {
write!(o, ",")?;
}
}
writeln!(o)?;

// data
for ((a, a_is_eq), (s, s_is_eq)) in samples.iter().zip(features.iter()) {
assert_eq!(a_is_eq, s_is_eq);
write!(o, "{},", a_is_eq as u8)?;
for (_var, value) in a.iter() {
write!(o, "{},", *value)?;
}
for (ii, feature_on) in s.iter().enumerate() {
let feature_on = *feature_on;
write!(o, "{}", feature_on as u8)?;
if ii < num_features - 1 {
write!(o, ",")?;
}
}
writeln!(o)?;
}

Ok(())
}
16 changes: 15 additions & 1 deletion tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ impl Samples {
a
}

pub fn vars(&self) -> impl Iterator<Item = Var> + '_ {
self.vars.iter().cloned()
}

pub fn to_json(self, out: &mut impl std::io::Write) -> serde_json::error::Result<()> {
let s: SamplesSerde = self.into();
serde_json::to_writer_pretty(out, &s)
Expand Down Expand Up @@ -378,6 +382,16 @@ fn union_vecs<T: Clone + PartialEq + Ord>(a: &[T], b: &[T]) -> Vec<T> {
out
}

#[inline]
pub fn get_var_name(v: &Var) -> Option<String> {
let name = v.to_string();
if name.starts_with('?') {
Some(name.chars().skip(1).collect())
} else {
None
}
}

/// Extracts the output width and all children including width and sign from an [[`egg::PatternAst`]].
/// Requires that the output width is name `?wo` and that the child width and sign are named like:
/// `?w{name}` and `?s{name}`.
Expand Down Expand Up @@ -439,7 +453,7 @@ fn width_const_from_pattern(pat: &PatternAst<Arith>, id: Id) -> VarOrConst {
match &pat[id] {
ENodeOrVar::ENode(node) => match node {
&Arith::WidthConst(w) => VarOrConst::C(w),
_ => unreachable!("not a widht!"),
_ => unreachable!("not a width!"),
},
ENodeOrVar::Var(var) => VarOrConst::V(*var),
}
Expand Down

0 comments on commit 7106815

Please sign in to comment.