From 3a86537d0827681e6f5fbd335773b12555020bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kevin=20L=C3=A4ufer?= Date: Tue, 10 Dec 2024 13:13:42 -0500 Subject: [PATCH] cond synth: save assignment to/from JSON --- tools/egraphs-cond-synth/.gitignore | 1 + tools/egraphs-cond-synth/Cargo.toml | 2 + tools/egraphs-cond-synth/src/main.rs | 63 +++++++++++++++++------ tools/egraphs-cond-synth/src/samples.rs | 62 ++++++++++++++++++---- tools/egraphs-cond-synth/src/summarize.rs | 7 +-- 5 files changed, 103 insertions(+), 32 deletions(-) create mode 100644 tools/egraphs-cond-synth/.gitignore diff --git a/tools/egraphs-cond-synth/.gitignore b/tools/egraphs-cond-synth/.gitignore new file mode 100644 index 0000000..6830205 --- /dev/null +++ b/tools/egraphs-cond-synth/.gitignore @@ -0,0 +1 @@ +/*.json diff --git a/tools/egraphs-cond-synth/Cargo.toml b/tools/egraphs-cond-synth/Cargo.toml index 864dbb3..fa239f6 100644 --- a/tools/egraphs-cond-synth/Cargo.toml +++ b/tools/egraphs-cond-synth/Cargo.toml @@ -20,3 +20,5 @@ indicatif = "0.17.9" rayon = "1.10.0" thread_local = "1.1.8" bitvec = "1.0.1" +serde_json = "1.0.133" +serde = { version = "1.0.215", features = ["derive"] } diff --git a/tools/egraphs-cond-synth/src/main.rs b/tools/egraphs-cond-synth/src/main.rs index 0d9f89a..a7148db 100644 --- a/tools/egraphs-cond-synth/src/main.rs +++ b/tools/egraphs-cond-synth/src/main.rs @@ -9,11 +9,14 @@ mod features; mod samples; mod summarize; +use crate::features::apply_features; +use crate::samples::{get_rule_info, Samples}; use crate::summarize::bdd_summarize; use clap::Parser; use egg::*; use patronus::expr::*; use patronus_egraphs::*; +use std::path::PathBuf; #[derive(Parser, Debug)] #[command(name = "patronus-egraphs-cond-synth")] @@ -33,6 +36,13 @@ struct Args { help = "checks the current condition, prints out if it disagrees with the examples we generate" )] check_cond: bool, + #[arg(long, help = "write the generated assignments to a JSON file")] + write_assignments: Option, + #[arg( + long, + help = "read assignments from a JSON file instead of generating and checking them" + )] + read_assignments: Option, #[arg(value_name = "RULE", index = 1)] rule: String, } @@ -116,9 +126,6 @@ fn get_width_from_e_graph(egraph: &mut EGraph, subst: &Subst, v: Var) -> WidthIn fn main() { let args = Args::parse(); - // remember start time - let start = std::time::Instant::now(); - // find rule and extract both sides let rewrites = create_rewrites(); let rule = match rewrites.iter().find(|r| r.name.as_str() == args.rule) { @@ -132,25 +139,40 @@ fn main() { } }; - let (samples, rule_info) = samples::generate_samples( - &args.rule, - rule, - args.max_width, - true, - args.dump_smt, - args.check_cond, - ); - let delta_t = std::time::Instant::now() - start; + // generate and check samples + let samples = if let Some(in_filename) = args.read_assignments { + let file = std::fs::File::open(&in_filename).expect("failed to open input JSON"); + let mut reader = std::io::BufReader::new(file); + let samples = Samples::from_json(&mut reader).expect("failed to parse input JSON"); + println!("Assignments loaded from {:?}", in_filename); + samples + } else { + // remember start time + let start = std::time::Instant::now(); + let samples = + samples::generate_samples(rule, args.max_width, true, args.dump_smt, args.check_cond); + let delta_t = std::time::Instant::now() - start; + println!( + "Took {delta_t:?} on {} threads.", + rayon::current_num_threads() + ); + samples + }; println!("Found {} equivalent rewrites.", samples.num_equivalent()); println!( "Found {} unequivalent rewrites.", samples.num_unequivalent() ); - println!( - "Took {delta_t:?} on {} threads.", - rayon::current_num_threads() - ); + + if let Some(out_filename) = args.write_assignments { + let mut file = std::fs::File::create(&out_filename).expect("failed to open output JSON"); + samples + .clone() + .to_json(&mut file) + .expect("failed to write output JSON"); + println!("Wrote assignments to `{:?}`", out_filename); + } if args.print_samples { for sample in samples.iter() { @@ -158,9 +180,16 @@ fn main() { } } + // check features + let feature_start = std::time::Instant::now(); + let rule_info = get_rule_info(rule); + let features = apply_features(&rule_info, &samples); + let feature_delta_t = std::time::Instant::now() - feature_start; + println!("{feature_delta_t:?} to apply all features"); + if args.bdd_formula { let summarize_start = std::time::Instant::now(); - let formula = bdd_summarize(&rule_info, &samples); + let formula = bdd_summarize(&features); let summarize_delta_t = std::time::Instant::now() - summarize_start; println!("Generated formula in {summarize_delta_t:?}:\n{}", formula); } diff --git a/tools/egraphs-cond-synth/src/samples.rs b/tools/egraphs-cond-synth/src/samples.rs index 465745d..4c5e560 100644 --- a/tools/egraphs-cond-synth/src/samples.rs +++ b/tools/egraphs-cond-synth/src/samples.rs @@ -9,23 +9,26 @@ use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt}; use patronus_egraphs::*; use rayon::prelude::*; use rustc_hash::{FxHashMap, FxHashSet}; +use serde::{Deserialize, Serialize, Serializer}; + +pub fn get_rule_info(rule: &Rewrite) -> RuleInfo { + let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule"); + let lhs_info = analyze_pattern(lhs); + let rhs_info = analyze_pattern(rhs); + lhs_info.merge(&rhs_info) +} pub fn generate_samples( - rule_name: &str, rule: &Rewrite, max_width: WidthInt, show_progress: bool, dump_smt: bool, check_cond: bool, -) -> (Samples, RuleInfo) { +) -> Samples { let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule"); - println!("{}: {} => {}", rule_name, lhs, rhs); - - // analyze rule patterns let lhs_info = analyze_pattern(lhs); - let rhs_info = analyze_pattern(lhs); + let rhs_info = analyze_pattern(rhs); let rule_info = lhs_info.merge(&rhs_info); - println!("{:?}", rule_info); let num_assignments = rule_info.num_assignments(max_width); println!("There are {num_assignments} possible assignments for this rule."); @@ -81,10 +84,9 @@ pub fn generate_samples( .collect::>(); // merge results from different threads - let samples = samples + samples .into_par_iter() - .reduce(|| Samples::new(&rule_info), Samples::merge); - (samples, rule_info) + .reduce(|| Samples::new(&rule_info), Samples::merge) } fn start_solver(dump_smt: bool) -> easy_smt::Context { @@ -110,6 +112,36 @@ pub struct Samples { is_equivalent: Vec, } +/// Works around the fact that `Var` cannot be serialized/deserialized +#[derive(Serialize, Deserialize)] +struct SamplesSerde { + vars: Vec, + assignments: Vec, + is_equivalent: Vec, +} + +impl From for SamplesSerde { + fn from(value: Samples) -> Self { + let vars = value.vars.into_iter().map(|v| v.to_string()).collect(); + Self { + vars, + assignments: value.assignments, + is_equivalent: value.is_equivalent, + } + } +} + +impl From for Samples { + fn from(value: SamplesSerde) -> Self { + let vars = value.vars.into_iter().map(|v| v.parse().unwrap()).collect(); + Self { + vars, + assignments: value.assignments, + is_equivalent: value.is_equivalent, + } + } +} + impl Samples { fn new(rule: &RuleInfo) -> Self { let vars = rule.assignment_vars().collect(); @@ -149,6 +181,16 @@ impl Samples { a.is_equivalent.append(&mut b.is_equivalent); a } + + pub fn to_json(self, out: &mut impl std::io::Write) -> serde_json::error::Result<()> { + let s: SamplesSerde = self.into(); + serde_json::to_writer_pretty(out, &s) + } + + pub fn from_json(input: &mut impl std::io::Read) -> serde_json::error::Result { + let s: SamplesSerde = serde_json::from_reader(input)?; + Ok(s.into()) + } } pub struct SamplesIter<'a> { diff --git a/tools/egraphs-cond-synth/src/summarize.rs b/tools/egraphs-cond-synth/src/summarize.rs index ecd4cf8..140ef3e 100644 --- a/tools/egraphs-cond-synth/src/summarize.rs +++ b/tools/egraphs-cond-synth/src/summarize.rs @@ -2,13 +2,10 @@ // released under BSD 3-Clause License // author: Kevin Laeufer -use crate::features::apply_features; -use crate::samples::{RuleInfo, Samples}; +use crate::features::FeatureResult; /// generate a simplified re-write condition from samples, using BDDs -pub fn bdd_summarize(rule: &RuleInfo, samples: &Samples) -> String { - let results = apply_features(rule, samples); - +pub fn bdd_summarize(results: &FeatureResult) -> String { // generate BDD terminals let mut bdd = boolean_expression::BDD::::new(); let vars: Vec<_> = results