diff --git a/src/ir/smt.rs b/src/ir/smt.rs index da3ba6d..cc6bd82 100644 --- a/src/ir/smt.rs +++ b/src/ir/smt.rs @@ -62,7 +62,7 @@ pub trait ExprNodeConstruction: } fn greater(&mut self, a: ExprRef, b: ExprRef) -> ExprRef { - self.add_node(Expr::BVGreaterSigned(a, b)) + self.add_node(Expr::BVGreater(a, b)) } fn greater_or_equal_signed(&mut self, a: ExprRef, b: ExprRef) -> ExprRef { self.add_node(Expr::BVGreaterEqualSigned(a, b)) diff --git a/src/mc/smt.rs b/src/mc/smt.rs index 49bd551..439a3bd 100644 --- a/src/mc/smt.rs +++ b/src/mc/smt.rs @@ -3,7 +3,7 @@ // author: Kevin Laeufer use crate::ir; -use crate::ir::{Context, Expr, ExprRef, GetNode, SignalKind, Type, TypeCheck}; +use crate::ir::{Expr, ExprRef, GetNode, SignalKind, Type, TypeCheck}; use std::borrow::Cow; use crate::ir::SignalKind::Input; @@ -76,7 +76,8 @@ impl SmtModelChecker { for k in 0..=k_max { // assume all constraints hold in this step for (expr_ref, _) in constraints.iter() { - smt_ctx.assert(enc.get_constraint(*expr_ref))?; + let expr = enc.get_constraint(&mut smt_ctx, *expr_ref); + smt_ctx.assert(expr)?; } // make sure the constraints are not contradictory @@ -92,7 +93,8 @@ impl SmtModelChecker { if self.opts.check_bad_states_individually { for (_bs_id, (expr_ref, _)) in bad_states.iter().enumerate() { - let res = smt_ctx.check_assuming([enc.get_bad_state(*expr_ref)])?; + let expr = enc.get_bad_state(&mut smt_ctx, *expr_ref); + let res = smt_ctx.check_assuming([expr])?; if res == smt::Response::Sat { let wit = self.get_witness(sys, &mut smt_ctx, &enc, k)?; @@ -100,11 +102,11 @@ impl SmtModelChecker { } } } else { - let any_bad = smt_ctx.or_many( - bad_states - .iter() - .map(|(expr_ref, _)| enc.get_bad_state(*expr_ref)), - ); + let all_bads = bad_states + .iter() + .map(|(expr_ref, _)| enc.get_bad_state(&mut smt_ctx, *expr_ref)) + .collect::>(); + let any_bad = smt_ctx.or_many(all_bads); let res = smt_ctx.check_assuming([any_bad])?; if res == smt::Response::Sat { @@ -131,7 +133,7 @@ impl SmtModelChecker { let mut wit = Witness::default(); // collect initial values for (state_idx, state) in sys.states().enumerate() { - let sym_at = enc.get_signal_at(state.symbol, 0); + let sym_at = enc.get_symbol_at(smt_ctx, state.symbol, 0); let value = get_smt_value(smt_ctx, sym_at)?; wit.init.insert(state_idx, value); } @@ -141,7 +143,7 @@ impl SmtModelChecker { for k in 0..=k_max { let mut input_values = State::default(); for (input_idx, (input, _)) in inputs.iter().enumerate() { - let sym_at = enc.get_signal_at(*input, k); + let sym_at = enc.get_symbol_at(smt_ctx, *input, k); let value = get_smt_value(smt_ctx, sym_at)?; input_values.insert(input_idx, value); } @@ -154,7 +156,7 @@ impl SmtModelChecker { fn get_smt_value(smt_ctx: &mut smt::Context, expr: smt::SExpr) -> Result { let smt_value = smt_ctx.get_value(vec![expr])?[0].1; - todo!("Convert: {:?}", smt_value) + todo!("Convert: {:?}", smt_ctx.display(smt_value).to_string()) } pub enum ModelCheckResult { @@ -165,10 +167,10 @@ pub enum ModelCheckResult { pub trait TransitionSystemEncoding { fn define_header(&self, smt_ctx: &mut smt::Context) -> Result<()>; fn init(&mut self, smt_ctx: &mut smt::Context) -> Result<()>; - fn unroll(&self, smt_ctx: &mut smt::Context) -> Result<()>; - fn get_constraint(&self, e: ExprRef) -> smt::SExpr; - fn get_bad_state(&self, e: ExprRef) -> smt::SExpr; - fn get_signal_at(&self, sym: ExprRef, k: u64) -> smt::SExpr; + fn unroll(&mut self, smt_ctx: &mut smt::Context) -> Result<()>; + fn get_constraint(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr; + fn get_bad_state(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr; + fn get_symbol_at(&self, smt_ctx: &mut smt::Context, sym: ExprRef, k: u64) -> smt::SExpr; } pub struct UnrollSmtEncoding<'a> { @@ -176,17 +178,29 @@ pub struct UnrollSmtEncoding<'a> { sys: &'a ir::TransitionSystem, current_step: Option, inputs: Vec<(ExprRef, ir::SignalInfo)>, + /// constraint and bad state signals (for now) + signals: Vec<(ExprRef, String)>, } impl<'a> UnrollSmtEncoding<'a> { pub fn new(ctx: &'a ir::Context, sys: &'a ir::TransitionSystem) -> Self { // remember inputs instead of constantly re-filtering them let inputs = sys.get_signals(|s| s.kind == Input); + // name all constraints and bad states + let mut signals = Vec::new(); + for (ii, (expr, _)) in sys.constraints().iter().enumerate() { + signals.push((*expr, format!("__constraint_{ii}"))); + } + for (ii, (expr, _)) in sys.bad_states().iter().enumerate() { + signals.push((*expr, format!("__bad_{ii}"))); + } + Self { ctx, sys, current_step: None, inputs, + signals, } } @@ -199,18 +213,37 @@ impl<'a> UnrollSmtEncoding<'a> { } // define signals - todo!(); + for (expr, name) in self.signals.iter() { + let name = format!("{}@{}", name, step); + let tpe = convert_tpe(smt_ctx, expr.get_type(self.ctx)); + let value = self.expr_in_step(smt_ctx, *expr, step); + smt_ctx.define_const(escape_smt_identifier(&name), tpe, value)?; + } + + Ok(()) } - fn expr_in_step( + fn get_local_expr_symbol_at( &self, smt_ctx: &mut smt::Context, - ctx: &Context, - expr: ExprRef, - step: u64, + e: ExprRef, + k: u64, ) -> smt::SExpr { + // find our local name + let base_name: &str = self + .signals + .iter() + .find(|(id, _)| *id == e) + .map(|(_, n)| n) + .unwrap(); + let name = format!("{}@{}", base_name, k); + smt_ctx.atom(escape_smt_identifier(&name)) + } + + fn expr_in_step(&self, smt_ctx: &mut smt::Context, expr: ExprRef, step: u64) -> smt::SExpr { let rename = |name: &str| -> Cow<'_, str> { Cow::Owned(format!("{}@{}", name, step)) }; - convert_expr(smt_ctx, ctx, expr, rename) + let general_rename = generalize_lifetime(rename); + convert_expr(smt_ctx, self.ctx, expr, &general_rename) } fn name_at(&self, sym: ExprRef, step: u64) -> String { @@ -218,6 +251,14 @@ impl<'a> UnrollSmtEncoding<'a> { } } +// stack overflow hack: https://stackoverflow.com/questions/70597152/creating-an-alias-for-a-fn-trait-results-in-one-type-is-more-general-than-the-o +fn generalize_lifetime<'a, F>(f: F) -> F +where + F: Fn(&'a str) -> Cow<'a, str>, +{ + f +} + fn convert_tpe(smt_ctx: &mut smt::Context, tpe: Type) -> smt::SExpr { match tpe { Type::BV(1) => smt_ctx.bool_sort(), @@ -230,18 +271,20 @@ fn convert_tpe(smt_ctx: &mut smt::Context, tpe: Type) -> smt::SExpr { } } -fn convert_expr<'a, 'f>( - smt_ctx: &'a mut smt::Context, +fn convert_expr<'a, 'f, F>( + smt_ctx: &smt::Context, ctx: &'a ir::Context, expr: ExprRef, - rename_sym: impl Fn(&'f str) -> Cow<'f, str>, + rename_sym: &F, ) -> smt::SExpr where + F: Fn(&'f str) -> Cow<'f, str>, 'a: 'f, { match ctx.get(expr) { Expr::BVSymbol { name, .. } => { - let renamed = (rename_sym)(ctx.get(name)); + let name_str = ctx.get(name); + let renamed = (rename_sym)(name_str); smt_ctx.atom(escape_smt_identifier(&renamed)) } Expr::BVLiteral { value, width } if *width == 1 => { @@ -252,19 +295,48 @@ where } } Expr::BVLiteral { value, width } => smt_ctx.binary(*width as usize, *value), - Expr::BVZeroExt { .. } => todo!(), + Expr::BVZeroExt { e, by, .. } => { + let e_expr = convert_expr(smt_ctx, ctx, *e, rename_sym); + // TODO: add support to easy_smt + smt_ctx.list(vec![ + smt_ctx.list(vec![ + smt_ctx.atoms().und, + smt_ctx.atom("zero_extend"), + smt_ctx.numeral(*by as usize), + ]), + e_expr, + ]) + } Expr::BVSignExt { .. } => todo!(), - Expr::BVSlice { .. } => todo!(), - Expr::BVNot(_, _) => todo!(), + Expr::BVSlice { e, hi, lo } => { + let e_expr = convert_expr(smt_ctx, ctx, *e, rename_sym); + smt_ctx.extract(*hi as i32, *lo as i32, e_expr) + } + Expr::BVNot(e_ref, _) => { + let e = convert_expr(smt_ctx, ctx, *e_ref, rename_sym); + smt_ctx.not(e) + } Expr::BVNegate(_, _) => todo!(), Expr::BVReduceOr(_) => todo!(), Expr::BVReduceAnd(_) => todo!(), Expr::BVReduceXor(_) => todo!(), Expr::BVEqual(_, _) => todo!(), - Expr::BVImplies(_, _) => todo!(), - Expr::BVGreater(_, _) => todo!(), + Expr::BVImplies(a_ref, b_ref) => { + let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); + let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); + smt_ctx.imp(a, b) + } + Expr::BVGreater(a_ref, b_ref) => { + let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); + let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); + smt_ctx.bvugt(a, b) + } Expr::BVGreaterSigned(_, _) => todo!(), - Expr::BVGreaterEqual(_, _) => todo!(), + Expr::BVGreaterEqual(a_ref, b_ref) => { + let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); + let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); + smt_ctx.bvuge(a, b) + } Expr::BVGreaterEqualSigned(_, _) => todo!(), Expr::BVConcat(_, _, _) => todo!(), Expr::BVAnd(_, _, _) => todo!(), @@ -273,7 +345,11 @@ where Expr::BVShiftLeft(_, _, _) => todo!(), Expr::BVArithmeticShiftRight(_, _, _) => todo!(), Expr::BVShiftRight(_, _, _) => todo!(), - Expr::BVAdd(_, _, _) => todo!(), + Expr::BVAdd(a_ref, b_ref, _) => { + let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); + let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); + smt_ctx.bvadd(a, b) + } Expr::BVMul(_, _, _) => todo!(), Expr::BVSignedDiv(_, _, _) => todo!(), Expr::BVUnsignedDiv(_, _, _) => todo!(), @@ -282,7 +358,12 @@ where Expr::BVUnsignedRem(_, _, _) => todo!(), Expr::BVSub(_, _, _) => todo!(), Expr::BVArrayRead { .. } => todo!(), - Expr::BVIte { .. } => todo!(), + Expr::BVIte { cond, tru, fals } => { + let c = convert_expr(smt_ctx, ctx, *cond, rename_sym); + let t = convert_expr(smt_ctx, ctx, *tru, rename_sym); + let f = convert_expr(smt_ctx, ctx, *fals, rename_sym); + smt_ctx.ite(c, t, f) + } Expr::ArraySymbol { name, .. } => { let renamed = (rename_sym)(ctx.get(name)); smt_ctx.atom(escape_smt_identifier(&renamed)) @@ -309,7 +390,7 @@ impl<'a> TransitionSystemEncoding for UnrollSmtEncoding<'a> { let out = convert_tpe(smt_ctx, state.symbol.get_type(self.ctx)); match state.init { Some(value) => { - let body = self.expr_in_step(smt_ctx, self.ctx, value, 0); + let body = self.expr_in_step(smt_ctx, value, 0); smt_ctx.define_const(escape_smt_identifier(&name), out, body)?; } None => { @@ -317,25 +398,52 @@ impl<'a> TransitionSystemEncoding for UnrollSmtEncoding<'a> { } } } + // define the inputs for the initial state and all signals derived from it self.define_inputs_and_signals(smt_ctx, 0)?; Ok(()) } - fn unroll(&self, smt_ctx: &mut smt::Context) -> Result<()> { - todo!() + fn unroll(&mut self, smt_ctx: &mut smt::Context) -> Result<()> { + let prev_step = self.current_step.unwrap(); + let next_step = prev_step + 1; + + // define next state + for state in self.sys.states() { + let name = self.name_at(state.symbol, next_step); + let out = convert_tpe(smt_ctx, state.symbol.get_type(self.ctx)); + match state.next { + Some(value) => { + let body = self.expr_in_step(smt_ctx, value, prev_step); + smt_ctx.define_const(escape_smt_identifier(&name), out, body)?; + } + None => { + smt_ctx.declare_const(escape_smt_identifier(&name), out)?; + } + } + } + + // declare the inputs and all signals derived from the new state and inputs + self.define_inputs_and_signals(smt_ctx, next_step)?; + + // update step count + self.current_step = Some(next_step); + Ok(()) } - fn get_constraint(&self, e: ExprRef) -> smt::SExpr { - todo!() + fn get_constraint(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr { + self.get_local_expr_symbol_at(smt_ctx, e, self.current_step.unwrap()) } - fn get_bad_state(&self, e: ExprRef) -> smt::SExpr { - todo!() + fn get_bad_state(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr { + self.get_local_expr_symbol_at(smt_ctx, e, self.current_step.unwrap()) } - fn get_signal_at(&self, sym: ExprRef, k: u64) -> smt::SExpr { - todo!() + fn get_symbol_at(&self, smt_ctx: &mut smt::Context, sym: ExprRef, k: u64) -> smt::SExpr { + assert!(sym.is_symbol(self.ctx)); + assert!(k <= self.current_step.unwrap_or(0)); + let name = self.name_at(sym, k); + smt_ctx.atom(escape_smt_identifier(&name)) } } @@ -384,10 +492,10 @@ fn is_simple_smt_identifier(id: &str) -> bool { fn escape_smt_identifier(id: &str) -> std::borrow::Cow<'_, str> { if is_simple_smt_identifier(id) { - std::borrow::Cow::Borrowed(id) + Cow::Borrowed(id) } else { let escaped = format!("|{}|", id); - std::borrow::Cow::Owned(escaped) + Cow::Owned(escaped) } }