diff --git a/core/builtin_funcs.py b/core/builtin_funcs.py index 2215014..27385cb 100755 --- a/core/builtin_funcs.py +++ b/core/builtin_funcs.py @@ -45,6 +45,7 @@ class BuiltInFunction(BaseFunction): def __init__(self, name: str, func: Optional[RadonCompatibleFunction] = None): super().__init__(name, None) self.func = func + self.va_name = None def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]: res = RTResult[Value]() diff --git a/core/datatypes.py b/core/datatypes.py index c145ab1..d266f0e 100755 --- a/core/datatypes.py +++ b/core/datatypes.py @@ -1038,6 +1038,7 @@ class BaseFunction(Value): symbol_table: Optional[SymbolTable] desc: str arg_names: list[str] + va_name: Optional[str] def __init__(self, name: Optional[str], symbol_table: Optional[SymbolTable]) -> None: super().__init__() @@ -1055,7 +1056,7 @@ def check_args( res = RTResult[None]() args_count = len(args) + len(kwargs) - if args_count > len(arg_names): + if self.va_name is None and args_count > len(arg_names): return res.failure( RTError( self.pos_start, @@ -1098,6 +1099,14 @@ def populate_args(self, arg_names, args, kwargs, defaults, exec_ctx): arg_value.set_context(exec_ctx) exec_ctx.symbol_table.set(arg_name, arg_value) + if self.va_name is not None: + va_list = [] + for i in range(len(arg_names), len(args)): + arg = args[i] + arg.set_context(exec_ctx) + va_list.append(arg) + exec_ctx.symbol_table.set(self.va_name, Array(va_list)) + for kw, kwarg in kwargs.items(): kwarg.set_context(exec_ctx) exec_ctx.symbol_table.set(kw, kwarg) @@ -1372,6 +1381,7 @@ def __init__( defaults: list[Optional[Value]], should_auto_return: bool, desc: str, + va_name: Optional[str], ) -> None: super().__init__(name, symbol_table) self.body_node = body_node @@ -1379,6 +1389,7 @@ def __init__( self.defaults = defaults self.should_auto_return = should_auto_return self.desc = desc + self.va_name = va_name def execute(self, args: list[Value], kwargs: dict[str, Value]) -> RTResult[Value]: from core.interpreter import Interpreter # Lazy import @@ -1412,6 +1423,7 @@ def copy(self) -> Function: self.defaults, self.should_auto_return, self.desc, + self.va_name, ) copy.set_context(self.context) copy.set_pos(self.pos_start, self.pos_end) diff --git a/core/interpreter.py b/core/interpreter.py index 48e4b84..50c79f4 100755 --- a/core/interpreter.py +++ b/core/interpreter.py @@ -485,7 +485,14 @@ def visit_FuncDefNode(self, node: FuncDefNode, context: Context) -> RTResult[Val func_value = ( Function( - func_name, context.symbol_table, body_node, arg_names, defaults, node.should_auto_return, func_desc + func_name, + context.symbol_table, + body_node, + arg_names, + defaults, + node.should_auto_return, + func_desc, + va_name=node.va_name, ) .set_context(context) .set_pos(node.pos_start, node.pos_end) diff --git a/core/lexer.py b/core/lexer.py index 212718b..43bcac5 100755 --- a/core/lexer.py +++ b/core/lexer.py @@ -78,8 +78,7 @@ def make_tokens(self) -> tuple[list[Token], Optional[Error]]: tokens.append(Token(TT_COLON, pos_start=self.pos)) self.advance() elif self.current_char == ".": - tokens.append(Token(TT_DOT, pos_start=self.pos)) - self.advance() + tokens.append(self.make_dot()) elif self.current_char == "!": token, error = self.make_not_equals() if error is not None: @@ -273,6 +272,18 @@ def make_power_equals(self) -> Token: return Token(tok_type, pos_start=pos_start, pos_end=self.pos) + def make_dot(self) -> Token: + tok_type = TT_DOT + pos_start = self.pos.copy() + self.advance() + + if self.text[self.pos.idx :].startswith(".."): + self.advance() + self.advance() + tok_type = TT_SPREAD + + return Token(tok_type, pos_start=pos_start, pos_end=self.pos.copy()) + def skip_comment(self) -> None: multi_line = False self.advance() diff --git a/core/nodes.py b/core/nodes.py index eae9d8b..7bf6959 100755 --- a/core/nodes.py +++ b/core/nodes.py @@ -229,6 +229,7 @@ def __init__(self, condition_node: Node, body_node: Node, should_return_null: bo self.pos_end = self.body_node.pos_end +@dataclass class FuncDefNode: var_name_tok: Optional[Token] arg_name_toks: list[Token] @@ -237,37 +238,11 @@ class FuncDefNode: should_auto_return: bool static: bool desc: str + va_name: Optional[str] pos_start: Position pos_end: Position - def __init__( - self, - var_name_tok: Optional[Token], - arg_name_toks: list[Token], - defaults: list[Optional[Node]], - body_node: Node, - should_auto_return: bool, - static: bool = False, - desc: str = "", - ) -> None: - self.var_name_tok = var_name_tok - self.arg_name_toks = arg_name_toks - self.defaults = defaults - self.body_node = body_node - self.should_auto_return = should_auto_return - self.static = static - self.desc = desc - - if self.var_name_tok: - self.pos_start = self.var_name_tok.pos_start - elif len(self.arg_name_toks) > 0: - self.pos_start = self.arg_name_toks[0].pos_start - else: - self.pos_start = self.body_node.pos_start - - self.pos_end = self.body_node.pos_end - class CallNode: node_to_call: Node diff --git a/core/parser.py b/core/parser.py index 4f1aae1..8599ace 100755 --- a/core/parser.py +++ b/core/parser.py @@ -1197,6 +1197,8 @@ def class_node(self) -> ParseResult[Node]: def func_def(self) -> ParseResult[Node]: res = ParseResult[Node]() + node_pos_start = self.current_tok.pos_start + static = False if self.current_tok.matches(TT_KEYWORD, "static"): self.advance(res) @@ -1229,24 +1231,36 @@ def func_def(self) -> ParseResult[Node]: self.advance(res) arg_name_toks = [] defaults: list[Optional[Node]] = [] - hasOptionals = False + has_optionals = False + is_va = False + va_name: Optional[str] = None + + if self.current_tok.type == TT_SPREAD: + is_va = True + self.advance(res) if self.current_tok.type == TT_IDENTIFIER: pos_start = self.current_tok.pos_start.copy() pos_end = self.current_tok.pos_end.copy() - arg_name_toks.append(self.current_tok) + arg_name_tok = self.current_tok + assert isinstance(arg_name_tok.value, str) self.advance(res) + if not is_va: + arg_name_toks.append(arg_name_tok) - if self.current_tok.type == TT_EQ: + if is_va: + va_name = arg_name_tok.value + is_va = False + elif self.current_tok.type == TT_EQ: self.advance(res) default = res.register(self.expr()) if res.error: return res assert default is not None defaults.append(default) - hasOptionals = True - elif hasOptionals: + has_optionals = True + elif has_optionals: return res.failure(InvalidSyntaxError(pos_start, pos_end, "Expected optional parameter.")) else: defaults.append(None) @@ -1254,6 +1268,10 @@ def func_def(self) -> ParseResult[Node]: while self.current_tok.type == TT_COMMA: self.advance(res) + if self.current_tok.type == TT_SPREAD: + is_va = True + self.advance(res) + if self.current_tok.type != TT_IDENTIFIER: return res.failure( InvalidSyntaxError(self.current_tok.pos_start, self.current_tok.pos_end, "Expected identifier") @@ -1261,18 +1279,25 @@ def func_def(self) -> ParseResult[Node]: pos_start = self.current_tok.pos_start.copy() pos_end = self.current_tok.pos_end.copy() - arg_name_toks.append(self.current_tok) + + arg_name_tok = self.current_tok + assert isinstance(arg_name_tok.value, str) + if not is_va: + arg_name_toks.append(arg_name_tok) self.advance(res) - if self.current_tok.type == TT_EQ: + if is_va: + va_name = arg_name_tok.value + is_va = False + elif self.current_tok.type == TT_EQ: self.advance(res) default = res.register(self.expr()) if res.error: return res assert default is not None defaults.append(default) - hasOptionals = True - elif hasOptionals: + has_optionals = True + elif has_optionals: return res.failure(InvalidSyntaxError(pos_start, pos_end, "Expected optional parameter.")) else: defaults.append(None) @@ -1301,7 +1326,18 @@ def func_def(self) -> ParseResult[Node]: assert body is not None return res.success( - FuncDefNode(var_name_tok, arg_name_toks, defaults, body, True, static=static, desc="[No Description]") + FuncDefNode( + var_name_tok, + arg_name_toks, + defaults, + body, + True, + static=static, + desc="[No Description]", + va_name=va_name, + pos_start=node_pos_start, + pos_end=self.current_tok.pos_end, + ) ) self.skip_newlines() @@ -1329,7 +1365,20 @@ def func_def(self) -> ParseResult[Node]: self.advance(res) - return res.success(FuncDefNode(var_name_tok, arg_name_toks, defaults, body, False, static=static, desc=desc)) + return res.success( + FuncDefNode( + var_name_tok, + arg_name_toks, + defaults, + body, + False, + static=static, + desc=desc, + va_name=va_name, + pos_start=node_pos_start, + pos_end=self.current_tok.pos_end, + ) + ) def switch_statement(self) -> ParseResult[Node]: res = ParseResult[Node]() diff --git a/core/tokens.py b/core/tokens.py index 2e65512..2208a2e 100755 --- a/core/tokens.py +++ b/core/tokens.py @@ -101,6 +101,7 @@ def copy(self) -> Position: TT_SLICE = TokenType("SLICE") # x[1:2:3] TT_PLUS_PLUS = TokenType("PLUS_PLUS") # ++ TT_MINUS_MINUS = TokenType("MINUS_MINUS") # -- +TT_SPREAD = TokenType("SPREAD") # ... KEYWORDS = [ "and", diff --git a/tests/varargs.rn b/tests/varargs.rn new file mode 100644 index 0000000..a036869 --- /dev/null +++ b/tests/varargs.rn @@ -0,0 +1,20 @@ + +fun f(...args) { + print(args) +} + +fun g(a, ...args) { + print(a) + print(args) +} + +f() +f(1, 2, 3) +f("hello", "world", "!") +f("a", "b", "c") + +g(1) +g(1, 2, 3) +g("hello", "world", "!") +g("a", "b", "c") + diff --git a/tests/varargs.rn.json b/tests/varargs.rn.json new file mode 100644 index 0000000..0f0a369 --- /dev/null +++ b/tests/varargs.rn.json @@ -0,0 +1 @@ +{"code": 0, "stdout": "[]\n[1, 2, 3]\n[\"hello\", \"world\", \"!\"]\n[\"a\", \"b\", \"c\"]\n1\n[]\n1\n[2, 3]\nhello\n[\"world\", \"!\"]\na\n[\"b\", \"c\"]\n", "stderr": ""} \ No newline at end of file