Skip to content

Commit

Permalink
Scheduling of statements in user queries (microsoft#31)
Browse files Browse the repository at this point in the history
Nested queries (comprehensions) are handled correctly.

Signed-off-by: Anand Krishnamoorthi <[email protected]>
  • Loading branch information
anakrish authored Oct 31, 2023
1 parent 7a3d5e7 commit 519cce5
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 64 deletions.
8 changes: 4 additions & 4 deletions examples/regorus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ fn rego_eval(
let mut interpreter = regorus::Interpreter::new(modules_ref)?;

// Prepare for evalution.
interpreter.prepare_for_eval(Some(&schedule), &Some(data.clone()))?;
interpreter.prepare_for_eval(Some(schedule.clone()), &Some(data.clone()))?;

// Evaluate all the modules.
interpreter.eval(&Some(data), &input, false, Some(&schedule))?;
interpreter.eval(&Some(data), &input, false, Some(schedule))?;

// Fetch query string. If none specified, use "data".
let query = match &query {
Expand All @@ -104,9 +104,9 @@ fn rego_eval(
};
let mut parser = regorus::Parser::new(&query_source)?;
let query_node = parser.parse_query(query_span, "")?;
let stmt_order = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;
let query_schedule = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;

let results = interpreter.eval_user_query(&query_node, &stmt_order, enable_tracing)?;
let results = interpreter.eval_user_query(&query_node, &query_schedule, enable_tracing)?;
println!("eval results:\n{}", serde_json::to_string_pretty(&results)?);

Ok(())
Expand Down
46 changes: 32 additions & 14 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Scope = BTreeMap<String, Value>;
pub struct Interpreter<'source> {
modules: Vec<&'source Module<'source>>,
module: Option<&'source Module<'source>>,
schedule: Option<&'source Schedule<'source>>,
schedule: Option<Schedule<'source>>,
current_module_path: String,
prepared: bool,
input: Value,
Expand Down Expand Up @@ -1218,6 +1218,7 @@ impl<'source> Interpreter<'source> {
} else {
query.stmts.iter().collect()
};

let r = self.eval_stmts(&ordered_stmts);
self.scopes.pop();
r
Expand Down Expand Up @@ -1649,7 +1650,7 @@ impl<'source> Interpreter<'source> {
}
}
Ok(Self::get_value_chained(self.data.clone(), fields))
} else {
} else if !self.modules.is_empty() {
// Add module prefix and ensure that any matching rule is evaluated.
let module_path =
Self::get_path_string(&self.current_module()?.package.refr, Some("data"))?;
Expand All @@ -1666,6 +1667,8 @@ impl<'source> Interpreter<'source> {

let value = Self::get_value_chained(self.data.clone(), &path[..]);
Ok(Self::get_value_chained(value, fields))
} else {
Ok(Value::Undefined)
}
}

Expand Down Expand Up @@ -2178,7 +2181,7 @@ impl<'source> Interpreter<'source> {

pub fn prepare_for_eval(
&mut self,
schedule: Option<&'source Schedule<'source>>,
schedule: Option<Schedule<'source>>,
data: &Option<Value>,
) -> Result<()> {
self.schedule = schedule;
Expand Down Expand Up @@ -2258,7 +2261,7 @@ impl<'source> Interpreter<'source> {
data: &Option<Value>,
input: &Option<Value>,
enable_tracing: bool,
schedule: Option<&'source Schedule<'source>>,
schedule: Option<Schedule<'source>>,
) -> Result<Value> {
self.prepare_for_eval(schedule, data)?;
self.eval_modules(input, enable_tracing)
Expand All @@ -2267,17 +2270,20 @@ impl<'source> Interpreter<'source> {
pub fn eval_user_query(
&mut self,
query: &'source Query<'source>,
order: &[u16],
schedule: &Schedule<'source>,
enable_tracing: bool,
) -> Result<QueryResults> {
self.traces = match enable_tracing {
true => Some(vec![]),
false => None,
};

// Create a new scope for evaluating the expression.
self.scopes.push(Scope::new());
let prev_module = self.set_current_module(self.modules.last().copied())?;
// Add schedules for queries.
if let Some(self_schedule) = &mut self.schedule {
for (k, v) in schedule.order.iter() {
self_schedule.order.insert(k, v.clone());
}
}

// Push new context.
self.contexts.push(Context {
Expand All @@ -2289,16 +2295,28 @@ impl<'source> Interpreter<'source> {
results: QueryResults::default(),
});

let ordered_stmts: Vec<&'source LiteralStmt<'source>> =
order.iter().map(|i| &query.stmts[*i as usize]).collect();
let _value = self.eval_stmts(&ordered_stmts);
let prev_module = self.set_current_module(self.modules.last().copied())?;

// Eval the query.
let query_r = self.eval_query(query);

// Restore schedules.
if let Some(self_schedule) = &mut self.schedule {
for (k, _) in schedule.order.iter() {
self_schedule.order.remove(k);
}
}

// Pop the scope.
let _scope = self.scopes.pop();
self.set_current_module(prev_module)?;
match self.contexts.pop() {

let r = match self.contexts.pop() {
Some(ctx) => Ok(ctx.results),
_ => bail!("internal error: no context"),
};

match query_r {
Ok(_) => r,
Err(e) => Err(e),
}
}

Expand Down
13 changes: 7 additions & 6 deletions src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ pub struct Analyzer<'a> {
order: BTreeMap<&'a Query<'a>, Vec<u16>>,
}

#[derive(Clone)]
pub struct Schedule<'a> {
pub scopes: BTreeMap<&'a Query<'a>, Scope<'a>>,
pub order: BTreeMap<&'a Query<'a>, Vec<u16>>,
Expand Down Expand Up @@ -420,14 +421,14 @@ impl<'a> Analyzer<'a> {
mut self,
modules: &'a [Module<'a>],
query: &'a Query<'a>,
) -> Result<Vec<u16>> {
) -> Result<Schedule<'a>> {
self.add_rules(modules)?;
self.analyze_query(None, None, query, Scope::default())?;
Ok(self
.order
.get(query)
.expect("could not schedule user query")
.clone())

Ok(Schedule {
scopes: self.locals,
order: self.order,
})
}

fn add_rules(&mut self, modules: &'a [Module<'a>]) -> Result<()> {
Expand Down
82 changes: 42 additions & 40 deletions tests/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,6 @@ pub fn eval_file_first_rule(
let mut modules = vec![];
let mut modules_ref = vec![];

let query_source = regorus::Source {
file: "<query.rego>",
contents: query,
lines: query.split('\n').collect(),
};
let query_span = regorus::Span {
source: &query_source,
line: 1,
col: 1,
start: 0,
end: query.len() as u16,
};
let mut parser = regorus::Parser::new(&query_source)?;
let query_node = parser.parse_query(query_span, "")?;
let query_stmt_order = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;
for (idx, _) in regos.iter().enumerate() {
files.push(format!("rego_{idx}"));
}
Expand All @@ -245,13 +230,28 @@ pub fn eval_file_first_rule(
modules_ref.push(m);
}

let query_source = regorus::Source {
file: "<query.rego>",
contents: query,
lines: query.split('\n').collect(),
};
let query_span = regorus::Span {
source: &query_source,
line: 1,
col: 1,
start: 0,
end: query.len() as u16,
};
let mut parser = regorus::Parser::new(&query_source)?;
let query_node = parser.parse_query(query_span, "")?;
let query_schedule = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;
let analyzer = Analyzer::new();
let schedule = analyzer.analyze(&modules)?;

let mut interpreter = interpreter::Interpreter::new(modules_ref)?;
if let Some(input) = input_opt {
// if inputs are defined then first the evaluation if prepared
interpreter.prepare_for_eval(Some(&schedule), &data_opt)?;
interpreter.prepare_for_eval(Some(schedule), &data_opt)?;

// then all modules are evaluated for each input
let mut inputs = vec![];
Expand All @@ -270,18 +270,18 @@ pub fn eval_file_first_rule(
// Now eval the query.
results.push(query_results_to_value(interpreter.eval_user_query(
&query_node,
&query_stmt_order,
&query_schedule,
enable_tracing,
)?)?);
}
} else {
// it no input is defined then one evaluation of all modules is performed
interpreter.eval(&data_opt, &None, enable_tracing, Some(&schedule))?;
interpreter.eval(&data_opt, &None, enable_tracing, Some(schedule))?;

// Now eval the query.
results.push(query_results_to_value(interpreter.eval_user_query(
&query_node,
&query_stmt_order,
&query_schedule,
enable_tracing,
)?)?);
}
Expand All @@ -302,22 +302,6 @@ pub fn eval_file(
let mut modules = vec![];
let mut modules_ref = vec![];

let query_source = regorus::Source {
file: "<query.rego>",
contents: query,
lines: query.split('\n').collect(),
};
let query_span = regorus::Span {
source: &query_source,
line: 1,
col: 1,
start: 0,
end: query.len() as u16,
};
let mut parser = regorus::Parser::new(&query_source)?;
let query_node = parser.parse_query(query_span, "")?;
let query_stmt_order = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;

for (idx, _) in regos.iter().enumerate() {
files.push(format!("rego_{idx}"));
}
Expand All @@ -340,13 +324,29 @@ pub fn eval_file(
modules_ref.push(m);
}

let query_source = regorus::Source {
file: "<query.rego>",
contents: query,
lines: query.split('\n').collect(),
};
let query_span = regorus::Span {
source: &query_source,
line: 1,
col: 1,
start: 0,
end: query.len() as u16,
};
let mut parser = regorus::Parser::new(&query_source)?;
let query_node = parser.parse_query(query_span, "")?;
let query_schedule = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;

let analyzer = Analyzer::new();
let schedule = analyzer.analyze(&modules)?;

let mut interpreter = interpreter::Interpreter::new(modules_ref)?;
if let Some(input) = input_opt {
// if inputs are defined then first the evaluation if prepared
interpreter.prepare_for_eval(Some(&schedule), &data_opt)?;
interpreter.prepare_for_eval(Some(schedule), &data_opt)?;

// then all modules are evaluated for each input
let mut inputs = vec![];
Expand All @@ -361,18 +361,18 @@ pub fn eval_file(
// Now eval the query.
results.push(query_results_to_value(interpreter.eval_user_query(
&query_node,
&query_stmt_order,
&query_schedule,
enable_tracing,
)?)?);
}
} else {
// it no input is defined then one evaluation of all modules is performed
interpreter.eval(&data_opt, &None, enable_tracing, Some(&schedule))?;
interpreter.eval(&data_opt, &None, enable_tracing, Some(schedule))?;

// Now eval the query.
results.push(query_results_to_value(interpreter.eval_user_query(
&query_node,
&query_stmt_order,
&query_schedule,
enable_tracing,
)?)?);
}
Expand Down Expand Up @@ -584,7 +584,9 @@ fn run_opa_tests() -> Result<()> {
}

if !failures.is_empty() {
dbg!(failures);
for (f, e) in failures {
println!("{f} failed.\n{e}");
}
panic!("failed");
}
Ok(())
Expand Down

0 comments on commit 519cce5

Please sign in to comment.