Skip to content

Commit

Permalink
model checking: implement enough to get to witness parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Nov 16, 2023
1 parent 284f128 commit 9ec4cd3
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/ir/smt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub trait ExprNodeConstruction:
}

fn greater(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
self.add_node(Expr::BVGreaterSigned(a, b))
self.add_node(Expr::BVGreater(a, b))
}
fn greater_or_equal_signed(&mut self, a: ExprRef, b: ExprRef) -> ExprRef {
self.add_node(Expr::BVGreaterEqualSigned(a, b))
Expand Down
196 changes: 152 additions & 44 deletions src/mc/smt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// author: Kevin Laeufer <[email protected]>

use crate::ir;
use crate::ir::{Context, Expr, ExprRef, GetNode, SignalKind, Type, TypeCheck};
use crate::ir::{Expr, ExprRef, GetNode, SignalKind, Type, TypeCheck};
use std::borrow::Cow;

use crate::ir::SignalKind::Input;
Expand Down Expand Up @@ -76,7 +76,8 @@ impl SmtModelChecker {
for k in 0..=k_max {
// assume all constraints hold in this step
for (expr_ref, _) in constraints.iter() {
smt_ctx.assert(enc.get_constraint(*expr_ref))?;
let expr = enc.get_constraint(&mut smt_ctx, *expr_ref);
smt_ctx.assert(expr)?;
}

// make sure the constraints are not contradictory
Expand All @@ -92,19 +93,20 @@ impl SmtModelChecker {

if self.opts.check_bad_states_individually {
for (_bs_id, (expr_ref, _)) in bad_states.iter().enumerate() {
let res = smt_ctx.check_assuming([enc.get_bad_state(*expr_ref)])?;
let expr = enc.get_bad_state(&mut smt_ctx, *expr_ref);
let res = smt_ctx.check_assuming([expr])?;

if res == smt::Response::Sat {
let wit = self.get_witness(sys, &mut smt_ctx, &enc, k)?;
return Ok(ModelCheckResult::Fail(wit));
}
}
} else {
let any_bad = smt_ctx.or_many(
bad_states
.iter()
.map(|(expr_ref, _)| enc.get_bad_state(*expr_ref)),
);
let all_bads = bad_states
.iter()
.map(|(expr_ref, _)| enc.get_bad_state(&mut smt_ctx, *expr_ref))
.collect::<Vec<_>>();
let any_bad = smt_ctx.or_many(all_bads);
let res = smt_ctx.check_assuming([any_bad])?;

if res == smt::Response::Sat {
Expand All @@ -131,7 +133,7 @@ impl SmtModelChecker {
let mut wit = Witness::default();
// collect initial values
for (state_idx, state) in sys.states().enumerate() {
let sym_at = enc.get_signal_at(state.symbol, 0);
let sym_at = enc.get_symbol_at(smt_ctx, state.symbol, 0);
let value = get_smt_value(smt_ctx, sym_at)?;
wit.init.insert(state_idx, value);
}
Expand All @@ -141,7 +143,7 @@ impl SmtModelChecker {
for k in 0..=k_max {
let mut input_values = State::default();
for (input_idx, (input, _)) in inputs.iter().enumerate() {
let sym_at = enc.get_signal_at(*input, k);
let sym_at = enc.get_symbol_at(smt_ctx, *input, k);
let value = get_smt_value(smt_ctx, sym_at)?;
input_values.insert(input_idx, value);
}
Expand All @@ -154,7 +156,7 @@ impl SmtModelChecker {

fn get_smt_value(smt_ctx: &mut smt::Context, expr: smt::SExpr) -> Result<Value> {
let smt_value = smt_ctx.get_value(vec![expr])?[0].1;
todo!("Convert: {:?}", smt_value)
todo!("Convert: {:?}", smt_ctx.display(smt_value).to_string())
}

pub enum ModelCheckResult {
Expand All @@ -165,28 +167,40 @@ pub enum ModelCheckResult {
pub trait TransitionSystemEncoding {
fn define_header(&self, smt_ctx: &mut smt::Context) -> Result<()>;
fn init(&mut self, smt_ctx: &mut smt::Context) -> Result<()>;
fn unroll(&self, smt_ctx: &mut smt::Context) -> Result<()>;
fn get_constraint(&self, e: ExprRef) -> smt::SExpr;
fn get_bad_state(&self, e: ExprRef) -> smt::SExpr;
fn get_signal_at(&self, sym: ExprRef, k: u64) -> smt::SExpr;
fn unroll(&mut self, smt_ctx: &mut smt::Context) -> Result<()>;
fn get_constraint(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr;
fn get_bad_state(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr;
fn get_symbol_at(&self, smt_ctx: &mut smt::Context, sym: ExprRef, k: u64) -> smt::SExpr;
}

pub struct UnrollSmtEncoding<'a> {
ctx: &'a ir::Context,
sys: &'a ir::TransitionSystem,
current_step: Option<u64>,
inputs: Vec<(ExprRef, ir::SignalInfo)>,
/// constraint and bad state signals (for now)
signals: Vec<(ExprRef, String)>,
}

impl<'a> UnrollSmtEncoding<'a> {
pub fn new(ctx: &'a ir::Context, sys: &'a ir::TransitionSystem) -> Self {
// remember inputs instead of constantly re-filtering them
let inputs = sys.get_signals(|s| s.kind == Input);
// name all constraints and bad states
let mut signals = Vec::new();
for (ii, (expr, _)) in sys.constraints().iter().enumerate() {
signals.push((*expr, format!("__constraint_{ii}")));
}
for (ii, (expr, _)) in sys.bad_states().iter().enumerate() {
signals.push((*expr, format!("__bad_{ii}")));
}

Self {
ctx,
sys,
current_step: None,
inputs,
signals,
}
}

Expand All @@ -199,25 +213,52 @@ impl<'a> UnrollSmtEncoding<'a> {
}

// define signals
todo!();
for (expr, name) in self.signals.iter() {
let name = format!("{}@{}", name, step);
let tpe = convert_tpe(smt_ctx, expr.get_type(self.ctx));
let value = self.expr_in_step(smt_ctx, *expr, step);
smt_ctx.define_const(escape_smt_identifier(&name), tpe, value)?;
}

Ok(())
}

fn expr_in_step(
fn get_local_expr_symbol_at(
&self,
smt_ctx: &mut smt::Context,
ctx: &Context,
expr: ExprRef,
step: u64,
e: ExprRef,
k: u64,
) -> smt::SExpr {
// find our local name
let base_name: &str = self
.signals
.iter()
.find(|(id, _)| *id == e)
.map(|(_, n)| n)
.unwrap();
let name = format!("{}@{}", base_name, k);
smt_ctx.atom(escape_smt_identifier(&name))
}

fn expr_in_step(&self, smt_ctx: &mut smt::Context, expr: ExprRef, step: u64) -> smt::SExpr {
let rename = |name: &str| -> Cow<'_, str> { Cow::Owned(format!("{}@{}", name, step)) };
convert_expr(smt_ctx, ctx, expr, rename)
let general_rename = generalize_lifetime(rename);
convert_expr(smt_ctx, self.ctx, expr, &general_rename)
}

fn name_at(&self, sym: ExprRef, step: u64) -> String {
format!("{}@{}", sym.get_symbol_name(self.ctx).unwrap(), step)
}
}

// stack overflow hack: https://stackoverflow.com/questions/70597152/creating-an-alias-for-a-fn-trait-results-in-one-type-is-more-general-than-the-o
fn generalize_lifetime<'a, F>(f: F) -> F
where
F: Fn(&'a str) -> Cow<'a, str>,
{
f
}

fn convert_tpe(smt_ctx: &mut smt::Context, tpe: Type) -> smt::SExpr {
match tpe {
Type::BV(1) => smt_ctx.bool_sort(),
Expand All @@ -230,18 +271,20 @@ fn convert_tpe(smt_ctx: &mut smt::Context, tpe: Type) -> smt::SExpr {
}
}

fn convert_expr<'a, 'f>(
smt_ctx: &'a mut smt::Context,
fn convert_expr<'a, 'f, F>(
smt_ctx: &smt::Context,
ctx: &'a ir::Context,
expr: ExprRef,
rename_sym: impl Fn(&'f str) -> Cow<'f, str>,
rename_sym: &F,
) -> smt::SExpr
where
F: Fn(&'f str) -> Cow<'f, str>,
'a: 'f,
{
match ctx.get(expr) {
Expr::BVSymbol { name, .. } => {
let renamed = (rename_sym)(ctx.get(name));
let name_str = ctx.get(name);
let renamed = (rename_sym)(name_str);
smt_ctx.atom(escape_smt_identifier(&renamed))
}
Expr::BVLiteral { value, width } if *width == 1 => {
Expand All @@ -252,19 +295,48 @@ where
}
}
Expr::BVLiteral { value, width } => smt_ctx.binary(*width as usize, *value),
Expr::BVZeroExt { .. } => todo!(),
Expr::BVZeroExt { e, by, .. } => {
let e_expr = convert_expr(smt_ctx, ctx, *e, rename_sym);
// TODO: add support to easy_smt
smt_ctx.list(vec![
smt_ctx.list(vec![
smt_ctx.atoms().und,
smt_ctx.atom("zero_extend"),
smt_ctx.numeral(*by as usize),
]),
e_expr,
])
}
Expr::BVSignExt { .. } => todo!(),
Expr::BVSlice { .. } => todo!(),
Expr::BVNot(_, _) => todo!(),
Expr::BVSlice { e, hi, lo } => {
let e_expr = convert_expr(smt_ctx, ctx, *e, rename_sym);
smt_ctx.extract(*hi as i32, *lo as i32, e_expr)
}
Expr::BVNot(e_ref, _) => {
let e = convert_expr(smt_ctx, ctx, *e_ref, rename_sym);
smt_ctx.not(e)
}
Expr::BVNegate(_, _) => todo!(),
Expr::BVReduceOr(_) => todo!(),
Expr::BVReduceAnd(_) => todo!(),
Expr::BVReduceXor(_) => todo!(),
Expr::BVEqual(_, _) => todo!(),
Expr::BVImplies(_, _) => todo!(),
Expr::BVGreater(_, _) => todo!(),
Expr::BVImplies(a_ref, b_ref) => {
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym);
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym);
smt_ctx.imp(a, b)
}
Expr::BVGreater(a_ref, b_ref) => {
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym);
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym);
smt_ctx.bvugt(a, b)
}
Expr::BVGreaterSigned(_, _) => todo!(),
Expr::BVGreaterEqual(_, _) => todo!(),
Expr::BVGreaterEqual(a_ref, b_ref) => {
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym);
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym);
smt_ctx.bvuge(a, b)
}
Expr::BVGreaterEqualSigned(_, _) => todo!(),
Expr::BVConcat(_, _, _) => todo!(),
Expr::BVAnd(_, _, _) => todo!(),
Expand All @@ -273,7 +345,11 @@ where
Expr::BVShiftLeft(_, _, _) => todo!(),
Expr::BVArithmeticShiftRight(_, _, _) => todo!(),
Expr::BVShiftRight(_, _, _) => todo!(),
Expr::BVAdd(_, _, _) => todo!(),
Expr::BVAdd(a_ref, b_ref, _) => {
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym);
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym);
smt_ctx.bvadd(a, b)
}
Expr::BVMul(_, _, _) => todo!(),
Expr::BVSignedDiv(_, _, _) => todo!(),
Expr::BVUnsignedDiv(_, _, _) => todo!(),
Expand All @@ -282,7 +358,12 @@ where
Expr::BVUnsignedRem(_, _, _) => todo!(),
Expr::BVSub(_, _, _) => todo!(),
Expr::BVArrayRead { .. } => todo!(),
Expr::BVIte { .. } => todo!(),
Expr::BVIte { cond, tru, fals } => {
let c = convert_expr(smt_ctx, ctx, *cond, rename_sym);
let t = convert_expr(smt_ctx, ctx, *tru, rename_sym);
let f = convert_expr(smt_ctx, ctx, *fals, rename_sym);
smt_ctx.ite(c, t, f)
}
Expr::ArraySymbol { name, .. } => {
let renamed = (rename_sym)(ctx.get(name));
smt_ctx.atom(escape_smt_identifier(&renamed))
Expand All @@ -309,33 +390,60 @@ impl<'a> TransitionSystemEncoding for UnrollSmtEncoding<'a> {
let out = convert_tpe(smt_ctx, state.symbol.get_type(self.ctx));
match state.init {
Some(value) => {
let body = self.expr_in_step(smt_ctx, self.ctx, value, 0);
let body = self.expr_in_step(smt_ctx, value, 0);
smt_ctx.define_const(escape_smt_identifier(&name), out, body)?;
}
None => {
smt_ctx.declare_const(escape_smt_identifier(&name), out)?;
}
}
}

// define the inputs for the initial state and all signals derived from it
self.define_inputs_and_signals(smt_ctx, 0)?;
Ok(())
}

fn unroll(&self, smt_ctx: &mut smt::Context) -> Result<()> {
todo!()
fn unroll(&mut self, smt_ctx: &mut smt::Context) -> Result<()> {
let prev_step = self.current_step.unwrap();
let next_step = prev_step + 1;

// define next state
for state in self.sys.states() {
let name = self.name_at(state.symbol, next_step);
let out = convert_tpe(smt_ctx, state.symbol.get_type(self.ctx));
match state.next {
Some(value) => {
let body = self.expr_in_step(smt_ctx, value, prev_step);
smt_ctx.define_const(escape_smt_identifier(&name), out, body)?;
}
None => {
smt_ctx.declare_const(escape_smt_identifier(&name), out)?;
}
}
}

// declare the inputs and all signals derived from the new state and inputs
self.define_inputs_and_signals(smt_ctx, next_step)?;

// update step count
self.current_step = Some(next_step);
Ok(())
}

fn get_constraint(&self, e: ExprRef) -> smt::SExpr {
todo!()
fn get_constraint(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr {
self.get_local_expr_symbol_at(smt_ctx, e, self.current_step.unwrap())
}

fn get_bad_state(&self, e: ExprRef) -> smt::SExpr {
todo!()
fn get_bad_state(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr {
self.get_local_expr_symbol_at(smt_ctx, e, self.current_step.unwrap())
}

fn get_signal_at(&self, sym: ExprRef, k: u64) -> smt::SExpr {
todo!()
fn get_symbol_at(&self, smt_ctx: &mut smt::Context, sym: ExprRef, k: u64) -> smt::SExpr {
assert!(sym.is_symbol(self.ctx));
assert!(k <= self.current_step.unwrap_or(0));
let name = self.name_at(sym, k);
smt_ctx.atom(escape_smt_identifier(&name))
}
}

Expand Down Expand Up @@ -384,10 +492,10 @@ fn is_simple_smt_identifier(id: &str) -> bool {

fn escape_smt_identifier(id: &str) -> std::borrow::Cow<'_, str> {
if is_simple_smt_identifier(id) {
std::borrow::Cow::Borrowed(id)
Cow::Borrowed(id)
} else {
let escaped = format!("|{}|", id);
std::borrow::Cow::Owned(escaped)
Cow::Owned(escaped)
}
}

Expand Down

0 comments on commit 9ec4cd3

Please sign in to comment.