Skip to content

Commit

Permalink
cond synth: change how we analyze patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Dec 5, 2024
1 parent aa92b68 commit 231bd31
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 71 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rustc-hash = "2.x"
baa = "0.14.6"
egg = "0.9.5"
easy-smt = "0.2.3"
regex = "1.11.1"
clap = { version = "4.x", features = ["derive"] }
patronus = {path = "patronus"}

Expand Down
2 changes: 2 additions & 0 deletions patronus-egraphs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ rust-version.workspace = true
patronus = { path = "../patronus" }
egg.workspace = true
baa.workspace = true
lazy_static = "1.5.0"
regex = "1.11.1"
15 changes: 14 additions & 1 deletion patronus-egraphs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

use baa::BitVecOps;
use egg::{define_language, rewrite, Id, Language, RecExpr, Var};
use lazy_static::lazy_static;
use patronus::expr::*;
use regex::Regex;
use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
Expand Down Expand Up @@ -32,11 +34,22 @@ pub struct ArithSymbol {
pub width: WidthInt,
}

lazy_static! {
static ref ARITH_SYMBOL_REGEX: Regex =
Regex::new(r"^StringRef\(([[:digit:]]+)\)\s*:\s*bv<\s*([[:digit:]]+)\s*>\s*$").unwrap();
}

impl FromStr for ArithSymbol {
type Err = ();

fn from_str(s: &str) -> Result<Self, Self::Err> {
todo!()
if let Some(c) = ARITH_SYMBOL_REGEX.captures(s) {
let name_index: usize = c.get(1).unwrap().as_str().parse().unwrap();
let width: WidthInt = c.get(2).unwrap().as_str().parse().unwrap();
todo!("{s} ==> {name_index} {width}")
} else {
Err(())
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion patronus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ lazy_static = "1.4.0"
easy-smt.workspace = true
smallvec = { version = "1.x", features = ["union"] }
boolean_expression = "0.4.4"
regex = "1.11.1"
regex.workspace = true
baa.workspace = true
rustc-hash.workspace = true

Expand Down
6 changes: 6 additions & 0 deletions tools/egraphs-cond-synth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ struct Args {
fn create_rewrites() -> Vec<Rewrite<Arith, ()>> {
vec![
rewrite!("commute-add"; "(+ ?wo ?wa ?sa ?a ?wb ?sb ?b)" => "(+ ?wo ?wb ?sb ?b ?wa ?sa ?a)"),
rewrite!("merge-left-shift";
// we require that b, c and (b + c) are all unsigned
"(<< ?wo ?wab ?sab (<< ?wab ?wa ?sa ?a ?wb 0 ?b) ?wc 0 ?c)" =>
// note: in this version we set the width of (b + c) on the RHS to be the width of the
// result (w_o)
"(<< ?wo ?wa ?sa ?a ?wo 0 (+ ?wo ?wb 0 ?b ?wc 0 ?c))"),
]
}

Expand Down
170 changes: 101 additions & 69 deletions tools/egraphs-cond-synth/src/samples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub struct Samples {

impl Samples {
fn new(rule: &RuleInfo) -> Self {
let vars = rule.assignment_vars();
let vars = rule.assignment_vars().collect();
let assignments = vec![];
let is_equivalent = vec![];
Self {
Expand Down Expand Up @@ -162,26 +162,40 @@ fn extract_patterns<L: Language>(

#[derive(Debug, Clone, Eq, PartialEq)]
struct RuleInfo {
width: Var,
children: Vec<RuleChild>,
/// width parameters
widths: Vec<Var>,
/// sign parameters
signs: Vec<Var>,
/// all actual expression symbols in the rule which we need to plug in
/// if we want to check for equivalence
symbols: Vec<RuleSymbol>,
}

#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
struct RuleChild {
enum VarOrConst {
C(WidthInt),
V(Var),
}

/// a unique symbol in a rule, needs to be replaced with an SMT bit-vector symbol for equivalence checks
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
struct RuleSymbol {
var: Var,
width: Var,
sign: Var,
width: VarOrConst,
sign: VarOrConst,
}

pub type Assignment = Vec<(Var, WidthInt)>;

impl RuleInfo {
fn merge(&self, other: &Self) -> Self {
assert_eq!(self.width, other.width);
let children = union_vecs(&self.children, &other.children);
let widths = union_vecs(&self.widths, &other.widths);
let signs = union_vecs(&self.signs, &other.signs);
let symbols = union_vecs(&self.symbols, &other.symbols);
Self {
width: self.width,
children,
widths,
signs,
symbols,
}
}

Expand All @@ -194,20 +208,15 @@ impl RuleInfo {
}

fn num_assignments(&self, max_width: WidthInt) -> u64 {
let cl = self.children.len() as u32;
let width_values = max_width as u64; // we do not use 0-bit
2u64.pow(cl) * width_values.pow(1 + cl)
2u64.pow(self.signs.len() as u32) * width_values.pow(self.widths.len() as u32)
}

fn assignment_vars(&self) -> Vec<Var> {
let mut out = vec![self.width];
for child in self.children.iter() {
out.push(child.width);
}
for child in self.children.iter() {
out.push(child.sign);
}
out
fn assignment_vars(&self) -> impl Iterator<Item = Var> + '_ {
self.widths
.iter()
.cloned()
.chain(self.signs.iter().cloned())
}
}

Expand All @@ -227,17 +236,14 @@ impl<'a> Iterator for AssignmentIter<'a> {
if self.index == max {
None
} else {
let mut out = Vec::with_capacity(1 + 2 * self.rule.children.len());
let mut out = Vec::with_capacity(1 + 2 * self.rule.symbols.len());
let mut index = self.index;
for width_var in [self.rule.width]
.into_iter()
.chain(self.rule.children.iter().map(|c| c.width))
{
for &width_var in self.rule.widths.iter() {
let value = (index % width_values) as WidthInt + 1;
index /= width_values;
out.push((width_var, value))
}
for sign_var in self.rule.children.iter().map(|c| c.sign) {
for &sign_var in self.rule.signs.iter() {
let value = (index % 2) as WidthInt;
index /= 2;
out.push((sign_var, value))
Expand All @@ -260,49 +266,69 @@ fn union_vecs<T: Clone + PartialEq + Ord>(a: &[T], b: &[T]) -> Vec<T> {
/// Extracts the output width and all children including width and sign from an [[`egg::PatternAst`]].
/// Requires that the output width is name `?wo` and that the child width and sign are named like:
/// `?w{name}` and `?s{name}`.
fn analyze_pattern<L: Language>(pat: &PatternAst<L>) -> RuleInfo {
let mut widths = FxHashMap::default();
let mut signs = FxHashMap::default();
let mut children = Vec::new();
fn analyze_pattern(pat: &PatternAst<Arith>) -> RuleInfo {
let mut widths = vec![];
let mut signs = vec![];
let mut symbols = Vec::new();
for element in pat.as_ref().iter() {
if let &ENodeOrVar::Var(v) = element {
// check name to determine the category
let name = v.to_string();
assert!(name.starts_with("?"), "expect all vars to start with `?`");
let second_char = name.chars().nth(1).unwrap();
match second_char {
'w' => {
widths.insert(name, v);
match &element {
ENodeOrVar::Var(v) => {
// collect information on variables
let name = v.to_string();
assert!(name.starts_with("?"), "expect all vars to start with `?`");
let second_char = name.chars().nth(1).unwrap();
match second_char {
'w' => widths.push(*v),
's' => signs.push(*v),
_ => {} // ignore
}
's' => {
signs.insert(name, v);
}
ENodeOrVar::ENode(n) => {
// bin op pattern
if let [_w, w_a, s_a, a, w_b, s_b, b] = n.children() {
if let Some(s) = symbol_from_pattern(pat, *a, *w_a, *s_a) {
symbols.push(s);
}
if let Some(s) = symbol_from_pattern(pat, *b, *w_b, *s_b) {
symbols.push(s);
}
}
_ => children.push(v),
}
}
}
let width = *widths
.get("?wo")
.expect("pattern is missing result width: `?wo`");
let mut children = children
.into_iter()
.map(|c| {
let name = c.to_string().chars().skip(1).collect::<String>();
let width = *widths
.get(&format!("?w{name}"))
.unwrap_or_else(|| panic!("pattern is missing a width for `{name}`: `?w{name}`"));
let sign = *signs
.get(&format!("?s{name}"))
.unwrap_or_else(|| panic!("pattern is missing a sign for `{name}`: `?s{name}`"));
RuleChild {
var: c,
width,
sign,
}
})
.collect::<Vec<_>>();
children.sort();
RuleInfo { width, children }

widths.sort();
widths.dedup();
signs.sort();
signs.dedup();
symbols.sort();
symbols.dedup();
RuleInfo {
widths,
signs,
symbols,
}
}

fn symbol_from_pattern(pat: &PatternAst<Arith>, a: Id, w: Id, s: Id) -> Option<RuleSymbol> {
if let ENodeOrVar::Var(var) = pat[a] {
let width = width_or_sign_from_pattern(pat, w);
let sign = width_or_sign_from_pattern(pat, s);
Some(RuleSymbol { var, width, sign })
} else {
None
}
}

fn width_or_sign_from_pattern(pat: &PatternAst<Arith>, id: Id) -> VarOrConst {
match &pat[id] {
ENodeOrVar::ENode(node) => match node {
&Arith::Width(w) => VarOrConst::C(w),
&Arith::Signed(s) => VarOrConst::C(s as WidthInt),
_ => unreachable!("not a widht!"),
},
ENodeOrVar::Var(var) => VarOrConst::V(*var),
}
}

/// Generates a patronus SMT expression from a pattern, rule info and assignment.
Expand All @@ -325,11 +351,17 @@ fn gen_substitution(
) -> FxHashMap<Var, Arith> {
let assignment = FxHashMap::from_iter(assignment.clone());
let mut out = FxHashMap::default();
out.insert(rule.width, Arith::Width(assignment[&rule.width]));
for child in rule.children.iter() {
let width = assignment[&child.width];
out.insert(child.width, Arith::Width(width));
out.insert(child.sign, Arith::Signed(assignment[&child.sign] != 0));
for &width_var in rule.widths.iter() {
out.insert(width_var, Arith::Width(assignment[&width_var]));
}
for &sign_var in rule.signs.iter() {
out.insert(sign_var, Arith::Signed(assignment[&sign_var] != 0));
}
for child in rule.symbols.iter() {
let width = match child.width {
VarOrConst::C(w) => w,
VarOrConst::V(v) => assignment[&v],
};
let name = child.var.to_string().chars().skip(1).collect::<String>();
let symbol = ctx.bv_symbol(&name, width);
let name_ref = ctx[symbol].get_symbol_name_ref().unwrap();
Expand Down

0 comments on commit 231bd31

Please sign in to comment.