Skip to content

Commit

Permalink
arc feature to enable using Engine and other data structures from m…
Browse files Browse the repository at this point in the history
…ultiple threads (microsoft#142)

* `arc` feature to make engine usable from multiple threads.

`arc` is turned on by default. When enabled, std::sync::Arc
will be used instead of std::rc::Rc. The former makes regorus
types like Engine, Value, ast nodes etc Send, allowing for
usability from multiple threads.
Arc would add a performance overhead though since the reference
counting will now become atomic.

Signed-off-by: Anand Krishnamoorthi <[email protected]>

* Make engine and related types Debug

Signed-off-by: Anand Krishnamoorthi <[email protected]>

* Input, Data as json. Evaluate bool queries.

Signed-off-by: Anand Krishnamoorthi <[email protected]>

---------

Signed-off-by: Anand Krishnamoorthi <[email protected]>
  • Loading branch information
anakrish authored Feb 13, 2024
1 parent 13eb06e commit 53b990f
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 24 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ keywords = ["interpreter", "opa", "policy-as-code", "rego"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = ["full-opa"]
default = ["full-opa", "arc"]

arc = ["scientific/arc"]
base64 = ["dep:data-encoding"]
base64url = ["dep:data-encoding"]
crypto = ["dep:constant_time_eq", "dep:hmac", "dep:hex", "dep:md-5", "dep:sha1", "dep:sha2"]
Expand Down
11 changes: 5 additions & 6 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

use crate::lexer::*;
use crate::Rc;

use std::ops::Deref;

Expand Down Expand Up @@ -37,7 +38,7 @@ pub enum AssignOp {
}

pub struct NodeRef<T> {
r: std::rc::Rc<T>,
r: Rc<T>,
}

impl<T> Clone for NodeRef<T> {
Expand All @@ -54,15 +55,15 @@ impl<T: std::fmt::Debug> std::fmt::Debug for NodeRef<T> {

impl<T> std::cmp::PartialEq for NodeRef<T> {
fn eq(&self, other: &Self) -> bool {
std::rc::Rc::as_ptr(&self.r).eq(&std::rc::Rc::as_ptr(&other.r))
Rc::as_ptr(&self.r).eq(&Rc::as_ptr(&other.r))
}
}

impl<T> std::cmp::Eq for NodeRef<T> {}

impl<T> std::cmp::Ord for NodeRef<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
std::rc::Rc::as_ptr(&self.r).cmp(&std::rc::Rc::as_ptr(&other.r))
Rc::as_ptr(&self.r).cmp(&Rc::as_ptr(&other.r))
}
}

Expand All @@ -88,9 +89,7 @@ impl<T> AsRef<T> for NodeRef<T> {

impl<T> NodeRef<T> {
pub fn new(t: T) -> Self {
Self {
r: std::rc::Rc::new(t),
}
Self { r: Rc::new(t) }
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/builtins/arrays.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::ast::{Expr, Ref};
use crate::builtins;
use crate::builtins::utils::{ensure_args_count, ensure_array, ensure_numeric};
use crate::lexer::Span;
use crate::value::{Rc, Value};
use crate::Rc;
use crate::Value;

use std::collections::HashMap;

Expand Down
3 changes: 2 additions & 1 deletion src/builtins/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::ast::{Expr, Ref};
use crate::builtins;
use crate::builtins::utils::{ensure_args_count, ensure_array, ensure_object};
use crate::lexer::Span;
use crate::value::{Rc, Value};
use crate::Rc;
use crate::Value;

use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::iter::Iterator;
Expand Down
3 changes: 2 additions & 1 deletion src/builtins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
use crate::ast::{Expr, Ref};
use crate::lexer::Span;
use crate::number::Number;
use crate::value::{Rc, Value};
use crate::Rc;
use crate::Value;

use std::collections::{BTreeMap, BTreeSet};

Expand Down
19 changes: 18 additions & 1 deletion src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use anyhow::{bail, Result};

/// The Rego evaluation engine.
///
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Engine {
modules: Vec<Ref<Module>>,
interpreter: Interpreter,
Expand Down Expand Up @@ -120,6 +120,11 @@ impl Engine {
self.interpreter.set_input(input);
}

pub fn set_input_json(&mut self, input_json: &str) -> Result<()> {
self.set_input(Value::from_json_str(input_json)?);
Ok(())
}

/// Clear the data document.
///
/// The data document will be reset to an empty object.
Expand Down Expand Up @@ -182,6 +187,10 @@ impl Engine {
self.interpreter.get_data_mut().merge(data)
}

pub fn add_data_json(&mut self, data_json: &str) -> Result<()> {
self.add_data(Value::from_json_str(data_json)?)
}

/// Set whether builtins should raise errors strictly or not.
///
/// Regorus differs from OPA in that by default builtins will
Expand Down Expand Up @@ -256,6 +265,14 @@ impl Engine {
)
}

pub fn eval_bool_query(&mut self, query: String, enable_tracing: bool) -> Result<bool> {
let results = self.eval_query(query, enable_tracing)?;
if results.result.len() != 1 || results.result[0].expressions.len() != 1 {
bail!("query did not produce exactly one value");
}
results.result[0].expressions[0].value.as_bool().copied()
}

#[doc(hidden)]
fn prepare_for_eval(&mut self, enable_tracing: bool) -> Result<()> {
self.interpreter.set_traces(enable_tracing);
Expand Down
9 changes: 5 additions & 4 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::parser::Parser;
use crate::scheduler::*;
use crate::utils::*;
use crate::value::*;
use crate::Rc;
use crate::{Expression, Extension, Location, QueryResult, QueryResults};

use anyhow::{anyhow, bail, Result};
Expand Down Expand Up @@ -37,7 +38,7 @@ enum FunctionModifier {
Value(Value),
}

#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Interpreter {
modules: Vec<Ref<Module>>,
module: Option<Ref<Module>>,
Expand All @@ -64,7 +65,7 @@ pub struct Interpreter {
allow_deprecated: bool,
strict_builtin_errors: bool,
imports: BTreeMap<String, Ref<Expr>>,
extensions: HashMap<String, (u8, Box<dyn Extension>)>,
extensions: HashMap<String, (u8, Rc<Box<dyn Extension>>)>,
}

impl Default for Interpreter {
Expand Down Expand Up @@ -2165,7 +2166,7 @@ impl Interpreter {
if param_values.len() != *nargs as usize {
bail!(span.error("incorrect number of parameters supplied to extension"));
}
let r = ext(param_values);
let r = Rc::make_mut(ext)(param_values);
// Restore with_functions.
if let Some(with_functions) = with_functions_saved {
self.with_functions = with_functions;
Expand Down Expand Up @@ -3458,7 +3459,7 @@ impl Interpreter {
extension: Box<dyn Extension>,
) -> Result<()> {
if let std::collections::hash_map::Entry::Vacant(v) = self.extensions.entry(path) {
v.insert((nargs, extension));
v.insert((nargs, Rc::new(extension)));
Ok(())
} else {
bail!("extension already added");
Expand Down
8 changes: 5 additions & 3 deletions src/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use core::str::CharIndices;
use std::convert::AsRef;
use std::path::Path;

use crate::value::Value;
use crate::Rc;
use crate::Value;

use anyhow::{anyhow, bail, Result};

#[derive(Clone)]
Expand All @@ -20,7 +22,7 @@ struct SourceInternal {

#[derive(Clone)]
pub struct Source {
src: std::rc::Rc<SourceInternal>,
src: Rc<SourceInternal>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -108,7 +110,7 @@ impl Source {
lines.push((s, s));
}
Self {
src: std::rc::Rc::new(SourceInternal {
src: Rc::new(SourceInternal {
file,
contents,
lines,
Expand Down
18 changes: 15 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ mod value;
pub use engine::Engine;
pub use value::Value;

#[cfg(feature = "arc")]
use std::sync::Arc as Rc;

#[cfg(not(feature = "arc"))]
use std::rc::Rc;

/// Location of an [`Expression`] in a Rego query.
///
/// ```
Expand Down Expand Up @@ -68,7 +74,7 @@ pub struct Expression {
pub value: Value,

/// The Rego expression.
pub text: std::rc::Rc<str>,
pub text: Rc<str>,

/// Location of the expression in the query string.
pub location: Location,
Expand Down Expand Up @@ -263,7 +269,7 @@ pub struct QueryResults {
/// A user defined builtin function implementation.
///
/// It is not necessary to implement this trait directly.
pub trait Extension: FnMut(Vec<Value>) -> anyhow::Result<Value> {
pub trait Extension: FnMut(Vec<Value>) -> anyhow::Result<Value> + Send + Sync {
/// Fn, FnMut etc are not sized and cannot be cloned in their boxed form.
/// clone_box exists to overcome that.
fn clone_box<'a>(&self) -> Box<dyn 'a + Extension>
Expand All @@ -274,7 +280,7 @@ pub trait Extension: FnMut(Vec<Value>) -> anyhow::Result<Value> {
/// Automatically make matching closures a valid [`Extension`].
impl<F> Extension for F
where
F: FnMut(Vec<Value>) -> anyhow::Result<Value> + Clone,
F: FnMut(Vec<Value>) -> anyhow::Result<Value> + Clone + Send + Sync,
{
fn clone_box<'a>(&self) -> Box<dyn 'a + Extension>
where
Expand All @@ -291,6 +297,12 @@ impl<'a> Clone for Box<dyn 'a + Extension> {
}
}

impl std::fmt::Debug for dyn Extension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
f.write_fmt(format_args!("<extension>"))
}
}

/// Items in `unstable` are likely to change.
#[doc(hidden)]
pub mod unstable {
Expand Down
3 changes: 2 additions & 1 deletion src/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

use core::fmt::{Debug, Formatter};
use std::cmp::{Ord, Ordering};
use std::rc::Rc;
use std::str::FromStr;

use anyhow::{anyhow, bail, Result};

use serde::ser::Serializer;
use serde::Serialize;

use crate::Rc;

pub type BigInt = i128;

type BigFloat = scientific::Scientific;
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ pub struct Analyzer {
current_module_path: String,
}

#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Schedule {
pub scopes: BTreeMap<Ref<Query>, Scope>,
pub order: BTreeMap<Ref<Query>, Vec<u16>>,
Expand Down
2 changes: 1 addition & 1 deletion src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::ser::{SerializeMap, Serializer};
use serde::{Deserialize, Serialize};

pub type Rc<T> = compact_rc::Rc16<T>;
use crate::Rc;

/// A value in a Rego document.
///
Expand Down
33 changes: 33 additions & 0 deletions tests/arc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use lazy_static::lazy_static;
use std::sync::Mutex;

use regorus::*;

// Ensure that types can be s
lazy_static! {
static ref VALUE: Value = Value::Null;
static ref ENGINE: Mutex<Engine> = Mutex::new(Engine::new());
// static ref ENGINE: Engine = Engine::new();
}

#[test]
fn shared_engine() -> anyhow::Result<()> {
let e_guard = ENGINE.lock();
let mut engine = e_guard.expect("failed to lock engine");

engine.add_policy(
"hello.rego".to_string(),
r#"
package test
allow = true
"#
.to_string(),
)?;

let results = engine.eval_query("data.test.allow".to_string(), false)?;
assert_eq!(results.result[0].expressions[0].value, Value::from(true));
Ok(())
}
3 changes: 3 additions & 0 deletions tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ mod engine;
mod lexer;
mod parser;
mod value;

#[cfg(feature = "arc")]
mod arc;

0 comments on commit 53b990f

Please sign in to comment.