Skip to content

Commit

Permalink
cond synth: wip thread local version which makes better use of all th…
Browse files Browse the repository at this point in the history
…reads
  • Loading branch information
ekiwi committed Dec 9, 2024
1 parent bda9bfa commit 9410d78
Showing 1 changed file with 49 additions and 39 deletions.
88 changes: 49 additions & 39 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ use patronus::expr::{Context, ExprRef, TypeCheck, WidthInt};
use patronus_egraphs::*;
use rayon::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use std::cell::RefCell;
use std::sync::Mutex;
use thread_local::ThreadLocal;

pub fn generate_samples(
rule_name: &str,
Expand Down Expand Up @@ -37,51 +40,58 @@ pub fn generate_samples(
};

// split up work across threads
// let num_threads = rayon::current_num_threads();
// let assignment_range = 0..rule_info.num_assignments(max_width);
// let assignments_per_thread = assignment_range.end as usize / num_threads;

// thread local storage in order to re-use smt solver process
let num_threads = rayon::current_num_threads();
let solvers = (0..num_threads)
.into_iter()
.map(|_| Mutex::new(start_solver(dump_smt)))
.collect::<Vec<_>>();

let assignment_range = 0..rule_info.num_assignments(max_width);
let assignments_per_thread = assignment_range.end as usize / num_threads;

// check all rewrites in parallel
let samples = assignment_range
.collect::<Vec<_>>()
.par_chunks(assignments_per_thread)
.map(|assignment_indices| {
// create context and samples
assignment_range
.into_par_iter()
.map(|index| {
// create / acquire context
let mut ctx = Context::default();
let mut smt_ctx = start_solver(dump_smt);
let mut samples = Samples::new(&rule_info);
let idx = rayon::current_thread_index().unwrap();
let mut smt_ctx = solvers[idx].lock().unwrap();

for &assignment_index in assignment_indices.iter() {
if let Some(p) = &prog {
p.inc(1);
}
let assignment = rule_info.get_assignment(max_width, assignment_index);
let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment);
let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment);
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);

smt_ctx.push_many(1).unwrap();
declare_vars(&mut smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
let resp = smt_ctx.check().unwrap();
smt_ctx.pop_many(1).unwrap();

match resp {
easy_smt::Response::Sat => samples.add(assignment, false),
easy_smt::Response::Unsat => samples.add(assignment, true),
easy_smt::Response::Unknown => println!("{assignment:?} => Unknown!"),
}
// progress and get assignment
if let Some(p) = &prog {
p.inc(1);
}
let assignment = rule_info.get_assignment(max_width, index);

// check assignment
let lhs_expr = to_smt(&mut ctx, lhs, &lhs_info, &assignment);
let rhs_expr = to_smt(&mut ctx, rhs, &rhs_info, &assignment);
let is_eq = ctx.equal(lhs_expr, rhs_expr);
let is_not_eq = ctx.not(is_eq);
let smt_expr = patronus::smt::convert_expr(&smt_ctx, &ctx, is_not_eq, &|_| None);

smt_ctx.push_many(1).unwrap();
declare_vars(&mut smt_ctx, &ctx, is_not_eq);
smt_ctx.assert(smt_expr).unwrap();
let resp = smt_ctx.check().unwrap();
smt_ctx.pop_many(1).unwrap();

match resp {
easy_smt::Response::Sat => (assignment, false),
easy_smt::Response::Unsat => (assignment, true),
easy_smt::Response::Unknown => panic!("{assignment:?} => Unknown!"),
}

samples
})
.collect::<Vec<_>>();

// merge results from different threads
samples
.into_par_iter()
.fold(
|| Samples::new(&rule_info),
|mut samples, (a, e)| {
samples.add(a, e);
samples
},
)
.reduce(|| Samples::new(&rule_info), Samples::merge)
}

Expand Down

0 comments on commit 9410d78

Please sign in to comment.