diff --git a/src/builtins/objects.rs b/src/builtins/objects.rs index 830150d6..be14206e 100644 --- a/src/builtins/objects.rs +++ b/src/builtins/objects.rs @@ -3,7 +3,7 @@ use crate::ast::{Expr, Ref}; use crate::builtins; -use crate::builtins::utils::{ensure_args_count, ensure_object}; +use crate::builtins::utils::{ensure_args_count, ensure_array, ensure_object}; use crate::lexer::Span; use crate::value::Value; @@ -21,6 +21,8 @@ pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { m.insert("object.keys", (keys, 1)); m.insert("object.remove", (remove, 2)); m.insert("object.subset", (subset, 2)); + m.insert("object.union", (object_union, 2)); + m.insert("object.union_n", (object_union_n, 1)); } fn json_filter_impl(v: &Value, filter: &Value) -> Value { @@ -324,3 +326,59 @@ fn subset(span: &Span, params: &[Ref], args: &[Value], _strict: bool) -> R Ok(Value::Bool(is_subset(&args[0], &args[1]))) } + +fn union(obj1: &Value, obj2: &Value) -> Result { + match (obj1, obj2) { + (Value::Object(m1), Value::Object(m2)) => { + let mut u = obj1.clone(); + let um = u.as_object_mut()?; + + for (key2, value2) in m2.iter() { + let vm = match m1.get(key2) { + Some(value1) => union(value1, value2)?, + _ => value2.clone(), + }; + um.insert(key2.clone(), vm); + } + Ok(u) + } + _ => Ok(obj2.clone()), + } +} + +fn object_union(span: &Span, params: &[Ref], args: &[Value], _strict: bool) -> Result { + let name = "object.union"; + ensure_args_count(span, name, params, args, 2)?; + + let _ = ensure_object(name, ¶ms[0], args[0].clone())?; + let _ = ensure_object(name, ¶ms[1], args[1].clone())?; + + union(&args[0], &args[1]) +} + +fn object_union_n( + span: &Span, + params: &[Ref], + args: &[Value], + strict: bool, +) -> Result { + let name = "object.union_n"; + ensure_args_count(span, name, params, args, 1)?; + + let arr = ensure_array(name, ¶ms[0], args[0].clone())?; + + let mut u = Value::new_object(); + for (idx, a) in arr.iter().enumerate() { + if a.as_object().is_err() { + if strict { + bail!(params[0] + .span() + .error(&format!("item at index {idx} is not an object"))); + } + return Ok(Value::Undefined); + } + u = union(&u, a)?; + } + + Ok(u) +} diff --git a/src/engine.rs b/src/engine.rs index 7183373b..3d979a0b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -164,23 +164,26 @@ impl Engine { pub fn eval_query(&mut self, query: String, enable_tracing: bool) -> Result { self.eval_modules(false)?; + let query_module = { + let source = Source::new( + "".to_owned(), + "package __internal_query_module".to_owned(), + ); + Ref::new(Parser::new(&source)?.parse()?) + }; + // Parse the query. - let query_len = query.len(); let query_source = Source::new("".to_string(), query); - let query_span = Span { - source: query_source.clone(), - line: 1, - col: 1, - start: 0, - end: query_len as u16, - }; let mut parser = Parser::new(&query_source)?; - let query_node = Ref::new(parser.parse_query(query_span, "")?); + let query_node = parser.parse_user_query()?; let query_schedule = Analyzer::new().analyze_query_snippet(&self.modules, &query_node)?; - let results = - self.interpreter - .eval_user_query(&query_node, &query_schedule, enable_tracing)?; + let results = self.interpreter.eval_user_query( + &query_module, + &query_node, + &query_schedule, + enable_tracing, + )?; Ok(results) } } diff --git a/src/interpreter.rs b/src/interpreter.rs index bcd67610..e67cbc4c 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -2727,6 +2727,7 @@ impl Interpreter { pub fn eval_user_query( &mut self, + module: &Ref, query: &Ref, schedule: &Schedule, enable_tracing: bool, @@ -2754,7 +2755,7 @@ impl Interpreter { is_compr: false, }); - let prev_module = self.set_current_module(self.modules.last().cloned())?; + let prev_module = self.set_current_module(Some(module.clone()))?; // Eval the query. let query_r = self.eval_query(query); @@ -2795,6 +2796,14 @@ impl Interpreter { self.set_current_module(prev_module)?; + if let Some(r) = results.result.last() { + if r.bindings.is_empty_object() + && r.expressions.iter().any(|e| e.value == Value::Bool(false)) + { + results = QueryResults::default(); + } + } + match query_r { Ok(_) => Ok(results), Err(e) => Err(e), diff --git a/src/parser.rs b/src/parser.rs index 20a3cafa..0cdce70a 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -620,10 +620,21 @@ impl<'source> Parser<'source> { let op = match self.token_text() { "+" => ArithOp::Add, "-" => ArithOp::Sub, + n if n.starts_with('-') && self.tok.0 == TokenKind::Number => ArithOp::Sub, _ => return Ok(expr), }; - self.next_token()?; - let right = self.parse_mul_div_mod_expr()?; + let right = if self.token_text().len() > 1 { + // Treat the - as a separate token + let mut rhs_span = self.tok.1.clone(); + rhs_span.start += 1; + rhs_span.col += 1; + + self.next_token()?; + Expr::Number(rhs_span) + } else { + self.next_token()?; + self.parse_mul_div_mod_expr()? + }; span.end = self.end; expr = Expr::ArithExpr { span, @@ -953,7 +964,7 @@ impl<'source> Parser<'source> { }) } - pub fn parse_query(&mut self, mut span: Span, end_delim: &str) -> Result { + fn parse_query(&mut self, mut span: Span, end_delim: &str) -> Result { let state = self.clone(); let is_definite_query = matches!(self.token_text(), "some" | "every"); @@ -1485,7 +1496,7 @@ impl<'source> Parser<'source> { Ok(Rule::Spec { span, head, bodies }) } - fn parse_package(&mut self) -> Result { + pub fn parse_package(&mut self) -> Result { let mut span = self.tok.1.clone(); self.expect("package", "Missing package declaration.")?; let name = self.parse_path_ref()?; @@ -1609,4 +1620,13 @@ impl<'source> Parser<'source> { policy, }) } + + pub fn parse_user_query(&mut self) -> Result> { + let span = self.tok.1.clone(); + let query = Ref::new(self.parse_query(span, "")?); + if self.tok.0 != TokenKind::Eof { + bail!(self.tok.1.error("expecting EOF")); + } + Ok(query) + } } diff --git a/tests/interpreter/cases/arithmetic/tests.yaml b/tests/interpreter/cases/arithmetic/tests.yaml new file mode 100644 index 00000000..2eb706b0 --- /dev/null +++ b/tests/interpreter/cases/arithmetic/tests.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +cases: + - note: negative-integer-literal-in-arithmetic-expressions + data: {} + modules: + - | + # In the following, the negative integer must be broken into a - and an integer tokens + # when in arighmetic expression contexts + package test + a = 1+1-1 + b = 1 +1 -1 + c = 1 + 1 - 1 + d = -1 -1 + query: data.test + want_result: + a: 1 + b: 1 + c: 1 + d: -2 + diff --git a/tests/interpreter/cases/engine/tests.yaml b/tests/interpreter/cases/engine/tests.yaml new file mode 100644 index 00000000..c25184db --- /dev/null +++ b/tests/interpreter/cases/engine/tests.yaml @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +cases: + - note: trailing whitespace in query + data: {} + modules: ["package test"] + query: "1 + 1 -2 " + want_result: 0 + + - note: trailing chars in query + data: {} + modules: ["package test"] + query: "[1]]" + error: expecting EOF + + - note: trailing expressions in query + data: {} + modules: ["package test"] + query: "1 2" + error: expecting EOF + + - note: multiple statements in query + data: {} + modules: ["package test"] + query: | + a = [1, 2, 3] + true + y = 1 + 1 + want_result: + a: [1, 2, 3] + y: 2 + + - note: comprehensions in query + data: {} + modules: ["package test"] + query: | + true + [1, 2, 3][_] + want_result: + many!: + - [true, 1] + - [true, 2] + - [true, 3] + diff --git a/tests/interpreter/mod.rs b/tests/interpreter/mod.rs index 49ab413f..2047810e 100644 --- a/tests/interpreter/mod.rs +++ b/tests/interpreter/mod.rs @@ -83,6 +83,7 @@ fn match_values(computed: &Value, expected: &Value) -> Result<()> { pub fn check_output(computed_results: &[Value], expected_results: &[Value]) -> Result<()> { if computed_results.len() != expected_results.len() { + dbg!((&computed_results, &expected_results)); bail!( "the number of computed results ({}) and expected results ({}) is not equal", computed_results.len(), @@ -108,11 +109,25 @@ pub fn check_output(computed_results: &[Value], expected_results: &[Value]) -> R } fn push_query_results(query_results: QueryResults, results: &mut Vec) { - if let Some(query_result) = query_results.result.last() { - if !query_result.bindings.is_empty_object() { - results.push(query_result.bindings.clone()); - } else if let Some(v) = query_result.expressions.last() { - results.push(v.value.clone()); + if query_results.result.len() == 1 { + if let Some(query_result) = query_results.result.last() { + if !query_result.bindings.is_empty_object() { + results.push(query_result.bindings.clone()); + } else { + for e in query_result.expressions.iter() { + results.push(e.value.clone()); + } + } + } + } else { + for r in query_results.result.iter() { + if !r.bindings.is_empty_object() { + results.push(r.bindings.clone()); + } else { + results.push(Value::from_array( + r.expressions.iter().map(|e| e.value.clone()).collect(), + )); + } } } } diff --git a/tests/opa.passing b/tests/opa.passing index 56116b95..03e19305 100644 --- a/tests/opa.passing +++ b/tests/opa.passing @@ -48,6 +48,7 @@ indirectreferences inputvalues intersection invalidkeyerror +jsonbuiltins jsonfilter jsonfilteridempotent jsonremove @@ -67,7 +68,11 @@ objectkeys objectremove objectremoveidempotent objectremovenonstringkey +objectunion +objectunionn partialdocconstants +partialiter +partialobjectdoc partialsetdoc planner-ir rand diff --git a/tests/opa.rs b/tests/opa.rs index 491c03fe..d68a1355 100644 --- a/tests/opa.rs +++ b/tests/opa.rs @@ -170,7 +170,11 @@ fn run_opa_tests(opa_tests_dir: String, folders: &[String]) -> Result<()> { entry.0 += 1; } // TODO: Handle tests that specify both want_result and strict_error - (Err(_), _) if case.want_error.is_some() => { + (Err(_), _) + if case.want_error.is_some() + || case.strict_error == Some(true) + || case.want_error_code.is_some() => + { // Expected failure. entry.0 += 1; }