Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-exhaustive warnings with the Ast.expr.Call node #227

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ class ExpressionsSpec extends AnyFunSpec {
)
}

// Attribute / method call
// Attributes
it("parses 123.to_s") {
Expressions.parse("123.to_s") should be (Attribute(IntNum(123),identifier("to_s")))
}
Expand All @@ -404,6 +404,63 @@ class ExpressionsSpec extends AnyFunSpec {
Expressions.parse("foo.bar") should be (Attribute(Name(identifier("foo")),identifier("bar")))
}

// Method calls
describe("parses method") {
it("without parameters") {
Expressions.parse("foo.bar()") should be (Call(Name(identifier("foo")),identifier("bar"), Seq()))
}

it("with parameters") {
Expressions.parse("foo.bar(42)") should be (Call(Name(identifier("foo")),identifier("bar"), Seq(IntNum(42))))
}

it("on strings") {
Expressions.parse("\"foo\".bar(42)") should be (Call(Str("foo"),identifier("bar"), Seq(IntNum(42))))
Expressions.parse("'foo'.bar(42)") should be (Call(Str("foo"),identifier("bar"), Seq(IntNum(42))))
}

it("on booleans") {
Expressions.parse("true.bar(42)") should be (Call(Bool(true), identifier("bar"), Seq(IntNum(42))))
Expressions.parse("false.bar(42)") should be (Call(Bool(false),identifier("bar"), Seq(IntNum(42))))
}

it("on integer") {
Expressions.parse("42.bar(42)") should be (Call(IntNum(42), identifier("bar"), Seq(IntNum(42))))
}

it("on float") {
Expressions.parse("42.0.bar(42)") should be (Call(FloatNum(42.0), identifier("bar"), Seq(IntNum(42))))
}

it("on array") {
Expressions.parse("[].bar(42)") should be (Call(List(Nil), identifier("bar"), Seq(IntNum(42))))
}

it("on slice") {
Expressions.parse("foo[1].bar(42)") should be (
Call(
Subscript(Name(identifier("foo")), IntNum(1)),
identifier("bar"),
Seq(IntNum(42))
)
)
}

it("on group") {
Expressions.parse("(42).bar(42)") should be (Call(IntNum(42), identifier("bar"), Seq(IntNum(42))))
}

it("on expression") {
Expressions.parse("(1+2).bar(42)") should be (
Call(
BinOp(IntNum(1), Add, IntNum(2)),
identifier("bar"),
Seq(IntNum(42))
)
)
}
}

describe("strings") {
it("single-quoted") {
// \" -> \"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,8 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends
case _ =>
affectedVars(value)
}
case Ast.expr.Call(func, args) =>
val fromFunc = func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => affectedVars(obj)
}
fromFunc ::: affectedVars(Ast.expr.List(args))
case Ast.expr.Call(value, _, args) =>
affectedVars(value) ::: affectedVars(Ast.expr.List(args))
case Ast.expr.Subscript(value, idx) =>
affectedVars(value) ++ affectedVars(idx)
case Ast.expr.Name(id) =>
Expand Down
12 changes: 11 additions & 1 deletion shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ object Ast {
// case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr
/** Represents `X < Y`, `X > Y` and so on. */
case class Compare(left: expr, ops: cmpop, right: expr) extends expr
case class Call(func: expr, args: Seq[expr]) extends expr
/**
* Represents function call on some expression:
* ```
* <obj>.<methodName>(<args>)
* ```
*
* @param obj expression on which method is called
* @param methodName method to call
* @param args method arguments
*/
case class Call(obj: expr, methodName: identifier, args: Seq[expr]) extends expr
case class IntNum(n: BigInt) extends expr
case class FloatNum(n: BigDecimal) extends expr
case class Str(s: String) extends expr
Expand Down
13 changes: 10 additions & 3 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,20 @@ object Expressions {
def list_contents[$: P] = P( test.rep(1, ",") ~ ",".? )
def list[$: P] = P( list_contents ).map(Ast.expr.List(_))

def call[$: P] = P("(" ~ arglist ~ ")").map { case (args) => (lhs: Ast.expr) => Ast.expr.Call(lhs, args)}
def call[$: P] = P("(" ~ arglist ~ ")")
def slice[$: P] = P("[" ~ test ~ "]").map { case (args) => (lhs: Ast.expr) => Ast.expr.Subscript(lhs, args)}
def cast[$: P] = P( "." ~ "as" ~ "<" ~ TYPE_NAME ~ ">" ).map(
typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName)
)
def attr[$: P] = P("." ~ NAME).map(id => (lhs: Ast.expr) => Ast.expr.Attribute(lhs, id))
def trailer[$: P]: P[Ast.expr => Ast.expr] = P( call | slice | cast | attr )
// Returns function that accept lsh expression and returns Attribute or Call
// node depending on existence of parameters
def attr[$: P] = P("." ~ NAME ~ call.?).map {
case (id, args) => (lhs: Ast.expr) => args match {
case Some(args) => Ast.expr.Call(lhs, id, args)
case None => Ast.expr.Attribute(lhs, id)
}
}
def trailer[$: P]: P[Ast.expr => Ast.expr] = P( slice | cast | attr )

def exprlist[$: P]: P[Seq[Ast.expr]] = P( expr.rep(1, sep = ",") ~ ",".? )
def testlist[$: P]: P[Seq[Ast.expr]] = P( test.rep(1, sep = ",") ~ ",".? )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,15 @@ abstract trait CommonMethods[T] extends TypeDetector {
* @return result of translation as [[T]]
*/
def translateCall(call: Ast.expr.Call): T = {
val func = call.func
val obj = call.obj
val args = call.args

func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
MethodArgType.byDataType(objType) match {
case Some(argType) =>
invokeMethod(argType, methodName.name, obj, args)
case None =>
throw new MethodNotFoundError(methodName.name, objType)
}
val objType = detectType(obj)
MethodArgType.byDataType(objType) match {
case Some(argType) =>
invokeMethod(argType, call.methodName.name, obj, args)
case None =>
throw new MethodNotFoundError(call.methodName.name, objType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,23 +238,20 @@ class TypeDetector(provider: TypeProvider) {
/**
* Detects resulting data type of a given function call expression. Typical function
* call expression in KSY is `foo.bar(arg1, arg2)`, which is represented in AST as
* `Call(Attribute(foo, bar), Seq(arg1, arg2))`.
* `Call(foo, bar, Seq(arg1, arg2))`.
* @note Must be kept in sync with [[CommonMethods.translateCall]]
* @param call function call expression
* @return data type
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
(objType, methodName.name) match {
case (_: StrType, "substring") => CalcStrType
case (_: StrType, "to_i") => CalcIntType
case (_: BytesType, "to_s") => CalcStrType
case _ =>
throw new MethodNotFoundError(methodName.name, objType)
}
val objType = detectType(call.obj)
// TODO: check number and type of arguments in `call.args`
(objType, call.methodName.name) match {
case (_: StrType, "substring") => CalcStrType
case (_: StrType, "to_i") => CalcIntType
case (_: BytesType, "to_s") => CalcStrType
case _ =>
throw new MethodNotFoundError(call.methodName.name, objType)
}
}

Expand Down
Loading