diff --git a/query-grammar/src/query_grammar.rs b/query-grammar/src/query_grammar.rs index 9d8c411342..92d7738b20 100644 --- a/query-grammar/src/query_grammar.rs +++ b/query-grammar/src/query_grammar.rs @@ -7,7 +7,7 @@ use nom::character::complete::{ }; use nom::combinator::{eof, map, map_res, opt, peek, recognize, value, verify}; use nom::error::{Error, ErrorKind}; -use nom::multi::{many0, many1, separated_list0, separated_list1}; +use nom::multi::{many0, many1, separated_list0}; use nom::sequence::{delimited, preceded, separated_pair, terminated, tuple}; use nom::IResult; @@ -786,27 +786,23 @@ fn binary_operand(inp: &str) -> IResult<&str, BinaryOperand> { } fn aggregate_binary_expressions( - left: UserInputAst, - others: Vec<(BinaryOperand, UserInputAst)>, -) -> UserInputAst { - let mut dnf: Vec> = vec![vec![left]]; - for (operator, operand_ast) in others { - match operator { - BinaryOperand::And => { - if let Some(last) = dnf.last_mut() { - last.push(operand_ast); - } - } - BinaryOperand::Or => { - dnf.push(vec![operand_ast]); - } - } - } - if dnf.len() == 1 { - UserInputAst::and(dnf.into_iter().next().unwrap()) //< safe + left: (Option, UserInputAst), + others: Vec<(Option, Option, UserInputAst)>, +) -> Result { + let mut leafs = Vec::with_capacity(others.len() + 1); + leafs.push((None, left.0, Some(left.1))); + leafs.extend( + others + .into_iter() + .map(|(operand, occur, ast)| (operand, occur, Some(ast))), + ); + // the parameters we pass should statically guarantee we can't get errors + // (no prefix BinaryOperand is provided) + let (res, mut errors) = aggregate_infallible_expressions(leafs); + if errors.is_empty() { + Ok(res) } else { - let conjunctions = dnf.into_iter().map(UserInputAst::and).collect(); - UserInputAst::or(conjunctions) + Err(errors.swap_remove(0)) } } @@ -822,30 +818,10 @@ fn aggregate_infallible_expressions( return (UserInputAst::empty_query(), err); } - let use_operand = leafs.iter().any(|(operand, _, _)| operand.is_some()); - let all_operand = leafs - .iter() - .skip(1) - .all(|(operand, _, _)| operand.is_some()); let early_operand = leafs .iter() .take(1) .all(|(operand, _, _)| operand.is_some()); - let use_occur = leafs.iter().any(|(_, occur, _)| occur.is_some()); - - if use_operand && use_occur { - err.push(LenientErrorInternal { - pos: 0, - message: "Use of mixed occur and boolean operator".to_string(), - }); - } - - if use_operand && !all_operand { - err.push(LenientErrorInternal { - pos: 0, - message: "Missing boolean operator".to_string(), - }); - } if early_operand { err.push(LenientErrorInternal { @@ -872,7 +848,15 @@ fn aggregate_infallible_expressions( Some(BinaryOperand::And) => Some(Occur::Must), _ => Some(Occur::Should), }; - clauses.push(vec![(occur.or(default_op), ast.clone())]); + if occur == &Some(Occur::MustNot) && default_op == Some(Occur::Should) { + // if occur is MustNot *and* operation is OR, we synthetize a ShouldNot + clauses.push(vec![( + Some(Occur::Should), + ast.clone().unary(Occur::MustNot), + )]) + } else { + clauses.push(vec![(occur.or(default_op), ast.clone())]); + } } None => { let default_op = match next_operator { @@ -880,7 +864,15 @@ fn aggregate_infallible_expressions( Some(BinaryOperand::Or) => Some(Occur::Should), None => None, }; - clauses.push(vec![(occur.or(default_op), ast.clone())]) + if occur == &Some(Occur::MustNot) && default_op == Some(Occur::Should) { + // if occur is MustNot *and* operation is OR, we synthetize a ShouldNot + clauses.push(vec![( + Some(Occur::Should), + ast.clone().unary(Occur::MustNot), + )]) + } else { + clauses.push(vec![(occur.or(default_op), ast.clone())]) + } } } } @@ -897,7 +889,12 @@ fn aggregate_infallible_expressions( } } Some(BinaryOperand::Or) => { - clauses.push(vec![(last_occur.or(Some(Occur::Should)), last_ast)]); + if last_occur == Some(Occur::MustNot) { + // if occur is MustNot *and* operation is OR, we synthetize a ShouldNot + clauses.push(vec![(Some(Occur::Should), last_ast.unary(Occur::MustNot))]); + } else { + clauses.push(vec![(last_occur.or(Some(Occur::Should)), last_ast)]); + } } None => clauses.push(vec![(last_occur, last_ast)]), } @@ -923,35 +920,29 @@ fn aggregate_infallible_expressions( } } -fn operand_leaf(inp: &str) -> IResult<&str, (BinaryOperand, UserInputAst)> { - tuple(( - terminated(binary_operand, multispace0), - terminated(boosted_leaf, multispace0), - ))(inp) +fn operand_leaf(inp: &str) -> IResult<&str, (Option, Option, UserInputAst)> { + map( + tuple(( + terminated(opt(binary_operand), multispace0), + terminated(occur_leaf, multispace0), + )), + |(operand, (occur, ast))| (operand, occur, ast), + )(inp) } fn ast(inp: &str) -> IResult<&str, UserInputAst> { - let boolean_expr = map( - separated_pair(boosted_leaf, multispace1, many1(operand_leaf)), + let boolean_expr = map_res( + separated_pair(occur_leaf, multispace1, many1(operand_leaf)), |(left, right)| aggregate_binary_expressions(left, right), ); - let whitespace_separated_leaves = map(separated_list1(multispace1, occur_leaf), |subqueries| { - if subqueries.len() == 1 { - let (occur_opt, ast) = subqueries.into_iter().next().unwrap(); - match occur_opt.unwrap_or(Occur::Should) { - Occur::Must | Occur::Should => ast, - Occur::MustNot => UserInputAst::Clause(vec![(Some(Occur::MustNot), ast)]), - } + let single_leaf = map(occur_leaf, |(occur, ast)| { + if occur == Some(Occur::MustNot) { + ast.unary(Occur::MustNot) } else { - UserInputAst::Clause(subqueries.into_iter().collect()) + ast } }); - - delimited( - multispace0, - alt((boolean_expr, whitespace_separated_leaves)), - multispace0, - )(inp) + delimited(multispace0, alt((boolean_expr, single_leaf)), multispace0)(inp) } fn ast_infallible(inp: &str) -> JResult<&str, UserInputAst> { @@ -1155,21 +1146,39 @@ mod test { test_parse_query_to_ast_helper("a OR b", "(?a ?b)"); test_parse_query_to_ast_helper("a OR b AND c", "(?a ?(+b +c))"); test_parse_query_to_ast_helper("a AND b AND c", "(+a +b +c)"); - test_is_parse_err("a OR b aaa", "(?a ?b *aaa)"); - test_is_parse_err("a AND b aaa", "(?(+a +b) *aaa)"); - test_is_parse_err("aaa a OR b ", "(*aaa ?a ?b)"); - test_is_parse_err("aaa ccc a OR b ", "(*aaa *ccc ?a ?b)"); - test_is_parse_err("aaa a AND b ", "(*aaa ?(+a +b))"); - test_is_parse_err("aaa ccc a AND b ", "(*aaa *ccc ?(+a +b))"); + test_parse_query_to_ast_helper("a OR b aaa", "(?a ?b *aaa)"); + test_parse_query_to_ast_helper("a AND b aaa", "(?(+a +b) *aaa)"); + test_parse_query_to_ast_helper("aaa a OR b ", "(*aaa ?a ?b)"); + test_parse_query_to_ast_helper("aaa ccc a OR b ", "(*aaa *ccc ?a ?b)"); + test_parse_query_to_ast_helper("aaa a AND b ", "(*aaa ?(+a +b))"); + test_parse_query_to_ast_helper("aaa ccc a AND b ", "(*aaa *ccc ?(+a +b))"); } #[test] fn test_parse_mixed_bool_occur() { - test_is_parse_err("a OR b +aaa", "(?a ?b +aaa)"); - test_is_parse_err("a AND b -aaa", "(?(+a +b) -aaa)"); - test_is_parse_err("+a OR +b aaa", "(+a +b *aaa)"); - test_is_parse_err("-a AND -b aaa", "(?(-a -b) *aaa)"); - test_is_parse_err("-aaa +ccc -a OR b ", "(-aaa +ccc -a ?b)"); + test_parse_query_to_ast_helper("+a OR +b", "(+a +b)"); + + test_parse_query_to_ast_helper("a AND -b", "(+a -b)"); + test_parse_query_to_ast_helper("-a AND b", "(-a +b)"); + test_parse_query_to_ast_helper("a AND NOT b", "(+a +(-b))"); + test_parse_query_to_ast_helper("NOT a AND b", "(+(-a) +b)"); + + test_parse_query_to_ast_helper("a AND NOT b AND c", "(+a +(-b) +c)"); + test_parse_query_to_ast_helper("a AND -b AND c", "(+a -b +c)"); + + test_parse_query_to_ast_helper("a OR -b", "(?a ?(-b))"); + test_parse_query_to_ast_helper("-a OR b", "(?(-a) ?b)"); + test_parse_query_to_ast_helper("a OR NOT b", "(?a ?(-b))"); + test_parse_query_to_ast_helper("NOT a OR b", "(?(-a) ?b)"); + + test_parse_query_to_ast_helper("a OR NOT b OR c", "(?a ?(-b) ?c)"); + test_parse_query_to_ast_helper("a OR -b OR c", "(?a ?(-b) ?c)"); + + test_parse_query_to_ast_helper("a OR b +aaa", "(?a ?b +aaa)"); + test_parse_query_to_ast_helper("a AND b -aaa", "(?(+a +b) -aaa)"); + test_parse_query_to_ast_helper("+a OR +b aaa", "(+a +b *aaa)"); + test_parse_query_to_ast_helper("-a AND -b aaa", "(?(-a -b) *aaa)"); + test_parse_query_to_ast_helper("-aaa +ccc -a OR b ", "(-aaa +ccc ?(-a) ?b)"); } #[test]