Skip to content

Commit

Permalink
Order query expression results (microsoft#32)
Browse files Browse the repository at this point in the history
Expressions are scheduled based on dependencies and thusthe gathered
expression values may not be in the same order as in source.

Reorder to match the source.

Signed-off-by: Anand Krishnamoorthi <[email protected]>
  • Loading branch information
anakrish authored Oct 31, 2023
1 parent 519cce5 commit 6228eaa
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 35 deletions.
21 changes: 5 additions & 16 deletions examples/regorus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use clap::{Parser, Subcommand};
fn rego_eval(
files: &[String],
input: Option<String>,
query: Option<String>,
query: String,
enable_tracing: bool,
) -> Result<()> {
// User specified data.
Expand Down Expand Up @@ -83,16 +83,10 @@ fn rego_eval(
// Evaluate all the modules.
interpreter.eval(&Some(data), &input, false, Some(schedule))?;

// Fetch query string. If none specified, use "data".
let query = match &query {
Some(query) => query,
_ => "data",
};

// Parse the query.
let query_source = regorus::Source {
file: "<query.rego>",
contents: query,
contents: &query,
lines: query.split('\n').collect(),
};
let query_span = regorus::Span {
Expand All @@ -107,7 +101,7 @@ fn rego_eval(
let query_schedule = regorus::Analyzer::new().analyze_query_snippet(&modules, &query_node)?;

let results = interpreter.eval_user_query(&query_node, &query_schedule, enable_tracing)?;
println!("eval results:\n{}", serde_json::to_string_pretty(&results)?);
println!("{}", serde_json::to_string_pretty(&results)?);

Ok(())
}
Expand Down Expand Up @@ -168,20 +162,15 @@ enum RegorusCommand {
/// Evaluate a Rego Query.
Eval {
/// Policy or data files. Rego, json or yaml.
#[arg(
required(true),
long,
short,
value_name = "policy.rego|data.json|data.yaml"
)]
#[arg(long, short, value_name = "policy.rego|data.json|data.yaml")]
data: Vec<String>,

/// Input file. json or yaml.
#[arg(long, short, value_name = "input.rego")]
input: Option<String>,

/// Query. Rego query block.
query: Option<String>,
query: String,

/// Enable tracing.
#[arg(long, short)]
Expand Down
76 changes: 59 additions & 17 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl Default for QueryResult {

#[derive(Debug, Clone, Default, Serialize)]
pub struct QueryResults {
pub results: Vec<QueryResult>,
pub result: Vec<QueryResult>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -782,6 +782,27 @@ impl<'source> Interpreter<'source> {
Ok(count > 0)
}

fn make_expression_result(span: &Span, v: &Value) -> Value {
let mut loc = BTreeMap::new();
loc.insert(
Value::String("row".to_string()),
Value::from_float(span.line as f64),
);
loc.insert(
Value::String("col".to_string()),
Value::from_float(span.col as f64),
);

let mut expr = BTreeMap::new();
expr.insert(Value::String("value".to_string()), v.clone());
expr.insert(Value::String("location".to_string()), Value::from_map(loc));
expr.insert(
Value::String("text".to_string()),
Value::String(span.text().to_string()),
);
Value::from_map(expr)
}

fn eval_stmt(
&mut self,
stmt: &'source LiteralStmt<'source>,
Expand Down Expand Up @@ -820,7 +841,7 @@ impl<'source> Interpreter<'source> {
}

let r = Ok(match &stmt.literal {
Literal::Expr { expr, .. } => {
Literal::Expr { span, expr, .. } => {
let value = match expr {
Expr::Call { span, fcn, params } => self.eval_call(
span,
Expand All @@ -834,7 +855,9 @@ impl<'source> Interpreter<'source> {

if let Some(ctx) = self.contexts.last_mut() {
if let Some(result) = &mut ctx.result {
result.expressions.push(value.clone());
result
.expressions
.push(Self::make_expression_result(span, &value))
}
}

Expand All @@ -847,7 +870,7 @@ impl<'source> Interpreter<'source> {
value != Value::Undefined
}
}
Literal::NotExpr { expr, .. } => {
Literal::NotExpr { span, expr, .. } => {
let value = match expr {
// Extra parameter is allowed; but a return argument is not allowed.
Expr::Call { span, fcn, params } => self.eval_call(
Expand All @@ -861,13 +884,15 @@ impl<'source> Interpreter<'source> {
};
if let Some(ctx) = self.contexts.last_mut() {
if let Some(result) = &mut ctx.result {
result.expressions.push(Value::Bool(true));
result
.expressions
.push(Self::make_expression_result(span, &Value::Bool(true)))
}
}
// https://github.com/open-policy-agent/opa/issues/1622#issuecomment-520547385
matches!(value, Value::Bool(false) | Value::Undefined)
}
Literal::SomeVars { vars, .. } => {
Literal::SomeVars { span, vars, .. } => {
for var in vars {
let name = var.text();
if let Ok(variable) = self.add_variable_or(name) {
Expand All @@ -881,7 +906,9 @@ impl<'source> Interpreter<'source> {
}
if let Some(ctx) = self.contexts.last_mut() {
if let Some(result) = &mut ctx.result {
result.expressions.push(Value::Bool(true));
result
.expressions
.push(Self::make_expression_result(span, &Value::Bool(true)))
}
}
true
Expand All @@ -894,7 +921,9 @@ impl<'source> Interpreter<'source> {
} => {
if let Some(ctx) = self.contexts.last_mut() {
if let Some(result) = &mut ctx.result {
result.expressions.push(Value::Bool(true));
result
.expressions
.push(Self::make_expression_result(span, &Value::Bool(true)))
}
}
self.eval_some_in(span, key, value, collection, stmts)?
Expand All @@ -908,7 +937,9 @@ impl<'source> Interpreter<'source> {
} => {
if let Some(ctx) = self.contexts.last_mut() {
if let Some(result) = &mut ctx.result {
result.expressions.push(Value::Bool(true));
result
.expressions
.push(Self::make_expression_result(span, &Value::Bool(true)))
}
}
self.eval_every(span, key, value, domain, query)?
Expand Down Expand Up @@ -1094,7 +1125,7 @@ impl<'source> Interpreter<'source> {
.insert(Value::String(name.to_string()), value.clone());
}
}
ctx.results.results.push(result);
ctx.results.result.push(result);
}

return Ok(true);
Expand Down Expand Up @@ -2300,22 +2331,33 @@ impl<'source> Interpreter<'source> {
// Eval the query.
let query_r = self.eval_query(query);

let mut results = match self.contexts.pop() {
Some(ctx) => ctx.results,
_ => bail!("internal error: no context"),
};

// Restore schedules.
if let Some(self_schedule) = &mut self.schedule {
for (k, _) in schedule.order.iter() {
for (k, ord) in schedule.order.iter() {
if k == &query {
for idx in 0..results.result.len() {
let mut ordered_expressions = vec![Value::Undefined; ord.len()];
for (expr_idx, value) in results.result[idx].expressions.iter().enumerate()
{
let orig_idx = ord[expr_idx] as usize;
ordered_expressions[orig_idx] = value.clone();
}
results.result[idx].expressions = ordered_expressions;
}
}
self_schedule.order.remove(k);
}
}

self.set_current_module(prev_module)?;

let r = match self.contexts.pop() {
Some(ctx) => Ok(ctx.results),
_ => bail!("internal error: no context"),
};

match query_r {
Ok(_) => r,
Ok(_) => Ok(results),
Err(e) => Err(e),
}
}
Expand Down
4 changes: 2 additions & 2 deletions tests/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,12 @@ pub fn check_output(computed_results: &[Value], expected_results: &[Value]) -> R
}

fn query_results_to_value(query_results: QueryResults) -> Result<Value> {
if let Some(query_result) = query_results.results.last() {
if let Some(query_result) = query_results.result.last() {
if !query_result.bindings.is_empty_object() {
return Ok(query_result.bindings.clone());
} else {
return match query_result.expressions.last() {
Some(v) => Ok(v.clone()),
Some(v) => Ok(v["value"].clone()),
_ => bail!("no expressions in query results"),
};
}
Expand Down

0 comments on commit 6228eaa

Please sign in to comment.