Skip to content

Commit

Permalink
cond synth: save assignment to/from JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 10, 2024
1 parent 535eeef commit 3a86537
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 32 deletions.
1 change: 1 addition & 0 deletions tools/egraphs-cond-synth/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/*.json
2 changes: 2 additions & 0 deletions tools/egraphs-cond-synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ indicatif = "0.17.9"
rayon = "1.10.0"
thread_local = "1.1.8"
bitvec = "1.0.1"
serde_json = "1.0.133"
serde = { version = "1.0.215", features = ["derive"] }
63 changes: 46 additions & 17 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ mod features;
mod samples;
mod summarize;

use crate::features::apply_features;
use crate::samples::{get_rule_info, Samples};
use crate::summarize::bdd_summarize;
use clap::Parser;
use egg::*;
use patronus::expr::*;
use patronus_egraphs::*;
use std::path::PathBuf;

#[derive(Parser, Debug)]
#[command(name = "patronus-egraphs-cond-synth")]
Expand All @@ -33,6 +36,13 @@ struct Args {
help = "checks the current condition, prints out if it disagrees with the examples we generate"
)]
check_cond: bool,
#[arg(long, help = "write the generated assignments to a JSON file")]
write_assignments: Option<PathBuf>,
#[arg(
long,
help = "read assignments from a JSON file instead of generating and checking them"
)]
read_assignments: Option<PathBuf>,
#[arg(value_name = "RULE", index = 1)]
rule: String,
}
Expand Down Expand Up @@ -116,9 +126,6 @@ fn get_width_from_e_graph(egraph: &mut EGraph, subst: &Subst, v: Var) -> WidthIn
fn main() {
let args = Args::parse();

// remember start time
let start = std::time::Instant::now();

// find rule and extract both sides
let rewrites = create_rewrites();
let rule = match rewrites.iter().find(|r| r.name.as_str() == args.rule) {
Expand All @@ -132,35 +139,57 @@ fn main() {
}
};

let (samples, rule_info) = samples::generate_samples(
&args.rule,
rule,
args.max_width,
true,
args.dump_smt,
args.check_cond,
);
let delta_t = std::time::Instant::now() - start;
// generate and check samples
let samples = if let Some(in_filename) = args.read_assignments {
let file = std::fs::File::open(&in_filename).expect("failed to open input JSON");
let mut reader = std::io::BufReader::new(file);
let samples = Samples::from_json(&mut reader).expect("failed to parse input JSON");
println!("Assignments loaded from {:?}", in_filename);
samples
} else {
// remember start time
let start = std::time::Instant::now();
let samples =
samples::generate_samples(rule, args.max_width, true, args.dump_smt, args.check_cond);
let delta_t = std::time::Instant::now() - start;
println!(
"Took {delta_t:?} on {} threads.",
rayon::current_num_threads()
);
samples
};

println!("Found {} equivalent rewrites.", samples.num_equivalent());
println!(
"Found {} unequivalent rewrites.",
samples.num_unequivalent()
);
println!(
"Took {delta_t:?} on {} threads.",
rayon::current_num_threads()
);

if let Some(out_filename) = args.write_assignments {
let mut file = std::fs::File::create(&out_filename).expect("failed to open output JSON");
samples
.clone()
.to_json(&mut file)
.expect("failed to write output JSON");
println!("Wrote assignments to `{:?}`", out_filename);
}

if args.print_samples {
for sample in samples.iter() {
println!("{:?}", sample);
}
}

// check features
let feature_start = std::time::Instant::now();
let rule_info = get_rule_info(rule);
let features = apply_features(&rule_info, &samples);
let feature_delta_t = std::time::Instant::now() - feature_start;
println!("{feature_delta_t:?} to apply all features");

if args.bdd_formula {
let summarize_start = std::time::Instant::now();
let formula = bdd_summarize(&rule_info, &samples);
let formula = bdd_summarize(&features);
let summarize_delta_t = std::time::Instant::now() - summarize_start;
println!("Generated formula in {summarize_delta_t:?}:\n{}", formula);
}
Expand Down
62 changes: 52 additions & 10 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,26 @@ use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt};
use patronus_egraphs::*;
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize, Serializer};

pub fn get_rule_info(rule: &Rewrite<Arith, ()>) -> RuleInfo {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
let lhs_info = analyze_pattern(lhs);
let rhs_info = analyze_pattern(rhs);
lhs_info.merge(&rhs_info)
}

pub fn generate_samples(
rule_name: &str,
rule: &Rewrite<Arith, ()>,
max_width: WidthInt,
show_progress: bool,
dump_smt: bool,
check_cond: bool,
) -> (Samples, RuleInfo) {
) -> Samples {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
println!("{}: {} => {}", rule_name, lhs, rhs);

// analyze rule patterns
let lhs_info = analyze_pattern(lhs);
let rhs_info = analyze_pattern(lhs);
let rhs_info = analyze_pattern(rhs);
let rule_info = lhs_info.merge(&rhs_info);
println!("{:?}", rule_info);

let num_assignments = rule_info.num_assignments(max_width);
println!("There are {num_assignments} possible assignments for this rule.");
Expand Down Expand Up @@ -81,10 +84,9 @@ pub fn generate_samples(
.collect::<Vec<_>>();

// merge results from different threads
let samples = samples
samples
.into_par_iter()
.reduce(|| Samples::new(&rule_info), Samples::merge);
(samples, rule_info)
.reduce(|| Samples::new(&rule_info), Samples::merge)
}

fn start_solver(dump_smt: bool) -> easy_smt::Context {
Expand All @@ -110,6 +112,36 @@ pub struct Samples {
is_equivalent: Vec<bool>,
}

/// Works around the fact that `Var` cannot be serialized/deserialized
#[derive(Serialize, Deserialize)]
struct SamplesSerde {
vars: Vec<String>,
assignments: Vec<WidthInt>,
is_equivalent: Vec<bool>,
}

impl From<Samples> for SamplesSerde {
fn from(value: Samples) -> Self {
let vars = value.vars.into_iter().map(|v| v.to_string()).collect();
Self {
vars,
assignments: value.assignments,
is_equivalent: value.is_equivalent,
}
}
}

impl From<SamplesSerde> for Samples {
fn from(value: SamplesSerde) -> Self {
let vars = value.vars.into_iter().map(|v| v.parse().unwrap()).collect();
Self {
vars,
assignments: value.assignments,
is_equivalent: value.is_equivalent,
}
}
}

impl Samples {
fn new(rule: &RuleInfo) -> Self {
let vars = rule.assignment_vars().collect();
Expand Down Expand Up @@ -149,6 +181,16 @@ impl Samples {
a.is_equivalent.append(&mut b.is_equivalent);
a
}

pub fn to_json(self, out: &mut impl std::io::Write) -> serde_json::error::Result<()> {
let s: SamplesSerde = self.into();
serde_json::to_writer_pretty(out, &s)
}

pub fn from_json(input: &mut impl std::io::Read) -> serde_json::error::Result<Self> {
let s: SamplesSerde = serde_json::from_reader(input)?;
Ok(s.into())
}
}

pub struct SamplesIter<'a> {
Expand Down
7 changes: 2 additions & 5 deletions tools/egraphs-cond-synth/src/summarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
// released under BSD 3-Clause License
// author: Kevin Laeufer <[email protected]>

use crate::features::apply_features;
use crate::samples::{RuleInfo, Samples};
use crate::features::FeatureResult;

/// generate a simplified re-write condition from samples, using BDDs
pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String {
let results = apply_features(rule, samples);

pub fn bdd_summarize(results: &FeatureResult) -> String {
// generate BDD terminals
let mut bdd = boolean_expression::BDD::<String>::new();
let vars: Vec<_> = results
Expand Down

0 comments on commit 3a86537

Please sign in to comment.