Skip to content

Commit

Permalink
Fix non-exhaustive warnings with the Ast.expr.Call node
Browse files Browse the repository at this point in the history
Because function pointers is forbidden in the expression language from now parser
won't allow anymore following constructions:
- 42()
- "string"()
- true()
- (...)()
- []()
  • Loading branch information
Mingun committed Dec 21, 2020
1 parent f083c43 commit 3f66677
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class ExpressionsSpec extends FunSpec {
)
}

// Attribute / method call
// Attributes
it("parses 123.to_s") {
Expressions.parse("123.to_s") should be (Attribute(IntNum(123),identifier("to_s")))
}
Expand All @@ -388,5 +388,62 @@ class ExpressionsSpec extends FunSpec {
it("parses foo.bar") {
Expressions.parse("foo.bar") should be (Attribute(Name(identifier("foo")),identifier("bar")))
}

// Method calls
describe("parses method") {
it("without parameters") {
Expressions.parse("foo.bar()") should be (Call(Name(identifier("foo")),identifier("bar"), Seq()))
}

it("with parameters") {
Expressions.parse("foo.bar(42)") should be (Call(Name(identifier("foo")),identifier("bar"), Seq(IntNum(42))))
}

it("on strings") {
Expressions.parse("\"foo\".bar(42)") should be (Call(Str("foo"),identifier("bar"), Seq(IntNum(42))))
Expressions.parse("'foo'.bar(42)") should be (Call(Str("foo"),identifier("bar"), Seq(IntNum(42))))
}

it("on booleans") {
Expressions.parse("true.bar(42)") should be (Call(Bool(true), identifier("bar"), Seq(IntNum(42))))
Expressions.parse("false.bar(42)") should be (Call(Bool(false),identifier("bar"), Seq(IntNum(42))))
}

it("on integer") {
Expressions.parse("42.bar(42)") should be (Call(IntNum(42), identifier("bar"), Seq(IntNum(42))))
}

it("on float") {
Expressions.parse("42.0.bar(42)") should be (Call(FloatNum(42.0), identifier("bar"), Seq(IntNum(42))))
}

it("on array") {
Expressions.parse("[].bar(42)") should be (Call(List(Nil), identifier("bar"), Seq(IntNum(42))))
}

it("on slice") {
Expressions.parse("foo[1].bar(42)") should be (
Call(
Subscript(Name(identifier("foo")), IntNum(1)),
identifier("bar"),
Seq(IntNum(42))
)
)
}

it("on group") {
Expressions.parse("(42).bar(42)") should be (Call(IntNum(42), identifier("bar"), Seq(IntNum(42))))
}

it("on expression") {
Expressions.parse("(1+2).bar(42)") should be (
Call(
BinOp(IntNum(1), Add, IntNum(2)),
identifier("bar"),
Seq(IntNum(42))
)
)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,8 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends
case _ =>
affectedVars(value)
}
case Ast.expr.Call(func, args) =>
val fromFunc = func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => affectedVars(obj)
}
fromFunc ::: affectedVars(Ast.expr.List(args))
case Ast.expr.Call(value, _, args) =>
affectedVars(value) ::: affectedVars(Ast.expr.List(args))
case Ast.expr.Subscript(value, idx) =>
affectedVars(value) ++ affectedVars(idx)
case SwitchType.ELSE_CONST =>
Expand Down
12 changes: 11 additions & 1 deletion shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,17 @@ object Ast {
case class IfExp(condition: expr, ifTrue: expr, ifFalse: expr) extends expr
// case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr
case class Compare(left: expr, ops: cmpop, right: expr) extends expr
case class Call(func: expr, args: Seq[expr]) extends expr
/**
* Represents function call on some expression:
* ```
* <obj>.<methodName>(<args>)
* ```
*
* @param obj expression on which method is called
* @param methodName method to call
* @param args method arguments
*/
case class Call(obj: expr, methodName: identifier, args: Seq[expr]) extends expr
case class IntNum(n: BigInt) extends expr
case class FloatNum(n: BigDecimal) extends expr
case class Str(s: String) extends expr
Expand Down
13 changes: 10 additions & 3 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,20 @@ object Expressions {
val list = P( list_contents ).map(Ast.expr.List(_))

val trailer: P[Ast.expr => Ast.expr] = {
val call = P("(" ~ arglist ~ ")").map{ case (args) => (lhs: Ast.expr) => Ast.expr.Call(lhs, args)}
val call = P("(" ~ arglist ~ ")")
val slice = P("[" ~ test ~ "]").map{ case (args) => (lhs: Ast.expr) => Ast.expr.Subscript(lhs, args)}
val cast = P( "." ~ "as" ~ "<" ~ TYPE_NAME ~ ">" ).map(
typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName)
)
val attr = P("." ~ NAME).map(id => (lhs: Ast.expr) => Ast.expr.Attribute(lhs, id))
P( call | slice | cast | attr )
// Returns function that accept дыр expression and returns Attribute or Call
// node depending on existence of parameters
val attr = P("." ~ NAME ~ call.?).map{
case (id, args) => (lhs: Ast.expr) => args match {
case Some(args) => Ast.expr.Call(lhs, id, args)
case None => Ast.expr.Attribute(lhs, id)
}
}
P( slice | cast | attr )
}

val exprlist: P[Seq[Ast.expr]] = P( expr.rep(1, sep = ",") ~ ",".? )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,15 @@ abstract trait CommonMethods[T] extends TypeDetector {
* @return result of translation as [[T]]
*/
def translateCall(call: Ast.expr.Call): T = {
val func = call.func
val obj = call.obj
val args = call.args

func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
(objType, methodName.name) match {
// TODO: check argument quantity
case (_: StrType, "substring") => strSubstring(obj, args(0), args(1))
case (_: StrType, "to_i") => strToInt(obj, args(0))
case (_: BytesType, "to_s") => bytesToStr(obj, args(0))
case _ => throw new TypeMismatchError(s"don't know how to call method '$methodName' of object type '$objType'")
}
val objType = detectType(call.obj)
(objType, call.methodName.name) match {
// TODO: check argument quantity
case (_: StrType, "substring") => strSubstring(obj, args(0), args(1))
case (_: StrType, "to_i") => strToInt(obj, args(0))
case (_: BytesType, "to_s") => bytesToStr(obj, args(0))
case _ => throw new TypeMismatchError(s"don't know how to call method '${call.methodName}' of object type '$objType'")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,23 +215,19 @@ class TypeDetector(provider: TypeProvider) {
/**
* Detects resulting data type of a given function call expression. Typical function
* call expression in KSY is `foo.bar(arg1, arg2)`, which is represented in AST as
* `Call(Attribute(foo, bar), Seq(arg1, arg2))`.
* `Call(foo, bar, Seq(arg1, arg2))`.
* @note Must be kept in sync with [[CommonMethods.translateCall]]
* @param call function call expression
* @return data type
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
(objType, methodName.name) match {
case (_: StrType, "substring") => CalcStrType
case (_: StrType, "to_i") => CalcIntType
case (_: BytesType, "to_s") => CalcStrType
case _ =>
throw new MethodNotFoundError(methodName.name, objType)
}
val objType = detectType(call.obj)
// TODO: check number and type of arguments in `call.args`
(objType, call.methodName.name) match {
case (_: StrType, "substring") => CalcStrType
case (_: StrType, "to_i") => CalcIntType
case (_: BytesType, "to_s") => CalcStrType
case _ => throw new MethodNotFoundError(call.methodName.name, objType)
}
}

Expand Down

0 comments on commit 3f66677

Please sign in to comment.