Skip to content

Commit

Permalink
Special cases of refs to data (microsoft#41)
Browse files Browse the repository at this point in the history
1. Numeric indices will be converted to strings for refs beginning with `data`
   if there is not valid numeric key
2. Error out if input document already contains value for a ref

Signed-off-by: Anand Krishnamoorthi <[email protected]>
  • Loading branch information
anakrish authored Nov 11, 2023
1 parent 02c6c6b commit 1c49214
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 87 deletions.
146 changes: 76 additions & 70 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,18 @@ impl<'source> Interpreter<'source> {
path.reverse();
let obj = self.eval_expr(refr)?;
let index = self.eval_expr(index)?;
return Ok(Self::get_value_chained(obj[&index].clone(), &path[..]));
let mut v = obj[&index].clone();
// Qualified references starting with data (e.g data.p.q) can
// be indexed using numbers. The number will be converted to string
// if a matching key exists.
if v == Value::Undefined
&& matches!(index, Value::Number(_))
&& get_root_var(refr)? == "data"
{
let index = index.to_string();
v = obj[&index].clone();
}
return Ok(Self::get_value_chained(v, &path[..]));
}
},
_ => {
Expand Down Expand Up @@ -1798,7 +1809,6 @@ impl<'source> Interpreter<'source> {
}

fn make_rule_context(&self, head: &'source RuleHead) -> Result<(Context<'source>, Vec<Span>)> {
//TODO: include "data" ?
let mut path = Parser::get_path_ref_components(&self.module.unwrap().package.refr)?;

match head {
Expand Down Expand Up @@ -2129,6 +2139,25 @@ impl<'source> Interpreter<'source> {
Ok(())
}

fn update_data(
&mut self,
span: &Span,
_refr: &Expr,
path: &[&str],
value: Value,
) -> Result<()> {
if value == Value::Undefined {
return Ok(());
}
// Ensure that path is created.
let vref = Self::make_or_get_value_mut(&mut self.data, path)?;
if Self::get_value_chained(self.init_data.clone(), path) == Value::Undefined {
Self::merge_value(span, vref, value)
} else {
Err(span.error("value for rule has already been specified in data document"))
}
}

fn eval_rule(&mut self, module: &'source Module, rule: &'source Rule) -> Result<()> {
// Skip reprocessing rule
if self.processed.contains(&Ref::make(rule)) {
Expand Down Expand Up @@ -2175,41 +2204,42 @@ impl<'source> Interpreter<'source> {
head: rule_head,
bodies: rule_body,
} => {
if !matches!(rule_head, RuleHead::Func { .. }) {
let (ctx, mut path) = self.make_rule_context(rule_head)?;
let special_set =
matches!((ctx.output_expr, &ctx.value), (None, Value::Set(_)));
let value = match self.eval_rule_bodies(ctx, span, rule_body)? {
Value::Set(_) if special_set => {
let entry = path[path.len() - 1].text();
let mut s = BTreeSet::new();
s.insert(Value::String(entry.to_owned().to_string()));
path = path[0..path.len() - 1].to_vec();
Value::from_set(s)
}
v => v,
};
if value != Value::Undefined {
match rule_head {
RuleHead::Compr { refr, .. } | RuleHead::Set { refr, .. } => {
let (ctx, mut path) = self.make_rule_context(rule_head)?;
let special_set =
matches!((ctx.output_expr, &ctx.value), (None, Value::Set(_)));
let value = match self.eval_rule_bodies(ctx, span, rule_body)? {
Value::Set(_) if special_set => {
let entry = path[path.len() - 1].text();
let mut s = BTreeSet::new();
s.insert(Value::String(entry.to_owned().to_string()));
path = path[0..path.len() - 1].to_vec();
Value::from_set(s)
}
v => v,
};
let paths: Vec<&str> = path.iter().map(|s| *s.text()).collect();
let vref = Self::make_or_get_value_mut(&mut self.data, &paths[..])?;
Self::merge_value(span, vref, value)?;
}
self.update_data(span, refr, &paths[..], value)?;

self.processed.insert(Ref::make(rule));
} else if let RuleHead::Func { refr, .. } = rule_head {
let mut path =
Parser::get_path_ref_components(&self.current_module()?.package.refr)?;

Parser::get_path_ref_components_into(refr, &mut path)?;
let path: Vec<&str> = path.iter().map(|s| *s.text()).collect();

// Ensure that for functions with a nesting level (e.g: a.foo),
// `a` is created as an empty object.
if path.len() > 1 {
let value =
Self::make_or_get_value_mut(&mut self.data, &path[0..path.len() - 1])?;
if value == &Value::Undefined {
*value = Value::new_object();
self.processed.insert(Ref::make(rule));
}
RuleHead::Func { refr, .. } => {
let mut path =
Parser::get_path_ref_components(&self.current_module()?.package.refr)?;

Parser::get_path_ref_components_into(refr, &mut path)?;
let path: Vec<&str> = path.iter().map(|s| *s.text()).collect();

// Ensure that for functions with a nesting level (e.g: a.foo),
// `a` is created as an empty object.
if path.len() > 1 {
self.update_data(
span,
refr,
&path[0..path.len() - 1],
Value::new_object(),
)?;
}
}
}
Expand Down Expand Up @@ -2248,56 +2278,32 @@ impl<'source> Interpreter<'source> {

if let Some(data) = data {
self.data = data.clone();
self.init_data = data.clone();
}

// Ensure that each module has an empty object
for m in &self.modules {
let path = Parser::get_path_ref_components(&m.package.refr)?;
let path: Vec<&str> = path.iter().map(|s| *s.text()).collect();
let vref = Self::make_or_get_value_mut(&mut self.data, &path[..])?;
if *vref == Value::Undefined {
*vref = Value::new_object();
}
}

self.check_default_rules()?;
self.functions = gather_functions(&self.modules)?;

self.gather_rules()?;

self.init_data = self.data.clone();
self.prepared = true;

Ok(())
}

pub fn eval_module(
&mut self,
module: &'source Module,
input: &Option<Value>,
enable_tracing: bool,
) -> Result<Value> {
pub fn eval_modules(&mut self, input: &Option<Value>, enable_tracing: bool) -> Result<Value> {
self.checks_for_eval(input, enable_tracing)?;
self.clean_internal_evaluation_state();

for rule in &module.policy {
self.eval_rule(module, rule)?;
}

// Defer the evaluation of the default rules to here
let prev_module = self.set_current_module(Some(module))?;
for rule in &module.policy {
self.eval_default_rule(rule)?;
// Ensure that each module has an empty object
for m in &self.modules {
let path = Parser::get_path_ref_components(&m.package.refr)?;
let path: Vec<&str> = path.iter().map(|s| *s.text()).collect();
let vref = Self::make_or_get_value_mut(&mut self.data, &path[..])?;
if *vref == Value::Undefined {
*vref = Value::new_object();
}
}
self.set_current_module(prev_module)?;

Ok(self.data.clone())
}

pub fn eval_modules(&mut self, input: &Option<Value>, enable_tracing: bool) -> Result<Value> {
self.checks_for_eval(input, enable_tracing)?;
self.clean_internal_evaluation_state();

self.check_default_rules()?;
for module in self.modules.clone() {
for rule in &module.policy {
self.eval_rule(module, rule)?;
Expand Down
11 changes: 1 addition & 10 deletions src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,6 @@ fn gather_vars<'a>(
gather_loop_vars(expr, parent_scopes, scope)
}

fn get_rule_prefix(expr: &Expr) -> Result<&str> {
match expr {
Expr::Var(v) => Ok(*v.text()),
Expr::RefDot { refr, .. } => get_rule_prefix(refr),
Expr::RefBrack { refr, .. } => get_rule_prefix(refr),
_ => bail!("internal error: analyzer: could not get rule prefix"),
}
}

pub struct Analyzer<'a> {
packages: BTreeMap<String, Scope<'a>>,
locals: BTreeMap<Ref<'a, Query>, Scope<'a>>,
Expand Down Expand Up @@ -447,7 +438,7 @@ impl<'a> Analyzer<'a> {
| RuleHead::Set { refr, .. }
| RuleHead::Func { refr, .. },
..
} => get_rule_prefix(refr)?,
} => get_root_var(refr)?,
};
scope.locals.insert(var);
}
Expand Down
18 changes: 13 additions & 5 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@ pub fn get_path_string(refr: &Expr, document: Option<&str>) -> Result<String> {
comps.push(&field.text());
expr = Some(refr);
}
Some(Expr::RefBrack { refr, index, .. })
if matches!(index.as_ref(), Expr::String(_)) =>
{
Some(Expr::RefBrack { refr, index, .. }) => {
if let Expr::String(s) = index.as_ref() {
comps.push(&s.text());
expr = Some(refr);
}
expr = Some(refr);
}
Some(Expr::Var(v)) => {
comps.push(&v.text());
expr = None;
}
_ => bail!("internal error: not a simple ref"),
_ => bail!("internal error: not a simple ref {expr:?}"),
}
}
if let Some(d) = document {
Expand Down Expand Up @@ -92,3 +90,13 @@ pub fn gather_functions<'a>(modules: &[&'a Module]) -> Result<FunctionTable<'a>>
}
Ok(table)
}

pub fn get_root_var(mut expr: &Expr) -> Result<&str> {
loop {
match expr {
Expr::Var(v) => return Ok(*v.text()),
Expr::RefDot { refr, .. } | Expr::RefBrack { refr, .. } => expr = refr,
_ => bail!("internal error: analyzer: could not get rule prefix"),
}
}
}
5 changes: 4 additions & 1 deletion src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ impl Value {
}

pub fn merge(&mut self, mut new: Value) -> Result<()> {
if self == &new {
return Ok(());
}
match (self, &mut new) {
(v @ Value::Undefined, _) => *v = new,
(Value::Set(ref mut set), Value::Set(new)) => {
Expand All @@ -351,7 +354,7 @@ impl Value {
};
}
}
_ => bail!("internal error: could not merge value"),
_ => bail!("error: could not merge value"),
};
Ok(())
}
Expand Down
69 changes: 69 additions & 0 deletions tests/interpreter/cases/data/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

cases:
- note: numbers are converted to string as needed when indexing data
data:
"1": "hello"
"1.2": "world"
test:
"2": 100
"2.2": 200
play:
"2": 100
modules:
- |
package test
p = v {
q = data.play
# 2 is not converted to "2" since refr doesn't being with `data`
v = q[2]
}
a = [
data[1],
data[1.2],
data.test[2],
data.test[2.2]
]
query: data.test
want_result:
"2": 100
"2.2": 200
a:
- "hello"
- "world"
- 100
- 200

- note: overriding refs in data produces error
data:
test:
rule1: 0
modules:
- |
package test
rule1 = 6
query: data.test
error: value for rule has already been specified

- note: rule named data
data:
test:
rule1: 8
modules:
- |
package test
data.test.rule1 = 9
data.test.rule1 = 9
query: data.test
want_result:
rule1: 8
data:
test:
rule1: 9


2 changes: 1 addition & 1 deletion tests/opa/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn run_opa_tests() -> Result<()> {
} else {
for (i, m) in modules.iter().enumerate() {
std::fs::write(
path.join(format!("rego{n}_{i}.json")),
path.join(format!("rego{n}_{i}.rego")),
m.as_bytes(),
)?;
}
Expand Down

0 comments on commit 1c49214

Please sign in to comment.