Skip to content

Commit

Permalink
cond synth: first attempt at multi threading
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 9, 2024
1 parent eed22c7 commit bda9bfa
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 48 deletions.
2 changes: 2 additions & 0 deletions tools/egraphs-cond-synth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ clap.workspace = true
rustc-hash.workspace = true
easy-smt.workspace = true
indicatif = "0.17.9"
rayon = "1.10.0"
thread_local = "1.1.8"
13 changes: 12 additions & 1 deletion tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ struct Args {
max_width: WidthInt,
#[arg(long)]
print_samples: bool,
#[arg(long)]
dump_smt: bool,
#[arg(value_name = "RULE", index = 1)]
rule: String,
}
Expand All @@ -40,6 +42,9 @@ fn create_rewrites() -> Vec<Rewrite<Arith, ()>> {
fn main() {
let args = Args::parse();

// remember start time
let start = std::time::Instant::now();

// find rule and extract both sides
let rewrites = create_rewrites();
let rule = match rewrites.iter().find(|r| r.name.as_str() == args.rule) {
Expand All @@ -53,12 +58,18 @@ fn main() {
}
};

let samples = samples::generate_samples(&args.rule, rule, args.max_width, true);
let samples = samples::generate_samples(&args.rule, rule, args.max_width, true, args.dump_smt);
let delta_t = std::time::Instant::now() - start;

println!("Found {} equivalent rewrites.", samples.num_equivalent());
println!(
"Found {} unequivalent rewrites.",
samples.num_unequivalent()
);
println!(
"Took {delta_t:?} on {} threads.",
rayon::current_num_threads()
);

if args.print_samples {
for sample in samples.iter() {
Expand Down
140 changes: 93 additions & 47 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ use indicatif::ProgressBar;
use patronus::expr::traversal::TraversalCmd;
use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt};
use patronus_egraphs::*;
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};

pub fn generate_samples(
rule_name: &str,
rule: &Rewrite<Arith, ()>,
max_width: WidthInt,
show_progress: bool,
dump_smt: bool,
) -> Samples {
let (lhs, rhs) = extract_patterns(rule).expect("failed to extract patterns from rewrite rule");
println!("{}: {} => {}", rule_name, lhs, rhs);
Expand All @@ -27,46 +29,76 @@ pub fn generate_samples(
let num_assignments = rule_info.num_assignments(max_width);
println!("There are {num_assignments} possible assignments for this rule.");

// create context and start smt solver
let mut ctx = Context::default();
let solver: patronus::mc::SmtSolverCmd = patronus::mc::BITWUZLA_CMD;
let mut smt_ctx = easy_smt::ContextBuilder::new()
.solver(solver.name, solver.args)
.replay_file(Some(std::fs::File::create("replay.smt").unwrap()))
.build()
.unwrap();
smt_ctx.set_logic("QF_ABV").unwrap();

// check all rewrites
let mut samples = Samples::new(&rule_info);
// progress indicator
let prog = if show_progress {
Some(ProgressBar::new(rule_info.num_assignments(max_width)))
} else {
None
};
for assignment in rule_info.iter_assignments(max_width) {
if let Some(p) = &prog {
p.inc(1);
}
let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment);
let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment);
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);

smt_ctx.push_many(1).unwrap();
declare_vars(&mut smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
let resp = smt_ctx.check().unwrap();
smt_ctx.pop_many(1).unwrap();

match resp {
easy_smt::Response::Sat => samples.add(assignment, false),
easy_smt::Response::Unsat => samples.add(assignment, true),
easy_smt::Response::Unknown => println!("{assignment:?} => Unknown!"),
}
}

// split up work across threads
let num_threads = rayon::current_num_threads();
let assignment_range = 0..rule_info.num_assignments(max_width);
let assignments_per_thread = assignment_range.end as usize / num_threads;

// check all rewrites in parallel
let samples = assignment_range
.collect::<Vec<_>>()
.par_chunks(assignments_per_thread)
.map(|assignment_indices| {
// create context and samples
let mut ctx = Context::default();
let mut smt_ctx = start_solver(dump_smt);
let mut samples = Samples::new(&rule_info);

for &assignment_index in assignment_indices.iter() {
if let Some(p) = &prog {
p.inc(1);
}
let assignment = rule_info.get_assignment(max_width, assignment_index);
let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment);
let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment);
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);

smt_ctx.push_many(1).unwrap();
declare_vars(&mut smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
let resp = smt_ctx.check().unwrap();
smt_ctx.pop_many(1).unwrap();

match resp {
easy_smt::Response::Sat => samples.add(assignment, false),
easy_smt::Response::Unsat => samples.add(assignment, true),
easy_smt::Response::Unknown => println!("{assignment:?} => Unknown!"),
}
}

samples
})
.collect::<Vec<_>>();

// merge results from different threads
samples
.into_par_iter()
.reduce(|| Samples::new(&rule_info), Samples::merge)
}

fn start_solver(dump_smt: bool) -> easy_smt::Context {
let solver: patronus::mc::SmtSolverCmd = patronus::mc::BITWUZLA_CMD;
let dump_file = if dump_smt {
Some(std::fs::File::create("replay.smt").unwrap())
} else {
None
};
let mut smt_ctx = easy_smt::ContextBuilder::new()
.solver(solver.name, solver.args)
.replay_file(dump_file)
.build()
.unwrap();
smt_ctx.set_logic("QF_ABV").unwrap();
smt_ctx
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -109,6 +141,12 @@ impl Samples {
index: 0,
}
}
pub fn merge(mut a: Self, mut b: Self) -> Self {
debug_assert_eq!(a.vars, b.vars);
a.assignments.append(&mut b.assignments);
a.is_equivalent.append(&mut b.is_equivalent);
a
}
}

pub struct SamplesIter<'a> {
Expand Down Expand Up @@ -209,6 +247,7 @@ impl RuleInfo {
}
}

#[allow(dead_code)]
fn iter_assignments(&self, max_width: WidthInt) -> impl Iterator<Item = Assignment> + '_ {
AssignmentIter {
rule: self,
Expand All @@ -222,6 +261,24 @@ impl RuleInfo {
2u64.pow(self.signs.len() as u32) * width_values.pow(self.widths.len() as u32)
}

/// gets the assignment `ii` where `num_asignments > index >= 0`
fn get_assignment(&self, max_width: WidthInt, mut index: u64) -> Assignment {
debug_assert!(self.num_assignments(max_width) > index);
let width_values = max_width as u64;
let mut out = Vec::with_capacity(1 + 2 * self.symbols.len());
for &width_var in self.widths.iter() {
let value = (index % width_values) as WidthInt + 1;
index /= width_values;
out.push((width_var, value))
}
for &sign_var in self.signs.iter() {
let value = (index % 2) as WidthInt;
index /= 2;
out.push((sign_var, value))
}
out
}

fn assignment_vars(&self) -> impl Iterator<Item = Var> + '_ {
self.widths
.iter()
Expand All @@ -231,6 +288,7 @@ impl RuleInfo {
}

/// An iterator over all possivle assignments in a rule.
#[allow(dead_code)]
struct AssignmentIter<'a> {
rule: &'a RuleInfo,
index: u64,
Expand All @@ -241,23 +299,11 @@ impl<'a> Iterator for AssignmentIter<'a> {
type Item = Assignment;

fn next(&mut self) -> Option<Self::Item> {
let width_values = self.max_width as u64;
let max = self.rule.num_assignments(self.max_width);
if self.index == max {
None
} else {
let mut out = Vec::with_capacity(1 + 2 * self.rule.symbols.len());
let mut index = self.index;
for &width_var in self.rule.widths.iter() {
let value = (index % width_values) as WidthInt + 1;
index /= width_values;
out.push((width_var, value))
}
for &sign_var in self.rule.signs.iter() {
let value = (index % 2) as WidthInt;
index /= 2;
out.push((sign_var, value))
}
let out = self.rule.get_assignment(self.max_width, self.index);
self.index += 1;
Some(out)
}
Expand Down

0 comments on commit bda9bfa

Please sign in to comment.