-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
model checking: implement enough to get to witness parsing
- Loading branch information
Showing
2 changed files
with
153 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
// author: Kevin Laeufer <[email protected]> | ||
|
||
use crate::ir; | ||
use crate::ir::{Context, Expr, ExprRef, GetNode, SignalKind, Type, TypeCheck}; | ||
use crate::ir::{Expr, ExprRef, GetNode, SignalKind, Type, TypeCheck}; | ||
use std::borrow::Cow; | ||
|
||
use crate::ir::SignalKind::Input; | ||
|
@@ -76,7 +76,8 @@ impl SmtModelChecker { | |
for k in 0..=k_max { | ||
// assume all constraints hold in this step | ||
for (expr_ref, _) in constraints.iter() { | ||
smt_ctx.assert(enc.get_constraint(*expr_ref))?; | ||
let expr = enc.get_constraint(&mut smt_ctx, *expr_ref); | ||
smt_ctx.assert(expr)?; | ||
} | ||
|
||
// make sure the constraints are not contradictory | ||
|
@@ -92,19 +93,20 @@ impl SmtModelChecker { | |
|
||
if self.opts.check_bad_states_individually { | ||
for (_bs_id, (expr_ref, _)) in bad_states.iter().enumerate() { | ||
let res = smt_ctx.check_assuming([enc.get_bad_state(*expr_ref)])?; | ||
let expr = enc.get_bad_state(&mut smt_ctx, *expr_ref); | ||
let res = smt_ctx.check_assuming([expr])?; | ||
|
||
if res == smt::Response::Sat { | ||
let wit = self.get_witness(sys, &mut smt_ctx, &enc, k)?; | ||
return Ok(ModelCheckResult::Fail(wit)); | ||
} | ||
} | ||
} else { | ||
let any_bad = smt_ctx.or_many( | ||
bad_states | ||
.iter() | ||
.map(|(expr_ref, _)| enc.get_bad_state(*expr_ref)), | ||
); | ||
let all_bads = bad_states | ||
.iter() | ||
.map(|(expr_ref, _)| enc.get_bad_state(&mut smt_ctx, *expr_ref)) | ||
.collect::<Vec<_>>(); | ||
let any_bad = smt_ctx.or_many(all_bads); | ||
let res = smt_ctx.check_assuming([any_bad])?; | ||
|
||
if res == smt::Response::Sat { | ||
|
@@ -131,7 +133,7 @@ impl SmtModelChecker { | |
let mut wit = Witness::default(); | ||
// collect initial values | ||
for (state_idx, state) in sys.states().enumerate() { | ||
let sym_at = enc.get_signal_at(state.symbol, 0); | ||
let sym_at = enc.get_symbol_at(smt_ctx, state.symbol, 0); | ||
let value = get_smt_value(smt_ctx, sym_at)?; | ||
wit.init.insert(state_idx, value); | ||
} | ||
|
@@ -141,7 +143,7 @@ impl SmtModelChecker { | |
for k in 0..=k_max { | ||
let mut input_values = State::default(); | ||
for (input_idx, (input, _)) in inputs.iter().enumerate() { | ||
let sym_at = enc.get_signal_at(*input, k); | ||
let sym_at = enc.get_symbol_at(smt_ctx, *input, k); | ||
let value = get_smt_value(smt_ctx, sym_at)?; | ||
input_values.insert(input_idx, value); | ||
} | ||
|
@@ -154,7 +156,7 @@ impl SmtModelChecker { | |
|
||
fn get_smt_value(smt_ctx: &mut smt::Context, expr: smt::SExpr) -> Result<Value> { | ||
let smt_value = smt_ctx.get_value(vec![expr])?[0].1; | ||
todo!("Convert: {:?}", smt_value) | ||
todo!("Convert: {:?}", smt_ctx.display(smt_value).to_string()) | ||
} | ||
|
||
pub enum ModelCheckResult { | ||
|
@@ -165,28 +167,40 @@ pub enum ModelCheckResult { | |
pub trait TransitionSystemEncoding { | ||
fn define_header(&self, smt_ctx: &mut smt::Context) -> Result<()>; | ||
fn init(&mut self, smt_ctx: &mut smt::Context) -> Result<()>; | ||
fn unroll(&self, smt_ctx: &mut smt::Context) -> Result<()>; | ||
fn get_constraint(&self, e: ExprRef) -> smt::SExpr; | ||
fn get_bad_state(&self, e: ExprRef) -> smt::SExpr; | ||
fn get_signal_at(&self, sym: ExprRef, k: u64) -> smt::SExpr; | ||
fn unroll(&mut self, smt_ctx: &mut smt::Context) -> Result<()>; | ||
fn get_constraint(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr; | ||
fn get_bad_state(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr; | ||
fn get_symbol_at(&self, smt_ctx: &mut smt::Context, sym: ExprRef, k: u64) -> smt::SExpr; | ||
} | ||
|
||
pub struct UnrollSmtEncoding<'a> { | ||
ctx: &'a ir::Context, | ||
sys: &'a ir::TransitionSystem, | ||
current_step: Option<u64>, | ||
inputs: Vec<(ExprRef, ir::SignalInfo)>, | ||
/// constraint and bad state signals (for now) | ||
signals: Vec<(ExprRef, String)>, | ||
} | ||
|
||
impl<'a> UnrollSmtEncoding<'a> { | ||
pub fn new(ctx: &'a ir::Context, sys: &'a ir::TransitionSystem) -> Self { | ||
// remember inputs instead of constantly re-filtering them | ||
let inputs = sys.get_signals(|s| s.kind == Input); | ||
// name all constraints and bad states | ||
let mut signals = Vec::new(); | ||
for (ii, (expr, _)) in sys.constraints().iter().enumerate() { | ||
signals.push((*expr, format!("__constraint_{ii}"))); | ||
} | ||
for (ii, (expr, _)) in sys.bad_states().iter().enumerate() { | ||
signals.push((*expr, format!("__bad_{ii}"))); | ||
} | ||
|
||
Self { | ||
ctx, | ||
sys, | ||
current_step: None, | ||
inputs, | ||
signals, | ||
} | ||
} | ||
|
||
|
@@ -199,25 +213,52 @@ impl<'a> UnrollSmtEncoding<'a> { | |
} | ||
|
||
// define signals | ||
todo!(); | ||
for (expr, name) in self.signals.iter() { | ||
let name = format!("{}@{}", name, step); | ||
let tpe = convert_tpe(smt_ctx, expr.get_type(self.ctx)); | ||
let value = self.expr_in_step(smt_ctx, *expr, step); | ||
smt_ctx.define_const(escape_smt_identifier(&name), tpe, value)?; | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
fn expr_in_step( | ||
fn get_local_expr_symbol_at( | ||
&self, | ||
smt_ctx: &mut smt::Context, | ||
ctx: &Context, | ||
expr: ExprRef, | ||
step: u64, | ||
e: ExprRef, | ||
k: u64, | ||
) -> smt::SExpr { | ||
// find our local name | ||
let base_name: &str = self | ||
.signals | ||
.iter() | ||
.find(|(id, _)| *id == e) | ||
.map(|(_, n)| n) | ||
.unwrap(); | ||
let name = format!("{}@{}", base_name, k); | ||
smt_ctx.atom(escape_smt_identifier(&name)) | ||
} | ||
|
||
fn expr_in_step(&self, smt_ctx: &mut smt::Context, expr: ExprRef, step: u64) -> smt::SExpr { | ||
let rename = |name: &str| -> Cow<'_, str> { Cow::Owned(format!("{}@{}", name, step)) }; | ||
convert_expr(smt_ctx, ctx, expr, rename) | ||
let general_rename = generalize_lifetime(rename); | ||
convert_expr(smt_ctx, self.ctx, expr, &general_rename) | ||
} | ||
|
||
fn name_at(&self, sym: ExprRef, step: u64) -> String { | ||
format!("{}@{}", sym.get_symbol_name(self.ctx).unwrap(), step) | ||
} | ||
} | ||
|
||
// stack overflow hack: https://stackoverflow.com/questions/70597152/creating-an-alias-for-a-fn-trait-results-in-one-type-is-more-general-than-the-o | ||
fn generalize_lifetime<'a, F>(f: F) -> F | ||
where | ||
F: Fn(&'a str) -> Cow<'a, str>, | ||
{ | ||
f | ||
} | ||
|
||
fn convert_tpe(smt_ctx: &mut smt::Context, tpe: Type) -> smt::SExpr { | ||
match tpe { | ||
Type::BV(1) => smt_ctx.bool_sort(), | ||
|
@@ -230,18 +271,20 @@ fn convert_tpe(smt_ctx: &mut smt::Context, tpe: Type) -> smt::SExpr { | |
} | ||
} | ||
|
||
fn convert_expr<'a, 'f>( | ||
smt_ctx: &'a mut smt::Context, | ||
fn convert_expr<'a, 'f, F>( | ||
smt_ctx: &smt::Context, | ||
ctx: &'a ir::Context, | ||
expr: ExprRef, | ||
rename_sym: impl Fn(&'f str) -> Cow<'f, str>, | ||
rename_sym: &F, | ||
) -> smt::SExpr | ||
where | ||
F: Fn(&'f str) -> Cow<'f, str>, | ||
'a: 'f, | ||
{ | ||
match ctx.get(expr) { | ||
Expr::BVSymbol { name, .. } => { | ||
let renamed = (rename_sym)(ctx.get(name)); | ||
let name_str = ctx.get(name); | ||
let renamed = (rename_sym)(name_str); | ||
smt_ctx.atom(escape_smt_identifier(&renamed)) | ||
} | ||
Expr::BVLiteral { value, width } if *width == 1 => { | ||
|
@@ -252,19 +295,48 @@ where | |
} | ||
} | ||
Expr::BVLiteral { value, width } => smt_ctx.binary(*width as usize, *value), | ||
Expr::BVZeroExt { .. } => todo!(), | ||
Expr::BVZeroExt { e, by, .. } => { | ||
let e_expr = convert_expr(smt_ctx, ctx, *e, rename_sym); | ||
// TODO: add support to easy_smt | ||
smt_ctx.list(vec![ | ||
smt_ctx.list(vec![ | ||
smt_ctx.atoms().und, | ||
smt_ctx.atom("zero_extend"), | ||
smt_ctx.numeral(*by as usize), | ||
]), | ||
e_expr, | ||
]) | ||
} | ||
Expr::BVSignExt { .. } => todo!(), | ||
Expr::BVSlice { .. } => todo!(), | ||
Expr::BVNot(_, _) => todo!(), | ||
Expr::BVSlice { e, hi, lo } => { | ||
let e_expr = convert_expr(smt_ctx, ctx, *e, rename_sym); | ||
smt_ctx.extract(*hi as i32, *lo as i32, e_expr) | ||
} | ||
Expr::BVNot(e_ref, _) => { | ||
let e = convert_expr(smt_ctx, ctx, *e_ref, rename_sym); | ||
smt_ctx.not(e) | ||
} | ||
Expr::BVNegate(_, _) => todo!(), | ||
Expr::BVReduceOr(_) => todo!(), | ||
Expr::BVReduceAnd(_) => todo!(), | ||
Expr::BVReduceXor(_) => todo!(), | ||
Expr::BVEqual(_, _) => todo!(), | ||
Expr::BVImplies(_, _) => todo!(), | ||
Expr::BVGreater(_, _) => todo!(), | ||
Expr::BVImplies(a_ref, b_ref) => { | ||
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); | ||
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); | ||
smt_ctx.imp(a, b) | ||
} | ||
Expr::BVGreater(a_ref, b_ref) => { | ||
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); | ||
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); | ||
smt_ctx.bvugt(a, b) | ||
} | ||
Expr::BVGreaterSigned(_, _) => todo!(), | ||
Expr::BVGreaterEqual(_, _) => todo!(), | ||
Expr::BVGreaterEqual(a_ref, b_ref) => { | ||
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); | ||
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); | ||
smt_ctx.bvuge(a, b) | ||
} | ||
Expr::BVGreaterEqualSigned(_, _) => todo!(), | ||
Expr::BVConcat(_, _, _) => todo!(), | ||
Expr::BVAnd(_, _, _) => todo!(), | ||
|
@@ -273,7 +345,11 @@ where | |
Expr::BVShiftLeft(_, _, _) => todo!(), | ||
Expr::BVArithmeticShiftRight(_, _, _) => todo!(), | ||
Expr::BVShiftRight(_, _, _) => todo!(), | ||
Expr::BVAdd(_, _, _) => todo!(), | ||
Expr::BVAdd(a_ref, b_ref, _) => { | ||
let a = convert_expr(smt_ctx, ctx, *a_ref, rename_sym); | ||
let b = convert_expr(smt_ctx, ctx, *b_ref, rename_sym); | ||
smt_ctx.bvadd(a, b) | ||
} | ||
Expr::BVMul(_, _, _) => todo!(), | ||
Expr::BVSignedDiv(_, _, _) => todo!(), | ||
Expr::BVUnsignedDiv(_, _, _) => todo!(), | ||
|
@@ -282,7 +358,12 @@ where | |
Expr::BVUnsignedRem(_, _, _) => todo!(), | ||
Expr::BVSub(_, _, _) => todo!(), | ||
Expr::BVArrayRead { .. } => todo!(), | ||
Expr::BVIte { .. } => todo!(), | ||
Expr::BVIte { cond, tru, fals } => { | ||
let c = convert_expr(smt_ctx, ctx, *cond, rename_sym); | ||
let t = convert_expr(smt_ctx, ctx, *tru, rename_sym); | ||
let f = convert_expr(smt_ctx, ctx, *fals, rename_sym); | ||
smt_ctx.ite(c, t, f) | ||
} | ||
Expr::ArraySymbol { name, .. } => { | ||
let renamed = (rename_sym)(ctx.get(name)); | ||
smt_ctx.atom(escape_smt_identifier(&renamed)) | ||
|
@@ -309,33 +390,60 @@ impl<'a> TransitionSystemEncoding for UnrollSmtEncoding<'a> { | |
let out = convert_tpe(smt_ctx, state.symbol.get_type(self.ctx)); | ||
match state.init { | ||
Some(value) => { | ||
let body = self.expr_in_step(smt_ctx, self.ctx, value, 0); | ||
let body = self.expr_in_step(smt_ctx, value, 0); | ||
smt_ctx.define_const(escape_smt_identifier(&name), out, body)?; | ||
} | ||
None => { | ||
smt_ctx.declare_const(escape_smt_identifier(&name), out)?; | ||
} | ||
} | ||
} | ||
|
||
// define the inputs for the initial state and all signals derived from it | ||
self.define_inputs_and_signals(smt_ctx, 0)?; | ||
Ok(()) | ||
} | ||
|
||
fn unroll(&self, smt_ctx: &mut smt::Context) -> Result<()> { | ||
todo!() | ||
fn unroll(&mut self, smt_ctx: &mut smt::Context) -> Result<()> { | ||
let prev_step = self.current_step.unwrap(); | ||
let next_step = prev_step + 1; | ||
|
||
// define next state | ||
for state in self.sys.states() { | ||
let name = self.name_at(state.symbol, next_step); | ||
let out = convert_tpe(smt_ctx, state.symbol.get_type(self.ctx)); | ||
match state.next { | ||
Some(value) => { | ||
let body = self.expr_in_step(smt_ctx, value, prev_step); | ||
smt_ctx.define_const(escape_smt_identifier(&name), out, body)?; | ||
} | ||
None => { | ||
smt_ctx.declare_const(escape_smt_identifier(&name), out)?; | ||
} | ||
} | ||
} | ||
|
||
// declare the inputs and all signals derived from the new state and inputs | ||
self.define_inputs_and_signals(smt_ctx, next_step)?; | ||
|
||
// update step count | ||
self.current_step = Some(next_step); | ||
Ok(()) | ||
} | ||
|
||
fn get_constraint(&self, e: ExprRef) -> smt::SExpr { | ||
todo!() | ||
fn get_constraint(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr { | ||
self.get_local_expr_symbol_at(smt_ctx, e, self.current_step.unwrap()) | ||
} | ||
|
||
fn get_bad_state(&self, e: ExprRef) -> smt::SExpr { | ||
todo!() | ||
fn get_bad_state(&self, smt_ctx: &mut smt::Context, e: ExprRef) -> smt::SExpr { | ||
self.get_local_expr_symbol_at(smt_ctx, e, self.current_step.unwrap()) | ||
} | ||
|
||
fn get_signal_at(&self, sym: ExprRef, k: u64) -> smt::SExpr { | ||
todo!() | ||
fn get_symbol_at(&self, smt_ctx: &mut smt::Context, sym: ExprRef, k: u64) -> smt::SExpr { | ||
assert!(sym.is_symbol(self.ctx)); | ||
assert!(k <= self.current_step.unwrap_or(0)); | ||
let name = self.name_at(sym, k); | ||
smt_ctx.atom(escape_smt_identifier(&name)) | ||
} | ||
} | ||
|
||
|
@@ -384,10 +492,10 @@ fn is_simple_smt_identifier(id: &str) -> bool { | |
|
||
fn escape_smt_identifier(id: &str) -> std::borrow::Cow<'_, str> { | ||
if is_simple_smt_identifier(id) { | ||
std::borrow::Cow::Borrowed(id) | ||
Cow::Borrowed(id) | ||
} else { | ||
let escaped = format!("|{}|", id); | ||
std::borrow::Cow::Owned(escaped) | ||
Cow::Owned(escaped) | ||
} | ||
} | ||
|
||
|