Skip to content

Commit

Permalink
Упрощение AST
Browse files Browse the repository at this point in the history
failures (15):
    formats::parse_test_data_formats_type_ternary_2nd_falsy_ksy
    model::expressions::evaluation::binary::add::float::int
    model::expressions::evaluation::binary::add::int::float
    model::expressions::evaluation::binary::eq::bool::bool
    model::expressions::evaluation::binary::eq::float::float
    model::expressions::evaluation::binary::eq::float::int
    model::expressions::evaluation::binary::eq::int::float
    model::expressions::evaluation::binary::eq::int::int
    model::expressions::evaluation::binary::eq::str::str
    model::expressions::evaluation::binary::ne::bool::bool
    model::expressions::evaluation::binary::ne::float::float
    model::expressions::evaluation::binary::ne::float::int
    model::expressions::evaluation::binary::ne::int::float
    model::expressions::evaluation::binary::ne::int::int
    model::expressions::evaluation::binary::ne::str::str
  • Loading branch information
Mingun committed Aug 9, 2024
1 parent dddca11 commit 3b25682
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 1 deletion.
334 changes: 334 additions & 0 deletions src/model/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,30 @@ impl OwningNode {
let left = Self::validate(*left)?;
let right = Self::validate(*right)?;

macro_rules! fold_constants {
(
$l1:ident, $r1:ident, $r2:ident;
$l:ident, $r:ident;
$op1:ident, $op2:ident;
$res1:ident($expr1:expr);
$res2:ident($expr2:expr);
) => {
match (&*$l1, &*$r1, &$r2) {
(_, Int($l), Int($r) ) => Binary { op: $res1, left: Box::new(Int ($expr1)), right: $l1 },
(_, Float($l), Float($r)) => Binary { op: $res1, left: Box::new(Float($expr1)), right: $l1 },

(Int($l), _, Int($r) ) => Binary { op: $res2, left: Box::new(Int ($expr2)), right: $r1 },
(Float($l), _, Float($r)) => Binary { op: $res2, left: Box::new(Float($expr2)), right: $r1 },

_ => Binary {
op: $op1,
left: Box::new(Binary { op: $op2, left: $l1, right: $r1 }),
right: Box::new($r2),
},
}
};
}

match (op, left, right) {
//TODO: Check types before simplification
(Add, Str(l), r) if l.is_empty() => r,
Expand Down Expand Up @@ -223,6 +247,205 @@ impl OwningNode {
(Div, l, Float(r)) if r.is_one() => l,// x / 1 = x

//=======================================================================================
// _ + L + R L + _ + R => SUM + _
// + + +
// / \ / \ / \
// + R OR + R SUM _
// / \ / \ (L+R)
// _ L L _
(Add, Binary { op: Add, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) {
(_, Str(l), Str(r) ) => Binary { op: Add, left: l1, right: Box::new(Str(l.to_owned() + r)) },

(_, Int(l), Int(r) ) => Binary { op: Add, left: Box::new(Int (l + r)), right: l1 },
(_, Float(l), Float(r)) => Binary { op: Add, left: Box::new(Float(l + r)), right: l1 },

(Int(l), _, Int(r) ) => Binary { op: Add, left: Box::new(Int (l + r)), right: r1 },
(Float(l), _, Float(r)) => Binary { op: Add, left: Box::new(Float(l + r)), right: r1 },

_ => Binary {
op: Add,
left: Box::new(Binary { op: Add, left: l1, right: r1 }),
right: Box::new(r2),
},
},
(Mul, Binary { op: Mul, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Mul, Mul;
Mul(l * r);
Mul(l * r);
),
//---------------------------------------------------------------------------------------
// _ - L + R => SUB + _ L - _ + R => SUM - _
// + + + -
// / \ / \ / \ / \
// - R SUB _ OR - R SUM _
// / \ (R-L) / \ (L+R)
// _ L L _
(Add, Binary { op: Sub, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Add, Sub;
Add(r - l);
Sub(r + l);
),
(Mul, Binary { op: Div, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Mul, Div;
Mul(r / l);
Div(r * l);
),
/*(Add, Binary { op: Sub, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) {
(_, Int(l), Int(r) ) => Binary { op: Add, left: Box::new(Int (r - l)), right: l1 },
(_, Float(l), Float(r)) => Binary { op: Add, left: Box::new(Float(r - l)), right: l1 },
(Int(l), _, Int(r) ) => Binary { op: Sub, left: Box::new(Int (l + r)), right: r1 },
(Float(l), _, Float(r)) => Binary { op: Sub, left: Box::new(Float(l + r)), right: r1 },
_ => Binary {
op: Add,
left: Box::new(Binary { op: Sub, left: l1, right: r1 }),
right: Box::new(r2),
},
},*/
//---------------------------------------------------------------------------------------
// _ - L - R => SUM + _ L - _ - R => SUB - _
// - + - -
// / \ / \ / \ / \
// - R SUM _ OR - R SUB _
// / \ (-L-R) / \ (L-R)
// _ L L _
(Sub, Binary { op: Sub, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Sub, Sub;
Add(-l - r);
Sub( l - r);
),
(Div, Binary { op: Div, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Div, Div;
Mul(1/(l * r));
Div(l / r);
),
/*(Sub, Binary { op: Sub, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) {
(_, Int(l), Int(r) ) => Binary { op: Sub, left: l1, right: Box::new(Int (l + r)) },
(_, Float(l), Float(r)) => Binary { op: Sub, left: l1, right: Box::new(Float(l + r)) },
(Int(l), _, Int(r) ) => Binary { op: Sub, left: Box::new(Int (l - r)), right: r1 },
(Float(l), _, Float(r)) => Binary { op: Sub, left: Box::new(Float(l - r)), right: r1 },
_ => Binary {
op: Sub,
left: Box::new(Binary { op: Sub, left: l1, right: r1 }),
right: Box::new(r2),
},
},*/
//---------------------------------------------------------------------------------------
// _ + L - R L + _ - R => SUB + _
// - - +
// / \ / \ / \
// + R OR + R SUB _
// / \ / \ (L-R)
// _ L L _
(Sub, Binary { op: Add, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Sub, Add;
Add(l - r);
Add(l - r);
),
(Div, Binary { op: Mul, left: l1, right: r1 }, r2) => fold_constants!(
l1, r1, r2;
l, r;
Div, Mul;
Mul(l / r);
Mul(l / r);
),
/*(Sub, Binary { op: Add, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) {
(_, Int(l), Int(r) ) => Binary { op: Add, left: Box::new(Int (l - r)), right: l1 },
(_, Float(l), Float(r)) => Binary { op: Add, left: Box::new(Float(l - r)), right: l1 },
(Int(l), _, Int(r) ) => Binary { op: Add, left: Box::new(Int (l - r)), right: r1 },
(Float(l), _, Float(r)) => Binary { op: Add, left: Box::new(Float(l - r)), right: r1 },
_ => Binary {
op: Sub,
left: Box::new(Binary { op: Add, left: l1, right: r1 }),
right: Box::new(r2),
},
},*/
//=======================================================================================
(Add, Str(l), Str(r)) => Str(l + &r),

(Le, Str(l), Str(r)) => Bool(l <= r),
(Ge, Str(l), Str(r)) => Bool(l >= r),
(Lt, Str(l), Str(r)) => Bool(l < r),
(Gt, Str(l), Str(r)) => Bool(l > r),

//---------------------------------------------------------------------------------------
(Add, Int(l), r) if l.is_zero() => r,
(Add, l, Int(r)) if r.is_zero() => l,
(Sub, Int(l), r) if l.is_zero() => Unary { op: UnaryOp::Neg, expr: Box::new(r) },
(Sub, l, Int(r)) if r.is_zero() => l,

(Add, Float(l), r) if l.is_zero() => r,
(Add, l, Float(r)) if r.is_zero() => l,
(Sub, Float(l), r) if l.is_zero() => Unary { op: UnaryOp::Neg, expr: Box::new(r) },
(Sub, l, Float(r)) if r.is_zero() => l,

(Mul, Int(l), r) if l.is_one() => r,
(Mul, l, Int(r)) if r.is_one() => l,
(Div, l, Int(r)) if r.is_one() => l,

(Mul, Float(l), r) if l.is_one() => r,
(Mul, l, Float(r)) if r.is_one() => l,
(Div, l, Float(r)) if r.is_one() => l,
//---------------------------------------------------------------------------------------

(Add, Int(l), Int(r)) => Int(l + r),
(Sub, Int(l), Int(r)) => Int(l - r),
(Mul, Int(l), Int(r)) => Int(l * r),
(Div, Int(l), Int(r)) => Int(l / r),
// Rust `%` uses modulo operation (negative result for negative `l`), but Kaitai Struct
// uses remainder operation (always positive result): https://doc.kaitai.io/user_guide.html#_operators
// (Rem, Int(l), Int(r)) => Int(l.rem_euclid(r)),

// (Shl, Int(l), Int(r)) => Int(l << r),
// (Shr, Int(l), Int(r)) => Int(l >> r),

(Le, Int(l), Int(r)) => Bool(l <= r),
(Ge, Int(l), Int(r)) => Bool(l >= r),
(Lt, Int(l), Int(r)) => Bool(l < r),
(Gt, Int(l), Int(r)) => Bool(l > r),

(BitAnd, Int(l), Int(r)) => Int(l & r),
(BitOr, Int(l), Int(r)) => Int(l | r),
(BitXor, Int(l), Int(r)) => Int(l ^ r),
//---------------------------------------------------------------------------------------

(Add, Float(l), Float(r)) => Float(l + r),
(Sub, Float(l), Float(r)) => Float(l - r),
(Mul, Float(l), Float(r)) => Float(l * r),
(Div, Float(l), Float(r)) => Float(l / r),

(Le, Float(l), Float(r)) => Bool(l <= r),
(Ge, Float(l), Float(r)) => Bool(l >= r),
(Lt, Float(l), Float(r)) => Bool(l < r),
(Gt, Float(l), Float(r)) => Bool(l > r),
//---------------------------------------------------------------------------------------

(And, Bool(l), Bool(r)) => Bool(l && r),
(Or, Bool(l), Bool(r)) => Bool(l || r),

//---------------------------------------------------------------------------------------
// Symbolic calculations: if two subtrees are equal after normalization,
// then the result is known at compile-time
(Eq, l, r) if l == r => Bool(true),

//---------------------------------------------------------------------------------------
(_, l, r) => Binary {
op,
left: Box::new(l),
Expand Down Expand Up @@ -1280,6 +1503,99 @@ mod evaluation {
}
}
}

/// Checks that folding constants in triplets behaves correctly
mod triplets {
use super::*;
use pretty_assertions::assert_eq;
use BinaryOp::Add;

/// Tests folding of the integral numbers' constants
#[test]
fn int() {
assert_eq!(OwningNode::parse("x + 1 + 2").unwrap(), Binary {
op: Add,
left: Box::new(Int(3.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});
assert_eq!(OwningNode::parse("1 + x + 2").unwrap(), Binary {
op: Add,
left: Box::new(Int(3.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});
assert_eq!(OwningNode::parse("1 + 2 + x").unwrap(), Binary {
op: Add,
left: Box::new(Int(3.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});

assert_eq!(OwningNode::parse("1 + 2 + 3 + x + 4 + 5 + 6").unwrap(), Binary {
op: Add,
left: Box::new(Int(21.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});
}

/// Tests folding of the floating-points' constants
#[test]
fn float() {
assert_eq!(OwningNode::parse("x + 1.0 + 2.0").unwrap(), Binary {
op: Add,
left: Box::new(Float(3.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});
assert_eq!(OwningNode::parse("1.0 + x + 2.0").unwrap(), Binary {
op: Add,
left: Box::new(Float(3.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});
assert_eq!(OwningNode::parse("1.0 + 2.0 + x").unwrap(), Binary {
op: Add,
left: Box::new(Float(3.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});

assert_eq!(OwningNode::parse("1.0 + 2.0 + 3.0 + x + 4.0 + 5.0 + 6.0").unwrap(), Binary {
op: Add,
left: Box::new(Float(21.into())),
right: Box::new(Attr(FieldName::valid("x"))),
});
}

/// Tests folding of the string constants
#[test]
fn str() {
assert_eq!(OwningNode::parse("x + 'a' + 'b'").unwrap(), Binary {
op: Add,
left: Box::new(Attr(FieldName::valid("x"))),
right: Box::new(Str("ab".into())),
});
assert_eq!(OwningNode::parse("'a' + x + 'b'").unwrap(), Binary {
op: Add,
left: Box::new(Binary {
op: Add,
left: Box::new(Str("a".into())),
right: Box::new(Attr(FieldName::valid("x"))),
}),
right: Box::new(Str("b".into())),
});
assert_eq!(OwningNode::parse("'a' + 'b' + x").unwrap(), Binary {
op: Add,
left: Box::new(Str("ab".into())),
right: Box::new(Attr(FieldName::valid("x"))),
});

assert_eq!(OwningNode::parse("'a' + 'b' + 'c' + x + 'd' + 'e' + 'f'").unwrap(), Binary {
op: Add,
left: Box::new(Binary {
op: Add,
left: Box::new(Str("abc".into())),
right: Box::new(Attr(FieldName::valid("x"))),
}),
right: Box::new(Str("def".into())),
});
}
}
}

#[test]
Expand All @@ -1292,4 +1608,22 @@ mod evaluation {
if_false: Box::new(Attr(FieldName::valid("b"))),
}));
}

#[test]
#[ignore]
fn index() {//TODO: Index validation
assert_eq!(OwningNode::parse("[3, 1, 4][-1]"), Ok(Int(4.into())));
assert_eq!(OwningNode::parse("[3, 1, 4][ 2]"), Ok(Int(4.into())));
assert_eq!(OwningNode::parse("[3, 1, x][ 2]"), Ok(Attr(FieldName::valid("x"))));
assert_eq!(OwningNode::parse("[3, 1, 4][ x]"), Ok(Index {
expr: Box::new(List(vec![
Int(3.into()),
Int(1.into()),
Int(4.into()),
])),
index: Box::new(Attr(FieldName::valid("x"))),
}));

assert_eq!(OwningNode::parse("[3, 1, 4][ 3]"), Err(Validation("".into())));
}
}
9 changes: 8 additions & 1 deletion src/parser/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,14 @@ pub enum BinaryOp {
Mul,
/// `/`: Division of two numeric arguments.
Div,
/// `%`: Remainder of division of two numeric arguments.
/// `%`: Nonnegative remainder of division of two numeric arguments.
///
/// This operation is different from Rust [`%`] operator: [`-5 % 3` is `1`, not `-2`][operators].
/// Analogous in Rust for this operation is [`rem_euclid`].
///
/// [`%`]: std::ops::Rem
/// [`rem_euclid`]: i64::rem_euclid
/// [operators]: https://doc.kaitai.io/user_guide.html#_operators
Rem,

/// `<<`: The left shift operator.
Expand Down

0 comments on commit 3b25682

Please sign in to comment.