From 248bee49c55a63ddb5662f8b232c6a3c4890ac37 Mon Sep 17 00:00:00 2001 From: Oleksii-Klimov <133792808+Oleksii-Klimov@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:29:09 +0000 Subject: [PATCH] feat: read json lazily to tolerate incorrect closing brackets (#37) * Read json lazily to tolerate incorrect closing brackets that models occasionally generate. --- aidial_assistant/application/prompts.py | 2 +- aidial_assistant/chain/command_chain.py | 62 +++---- .../chain/model_response_reader.py | 6 +- ...racterstream.py => chunked_char_stream.py} | 29 ++-- aidial_assistant/json_stream/exceptions.py | 18 +- aidial_assistant/json_stream/json_array.py | 134 +++++++-------- aidial_assistant/json_stream/json_bool.py | 33 ++-- aidial_assistant/json_stream/json_node.py | 92 ++++++---- .../json_stream/json_normalizer.py | 30 ---- aidial_assistant/json_stream/json_null.py | 20 +-- aidial_assistant/json_stream/json_number.py | 29 ++-- aidial_assistant/json_stream/json_object.py | 157 +++++++++--------- aidial_assistant/json_stream/json_parser.py | 80 ++++----- aidial_assistant/json_stream/json_root.py | 87 ---------- aidial_assistant/json_stream/json_string.py | 108 ++++++------ .../json_stream/test_json_stream.py | 57 ++++++- 16 files changed, 462 insertions(+), 482 deletions(-) rename aidial_assistant/json_stream/{characterstream.py => chunked_char_stream.py} (77%) delete mode 100644 aidial_assistant/json_stream/json_normalizer.py delete mode 100644 aidial_assistant/json_stream/json_root.py diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 9a49485..194a13f 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -35,7 +35,7 @@ def build(self, **kwargs) -> Template: { "command": "", "args": [ - // + "", "", ... ] } ] diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index c5e54f9..10041f8 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -29,7 +29,7 @@ skip_to_json_start, ) from aidial_assistant.commands.base import Command, FinalCommand -from aidial_assistant.json_stream.characterstream import CharacterStream +from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream from aidial_assistant.json_stream.exceptions import JsonParsingException from aidial_assistant.json_stream.json_node import JsonNode from aidial_assistant.json_stream.json_parser import JsonParser @@ -166,39 +166,39 @@ async def _run_with_protocol_failure_retries( async def _run_commands( self, chunk_stream: AsyncIterator[str], callback: ChainCallback ) -> Tuple[list[CommandInvocation], list[CommandResult]]: - char_stream = CharacterStream(chunk_stream) + char_stream = ChunkedCharStream(chunk_stream) await skip_to_json_start(char_stream) - async with JsonParser.parse(char_stream) as root_node: - commands: list[CommandInvocation] = [] - responses: list[CommandResult] = [] - request_reader = CommandsReader(root_node) - async for invocation in request_reader.parse_invocations(): - command_name = await invocation.parse_name() - command = self._create_command(command_name) - args = invocation.parse_args() - if isinstance(command, FinalCommand): - if len(responses) > 0: - continue - message = await anext(args) - await CommandChain._to_result( - message - if isinstance(message, JsonString) - else message.to_string_chunks(), - callback.result_callback(), - ) - break - else: - response = await CommandChain._execute_command( - command_name, command, args, callback - ) + root_node = await JsonParser().parse(char_stream) + commands: list[CommandInvocation] = [] + responses: list[CommandResult] = [] + request_reader = CommandsReader(root_node) + async for invocation in request_reader.parse_invocations(): + command_name = await invocation.parse_name() + command = self._create_command(command_name) + args = invocation.parse_args() + if isinstance(command, FinalCommand): + if len(responses) > 0: + continue + message = await anext(args) + await CommandChain._to_result( + message + if isinstance(message, JsonString) + else message.to_chunks(), + callback.result_callback(), + ) + break + else: + response = await CommandChain._execute_command( + command_name, command, args, callback + ) - commands.append( - cast(CommandInvocation, invocation.node.value()) - ) - responses.append(response) + commands.append( + cast(CommandInvocation, invocation.node.value()) + ) + responses.append(response) - return commands, responses + return commands, responses def _create_command(self, name: str) -> Command: if name not in self.command_dict: @@ -237,7 +237,7 @@ async def _to_args( arg_callback = args_callback.arg_callback() arg_callback.on_arg_start() result = "" - async for chunk in arg.to_string_chunks(): + async for chunk in arg.to_chunks(): arg_callback.on_arg(chunk) result += chunk arg_callback.on_arg_end() diff --git a/aidial_assistant/chain/model_response_reader.py b/aidial_assistant/chain/model_response_reader.py index 464a22e..dffe11f 100644 --- a/aidial_assistant/chain/model_response_reader.py +++ b/aidial_assistant/chain/model_response_reader.py @@ -1,6 +1,6 @@ from collections.abc import AsyncIterator -from aidial_assistant.json_stream.characterstream import AsyncPeekable +from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream from aidial_assistant.json_stream.json_array import JsonArray from aidial_assistant.json_stream.json_node import JsonNode from aidial_assistant.json_stream.json_object import JsonObject @@ -16,12 +16,12 @@ class AssistantProtocolException(Exception): pass -async def skip_to_json_start(stream: AsyncPeekable[str]): +async def skip_to_json_start(stream: ChunkedCharStream): # Some models tend to provide explanations for their replies regardless of what the prompt says. try: while True: char = await stream.apeek() - if char == JsonObject.token(): + if JsonObject.starts_with(char): break await anext(stream) diff --git a/aidial_assistant/json_stream/characterstream.py b/aidial_assistant/json_stream/chunked_char_stream.py similarity index 77% rename from aidial_assistant/json_stream/characterstream.py rename to aidial_assistant/json_stream/chunked_char_stream.py index 9b0ffc8..0553436 100644 --- a/aidial_assistant/json_stream/characterstream.py +++ b/aidial_assistant/json_stream/chunked_char_stream.py @@ -1,22 +1,10 @@ -from abc import ABC, abstractmethod +from abc import ABC from collections.abc import AsyncIterator -from typing import Generic, TypeVar from typing_extensions import override -T = TypeVar("T") - -class AsyncPeekable(ABC, Generic[T], AsyncIterator[T]): - @abstractmethod - async def apeek(self) -> T: - pass - - async def askip(self) -> None: - await anext(self) - - -class CharacterStream(AsyncPeekable[str]): +class ChunkedCharStream(ABC, AsyncIterator[str]): def __init__(self, source: AsyncIterator[str]): self._source = source self._chunk: str = "" @@ -33,7 +21,6 @@ async def __anext__(self) -> str: self._next_char_offset += 1 return result - @override async def apeek(self) -> str: while self._next_char_offset == len(self._chunk): self._chunk_position += len(self._chunk) @@ -41,6 +28,9 @@ async def apeek(self) -> str: self._next_char_offset = 0 return self._chunk[self._next_char_offset] + async def askip(self): + await anext(self) + @property def chunk_position(self) -> int: return self._chunk_position @@ -48,3 +38,12 @@ def chunk_position(self) -> int: @property def char_position(self) -> int: return self._chunk_position + self._next_char_offset + + +async def skip_whitespaces(stream: ChunkedCharStream): + while True: + char = await stream.apeek() + if not str.isspace(char): + break + + await stream.askip() diff --git a/aidial_assistant/json_stream/exceptions.py b/aidial_assistant/json_stream/exceptions.py index e4675fc..7e0aeb6 100644 --- a/aidial_assistant/json_stream/exceptions.py +++ b/aidial_assistant/json_stream/exceptions.py @@ -1,10 +1,24 @@ class JsonParsingException(Exception): - pass + def __init__(self, message: str, char_position: int): + super().__init__( + f"Failed to parse json string at position {char_position}: {message}" + ) def unexpected_symbol_error( char: str, char_position: int +) -> JsonParsingException: + return JsonParsingException(f"Unexpected symbol '{char}'.", char_position) + + +def unexpected_end_of_stream_error(char_position: int) -> JsonParsingException: + return JsonParsingException("Unexpected end of stream.", char_position) + + +def invalid_sequence_error( + expected_type: str, string: str, char_position: int ) -> JsonParsingException: return JsonParsingException( - f"Failed to parse json string: unexpected symbol {char} at position {char_position}" + f"Can't parse {expected_type} from the string '{string}'.", + char_position, ) diff --git a/aidial_assistant/json_stream/json_array.py b/aidial_assistant/json_stream/json_array.py index a206916..16a25a4 100644 --- a/aidial_assistant/json_stream/json_array.py +++ b/aidial_assistant/json_stream/json_array.py @@ -1,90 +1,94 @@ -from asyncio import Queue from collections.abc import AsyncIterator from typing import Any from typing_extensions import override -from aidial_assistant.json_stream.characterstream import CharacterStream -from aidial_assistant.json_stream.exceptions import unexpected_symbol_error +from aidial_assistant.json_stream.chunked_char_stream import ( + ChunkedCharStream, + skip_whitespaces, +) +from aidial_assistant.json_stream.exceptions import ( + unexpected_end_of_stream_error, + unexpected_symbol_error, +) from aidial_assistant.json_stream.json_node import ( - ComplexNode, + CompoundNode, JsonNode, - NodeResolver, + NodeParser, ) -from aidial_assistant.json_stream.json_normalizer import JsonNormalizer -class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]): - def __init__(self, char_position: int): - super().__init__(char_position) - self.listener = Queue[JsonNode | None]() - self.array: list[JsonNode] = [] +class JsonArray(CompoundNode[list[Any], JsonNode]): + def __init__(self, source: AsyncIterator[JsonNode], pos: int): + super().__init__(source, pos) + self._array: list[JsonNode] = [] @override def type(self) -> str: return "array" @staticmethod - def token() -> str: - return "[" - - @override - def __aiter__(self) -> AsyncIterator[JsonNode]: - return self - - @override - async def __anext__(self) -> JsonNode: - result = await self.listener.get() - if result is None: - raise StopAsyncIteration - - self.array.append(result) - return result - - @override - async def parse( - self, stream: CharacterStream, dependency_resolver: NodeResolver - ): - normalised_stream = JsonNormalizer(stream) - char = await anext(normalised_stream) - self._char_position = stream.char_position - if not char == JsonArray.token(): - raise unexpected_symbol_error(char, stream.char_position) - - separate = False - while True: - char = await normalised_stream.apeek() - if char == "]": - await anext(normalised_stream) - break - - if char == ",": - if not separate: - raise unexpected_symbol_error(char, stream.char_position) - - await anext(normalised_stream) - separate = False - else: - value = await dependency_resolver.resolve(stream) - await self.listener.put(value) - if isinstance(value, ComplexNode): - await value.parse(stream, dependency_resolver) - separate = True - - await self.listener.put(None) + async def read( + stream: ChunkedCharStream, node_parser: NodeParser + ) -> AsyncIterator[JsonNode]: + try: + await skip_whitespaces(stream) + char = await anext(stream) + if not JsonArray.starts_with(char): + raise unexpected_symbol_error(char, stream.char_position) + + is_comma_expected = False + while True: + await skip_whitespaces(stream) + char = await stream.apeek() + if char == "]": + await stream.askip() + break + + if char == ",": + if not is_comma_expected: + raise unexpected_symbol_error( + char, stream.char_position + ) + + await stream.askip() + is_comma_expected = False + else: + value = await node_parser.parse(stream) + yield value + + if isinstance(value, CompoundNode): + await value.read_to_end() + is_comma_expected = True + except StopAsyncIteration: + raise unexpected_end_of_stream_error(stream.char_position) @override - async def to_string_chunks(self) -> AsyncIterator[str]: - yield JsonArray.token() - separate = False + async def to_chunks(self) -> AsyncIterator[str]: + yield "[" + is_first_element = True async for value in self: - if separate: + if not is_first_element: yield ", " - async for chunk in value.to_string_chunks(): + async for chunk in value.to_chunks(): yield chunk - separate = True + is_first_element = False yield "]" @override def value(self) -> list[JsonNode]: - return [item.value() for item in self.array] + return [item.value() for item in self._array] + + @override + def _accumulate(self, element: JsonNode): + self._array.append(element) + + @classmethod + def parse( + cls, stream: ChunkedCharStream, node_parser: NodeParser + ) -> "JsonArray": + return cls(JsonArray.read(stream, node_parser), stream.char_position) + + @staticmethod + def starts_with(char: str) -> bool: + return char == "[" diff --git a/aidial_assistant/json_stream/json_bool.py b/aidial_assistant/json_stream/json_bool.py index 1b32f74..cff7adb 100644 --- a/aidial_assistant/json_stream/json_bool.py +++ b/aidial_assistant/json_stream/json_bool.py @@ -1,31 +1,36 @@ -import json - from typing_extensions import override -from aidial_assistant.json_stream.json_node import PrimitiveNode +from aidial_assistant.json_stream.exceptions import invalid_sequence_error +from aidial_assistant.json_stream.json_node import AtomicNode TRUE_STRING = "true" FALSE_STRING = "false" +TYPE_STRING = "boolean" -class JsonBoolean(PrimitiveNode[bool]): - def __init__(self, raw_data: str, char_position: int): - super().__init__(char_position) - self._raw_data = raw_data - self._value: bool = json.loads(raw_data) +class JsonBoolean(AtomicNode[bool]): + def __init__(self, raw_data: str, pos: int): + super().__init__(raw_data, pos) + self._value: bool = JsonBoolean._parse_boolean(raw_data, pos) @override def type(self) -> str: - return "boolean" - - @override - def raw_data(self) -> str: - return self._raw_data + return TYPE_STRING @override def value(self) -> bool: return self._value @staticmethod - def is_bool(char: str) -> bool: + def starts_with(char: str) -> bool: return char == "t" or char == "f" + + @staticmethod + def _parse_boolean(string: str, char_position: int) -> bool: + if string == TRUE_STRING: + return True + + if string == FALSE_STRING: + return False + + raise invalid_sequence_error(TYPE_STRING, string, char_position) diff --git a/aidial_assistant/json_stream/json_node.py b/aidial_assistant/json_stream/json_node.py index 922b60f..47ed66a 100644 --- a/aidial_assistant/json_stream/json_node.py +++ b/aidial_assistant/json_stream/json_node.py @@ -4,66 +4,94 @@ from typing_extensions import override -from aidial_assistant.json_stream.characterstream import CharacterStream +from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream +from aidial_assistant.json_stream.exceptions import ( + unexpected_end_of_stream_error, +) -class NodeResolver(ABC): +class NodeParser(ABC): @abstractmethod - async def resolve(self, stream: CharacterStream) -> "JsonNode": + async def parse(self, stream: ChunkedCharStream) -> "JsonNode": pass -T = TypeVar("T") +TValue = TypeVar("TValue") +TElement = TypeVar("TElement") -class JsonNode(ABC, Generic[T]): - def __init__(self, char_position: int): - self._char_position = char_position +class JsonNode(ABC, Generic[TValue]): + def __init__(self, pos: int): + self._pos = pos @abstractmethod def type(self) -> str: pass @abstractmethod - def to_string_chunks(self) -> AsyncIterator[str]: + def to_chunks(self) -> AsyncIterator[str]: pass @property - def char_position(self) -> int: - return self._char_position + def pos(self) -> int: + return self._pos @abstractmethod - def value(self) -> T: + def value(self) -> TValue: pass -class ComplexNode(JsonNode[T], ABC, Generic[T]): - def __init__(self, char_position: int): - super().__init__(char_position) +class CompoundNode( + JsonNode[TValue], AsyncIterator[TElement], ABC, Generic[TValue, TElement] +): + def __init__(self, source: AsyncIterator[TElement], pos: int): + super().__init__(pos) + self._source = source - @abstractmethod - async def parse( - self, stream: CharacterStream, dependency_resolver: NodeResolver - ): - pass + @override + def __aiter__(self) -> AsyncIterator[TElement]: + return self + + @override + async def __anext__(self) -> TElement: + result = await anext(self._source) + self._accumulate(result) + return result -class PrimitiveNode(JsonNode[T], ABC, Generic[T]): @abstractmethod - def raw_data(self) -> str: + def _accumulate(self, element: TElement): pass + async def read_to_end(self): + async for _ in self: + pass + + +class AtomicNode(JsonNode[TValue], ABC, Generic[TValue]): + def __init__(self, raw_data: str, pos: int): + super().__init__(pos) + self._raw_data = raw_data + @override - async def to_string_chunks(self) -> AsyncIterator[str]: - yield self.raw_data() + async def to_chunks(self) -> AsyncIterator[str]: + yield self._raw_data + + @classmethod + async def parse(cls, stream: ChunkedCharStream) -> "AtomicNode": + position = stream.char_position + return cls(await AtomicNode._read_all(stream), position) @staticmethod - async def collect(stream: CharacterStream) -> str: - raw_data = "" - while True: - char = await stream.apeek() - if char.isspace() or char in ",:[]{}": - return raw_data - else: - raw_data += char - await stream.askip() + async def _read_all(stream: ChunkedCharStream) -> str: + try: + raw_data = "" + while True: + char = await stream.apeek() + if char.isspace() or char in ",:[]{}": + return raw_data + else: + raw_data += char + await stream.askip() + except StopAsyncIteration: + raise unexpected_end_of_stream_error(stream.char_position) diff --git a/aidial_assistant/json_stream/json_normalizer.py b/aidial_assistant/json_stream/json_normalizer.py deleted file mode 100644 index 0faf70f..0000000 --- a/aidial_assistant/json_stream/json_normalizer.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing_extensions import override - -from aidial_assistant.json_stream.characterstream import ( - AsyncPeekable, - CharacterStream, -) - - -class JsonNormalizer(AsyncPeekable[str]): - def __init__(self, stream: CharacterStream): - self.stream = stream - - @override - def __aiter__(self): - return self - - @override - async def __anext__(self) -> str: - await self.apeek() - return await anext(self.stream) - - @override - async def apeek(self) -> str: - while True: - char = await self.stream.apeek() - if str.isspace(char): - await anext(self.stream) - continue - else: - return char diff --git a/aidial_assistant/json_stream/json_null.py b/aidial_assistant/json_stream/json_null.py index 94757ab..6d569eb 100644 --- a/aidial_assistant/json_stream/json_null.py +++ b/aidial_assistant/json_stream/json_null.py @@ -1,23 +1,19 @@ from typing_extensions import override -from aidial_assistant.json_stream.exceptions import unexpected_symbol_error -from aidial_assistant.json_stream.json_node import PrimitiveNode +from aidial_assistant.json_stream.exceptions import invalid_sequence_error +from aidial_assistant.json_stream.json_node import AtomicNode NULL_STRING = "null" -class JsonNull(PrimitiveNode[None]): - def __init__(self, raw_data: str, char_position: int): - super().__init__(char_position) - if raw_data != "null": - raise unexpected_symbol_error(raw_data, char_position) +class JsonNull(AtomicNode[None]): + def __init__(self, raw_data: str, pos: int): + super().__init__(raw_data, pos) + if raw_data != NULL_STRING: + raise invalid_sequence_error(NULL_STRING, raw_data, pos) @override def type(self) -> str: - return "null" - - @override - def raw_data(self) -> str: return NULL_STRING @override @@ -25,5 +21,5 @@ def value(self) -> None: return None @staticmethod - def is_null(char: str) -> bool: + def starts_with(char: str) -> bool: return char == "n" diff --git a/aidial_assistant/json_stream/json_number.py b/aidial_assistant/json_stream/json_number.py index 8404452..3ac126b 100644 --- a/aidial_assistant/json_stream/json_number.py +++ b/aidial_assistant/json_stream/json_number.py @@ -2,27 +2,32 @@ from typing_extensions import override -from aidial_assistant.json_stream.json_node import PrimitiveNode +from aidial_assistant.json_stream.exceptions import invalid_sequence_error +from aidial_assistant.json_stream.json_node import AtomicNode +TYPE_STRING = "number" -class JsonNumber(PrimitiveNode[float | int]): - def __init__(self, raw_data: str, char_position: int): - super().__init__(char_position) - self._raw_data = raw_data - self._value: float | int = json.loads(raw_data) - @override - def type(self) -> str: - return "number" +class JsonNumber(AtomicNode[float | int]): + def __init__(self, raw_data: str, pos: int): + super().__init__(raw_data, pos) + self._value: float | int = JsonNumber._parse_number(raw_data, pos) @override - def raw_data(self) -> str: - return self._raw_data + def type(self) -> str: + return TYPE_STRING @override def value(self) -> float | int: return self._value @staticmethod - def is_number(char: str) -> bool: + def starts_with(char: str) -> bool: return char.isdigit() or char == "-" + + @staticmethod + def _parse_number(string: str, char_position: int) -> float | int: + try: + return json.loads(string) + except json.JSONDecodeError: + raise invalid_sequence_error(TYPE_STRING, string, char_position) diff --git a/aidial_assistant/json_stream/json_object.py b/aidial_assistant/json_stream/json_object.py index 4ffff1c..ac84ef6 100644 --- a/aidial_assistant/json_stream/json_object.py +++ b/aidial_assistant/json_stream/json_object.py @@ -1,50 +1,35 @@ import json -from asyncio import Queue from collections.abc import AsyncIterator from typing import Any, Tuple from typing_extensions import override -from aidial_assistant.json_stream.characterstream import CharacterStream -from aidial_assistant.json_stream.exceptions import unexpected_symbol_error +from aidial_assistant.json_stream.chunked_char_stream import ( + ChunkedCharStream, + skip_whitespaces, +) +from aidial_assistant.json_stream.exceptions import ( + unexpected_end_of_stream_error, + unexpected_symbol_error, +) from aidial_assistant.json_stream.json_node import ( - ComplexNode, + CompoundNode, JsonNode, - NodeResolver, + NodeParser, ) -from aidial_assistant.json_stream.json_normalizer import JsonNormalizer from aidial_assistant.json_stream.json_string import JsonString from aidial_assistant.utils.text import join_string -class JsonObject( - ComplexNode[dict[str, Any]], AsyncIterator[Tuple[str, JsonNode]] -): - def __init__(self, char_position: int): - super().__init__(char_position) - self.listener = Queue[Tuple[str, JsonNode] | None]() - self._object: dict[str, JsonNode] = {} +class JsonObject(CompoundNode[dict[str, Any], Tuple[str, JsonNode]]): + def __init__(self, source: AsyncIterator[Tuple[str, JsonNode]], pos: int): + super().__init__(source, pos) + self._object = {} @override def type(self) -> str: return "object" - def __aiter__(self) -> AsyncIterator[Tuple[str, JsonNode]]: - return self - - @override - async def __anext__(self) -> Tuple[str, JsonNode]: - result = await self.listener.get() - if result is None: - raise StopAsyncIteration - - self._object[result[0]] = result[1] - return result - - @staticmethod - def token() -> str: - return "{" - async def get(self, key: str) -> JsonNode: if key in self._object.keys(): return self._object[key] @@ -55,62 +40,86 @@ async def get(self, key: str) -> JsonNode: raise KeyError(key) - @override - async def parse( - self, stream: CharacterStream, dependency_resolver: NodeResolver - ): - normalised_stream = JsonNormalizer(stream) - char = await anext(normalised_stream) - if not char == JsonObject.token(): - raise unexpected_symbol_error(char, stream.char_position) - - separate = False - while True: - char = await normalised_stream.apeek() - - if char == "}": - await normalised_stream.askip() - break - - if char == ",": - if not separate: - raise unexpected_symbol_error(char, stream.char_position) - - await normalised_stream.askip() - separate = False - elif char == '"': - if separate: - raise unexpected_symbol_error(char, stream.char_position) - - key = await join_string(JsonString.read(stream)) - colon = await anext(normalised_stream) - if not colon == ":": - raise unexpected_symbol_error(colon, stream.char_position) - - value = await dependency_resolver.resolve(stream) - await self.listener.put((key, value)) - if isinstance(value, ComplexNode): - await value.parse(stream, dependency_resolver) - separate = True - else: + @staticmethod + async def read( + stream: ChunkedCharStream, node_parser: NodeParser + ) -> AsyncIterator[Tuple[str, JsonNode]]: + try: + await skip_whitespaces(stream) + char = await anext(stream) + if not JsonObject.starts_with(char): raise unexpected_symbol_error(char, stream.char_position) - await self.listener.put(None) + is_comma_expected = False + while True: + await skip_whitespaces(stream) + char = await stream.apeek() + + if char == "}": + await stream.askip() + break + + if char == ",": + if not is_comma_expected: + raise unexpected_symbol_error( + char, stream.char_position + ) + + await stream.askip() + is_comma_expected = False + elif JsonString.starts_with(char): + if is_comma_expected: + raise unexpected_symbol_error( + char, stream.char_position + ) + + key = await join_string(JsonString.read(stream)) + await skip_whitespaces(stream) + colon = await anext(stream) + if not colon == ":": + raise unexpected_symbol_error( + colon, stream.char_position + ) + + value = await node_parser.parse(stream) + yield key, value + + if isinstance(value, CompoundNode): + await value.read_to_end() + is_comma_expected = True + else: + raise unexpected_symbol_error(char, stream.char_position) + except StopAsyncIteration: + raise unexpected_end_of_stream_error(stream.char_position) @override - async def to_string_chunks(self) -> AsyncIterator[str]: - yield JsonObject.token() - separate = False + async def to_chunks(self) -> AsyncIterator[str]: + yield "{" + is_first_entry = True async for key, value in self: - if separate: + if not is_first_entry: yield ", " yield json.dumps(key) yield ": " - async for chunk in value.to_string_chunks(): + async for chunk in value.to_chunks(): yield chunk - separate = True + is_first_entry = False yield "}" @override def value(self) -> dict[str, Any]: return {k: v.value() for k, v in self._object.items()} + + @override + def _accumulate(self, element: Tuple[str, JsonNode]): + self._object[element[0]] = element[1] + + @classmethod + def parse( + cls, stream: ChunkedCharStream, node_parser: NodeParser + ) -> "JsonObject": + return cls(JsonObject.read(stream, node_parser), stream.char_position) + + @staticmethod + def starts_with(char: str) -> bool: + return char == "{" diff --git a/aidial_assistant/json_stream/json_parser.py b/aidial_assistant/json_stream/json_parser.py index 6e33c6a..5d8dcb1 100644 --- a/aidial_assistant/json_stream/json_parser.py +++ b/aidial_assistant/json_stream/json_parser.py @@ -1,20 +1,26 @@ -from asyncio import TaskGroup -from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator +from typing_extensions import override -from aidial_assistant.json_stream.characterstream import CharacterStream -from aidial_assistant.json_stream.exceptions import JsonParsingException +from aidial_assistant.json_stream.chunked_char_stream import ( + ChunkedCharStream, + skip_whitespaces, +) +from aidial_assistant.json_stream.exceptions import ( + unexpected_end_of_stream_error, + unexpected_symbol_error, +) from aidial_assistant.json_stream.json_array import JsonArray -from aidial_assistant.json_stream.json_node import ComplexNode, JsonNode +from aidial_assistant.json_stream.json_bool import JsonBoolean +from aidial_assistant.json_stream.json_node import JsonNode, NodeParser +from aidial_assistant.json_stream.json_null import JsonNull +from aidial_assistant.json_stream.json_number import JsonNumber from aidial_assistant.json_stream.json_object import JsonObject -from aidial_assistant.json_stream.json_root import JsonRoot, RootNodeResolver from aidial_assistant.json_stream.json_string import JsonString def array_node(node: JsonNode) -> JsonArray: if not isinstance(node, JsonArray): raise TypeError( - f"Expected json array at position {node.char_position}, got {node.type}" + f"Expected json array at position {node.pos}, got {node.type()}" ) return node @@ -23,7 +29,7 @@ def array_node(node: JsonNode) -> JsonArray: def object_node(node: JsonNode) -> JsonObject: if not isinstance(node, JsonObject): raise TypeError( - f"Expected json object at position {node.char_position}, got {node.type}" + f"Expected json object at position {node.pos}, got {node.type()}" ) return node @@ -32,40 +38,36 @@ def object_node(node: JsonNode) -> JsonObject: def string_node(node: JsonNode) -> JsonString: if not isinstance(node, JsonString): raise TypeError( - f"Expected json string at position {node.char_position}, got {node.type}" + f"Expected json string at position {node.pos}, got {node.type()}" ) return node -class JsonParser: - @staticmethod - @asynccontextmanager - async def parse(stream: CharacterStream) -> AsyncGenerator[JsonNode, Any]: - root = JsonRoot() +class JsonParser(NodeParser): + @override + async def parse(self, stream: ChunkedCharStream) -> JsonNode: try: - async with TaskGroup() as tg: - task = tg.create_task(JsonParser._parse_root(root, stream)) - try: - yield await root.node() - finally: - await task - except ExceptionGroup as e: - raise e.exceptions[0] - - @staticmethod - async def _parse_root(root: JsonRoot, stream: CharacterStream): - try: - node_resolver = RootNodeResolver() - await root.parse(stream, node_resolver) - node = await root.node() - if isinstance(node, ComplexNode): - await node.parse(stream, node_resolver) + await skip_whitespaces(stream) + char = await stream.apeek() + if JsonObject.starts_with(char): + return JsonObject.parse(stream, self) + + if JsonArray.starts_with(char): + return JsonArray.parse(stream, self) + + if JsonString.starts_with(char): + return JsonString.parse(stream) + + if JsonNumber.starts_with(char): + return await JsonNumber.parse(stream) + + if JsonNull.starts_with(char): + return await JsonNull.parse(stream) + + if JsonBoolean.starts_with(char): + return await JsonBoolean.parse(stream) except StopAsyncIteration: - raise JsonParsingException( - "Failed to parse json: unexpected end of stream." - ) - finally: - # flush the stream - async for _ in stream: - pass + raise unexpected_end_of_stream_error(stream.char_position) + + raise unexpected_symbol_error(char, stream.char_position) diff --git a/aidial_assistant/json_stream/json_root.py b/aidial_assistant/json_stream/json_root.py deleted file mode 100644 index c155d14..0000000 --- a/aidial_assistant/json_stream/json_root.py +++ /dev/null @@ -1,87 +0,0 @@ -import asyncio -from typing import Any, AsyncIterator - -from typing_extensions import override - -from aidial_assistant.json_stream.characterstream import CharacterStream -from aidial_assistant.json_stream.exceptions import unexpected_symbol_error -from aidial_assistant.json_stream.json_array import JsonArray -from aidial_assistant.json_stream.json_bool import JsonBoolean -from aidial_assistant.json_stream.json_node import ( - ComplexNode, - JsonNode, - NodeResolver, - PrimitiveNode, -) -from aidial_assistant.json_stream.json_normalizer import JsonNormalizer -from aidial_assistant.json_stream.json_null import JsonNull -from aidial_assistant.json_stream.json_number import JsonNumber -from aidial_assistant.json_stream.json_object import JsonObject -from aidial_assistant.json_stream.json_string import JsonString - - -class RootNodeResolver(NodeResolver): - @override - async def resolve(self, stream: CharacterStream) -> JsonNode: - normalised_stream = JsonNormalizer(stream) - char = await normalised_stream.apeek() - if char == JsonObject.token(): - return JsonObject(stream.char_position) - - if char == JsonString.token(): - return JsonString(stream.char_position) - - if char == JsonArray.token(): - return JsonArray(stream.char_position) - - if JsonNumber.is_number(char): - position = stream.char_position - return JsonNumber(await PrimitiveNode.collect(stream), position) - - if JsonNull.is_null(char): - position = stream.char_position - return JsonNull(await PrimitiveNode.collect(stream), position) - - if JsonBoolean.is_bool(char): - position = stream.char_position - return JsonBoolean(await PrimitiveNode.collect(stream), position) - - raise unexpected_symbol_error(char, stream.char_position) - - -class JsonRoot(ComplexNode[Any]): - def __init__(self): - super().__init__(0) - self._node: JsonNode | None = None - self._event = asyncio.Event() - - async def node(self) -> JsonNode: - await self._event.wait() - if self._node is None: - # Should never happen - raise Exception("Node was not parsed") - - return self._node - - @override - def type(self) -> str: - return "root" - - @override - async def parse( - self, stream: CharacterStream, dependency_resolver: NodeResolver - ): - try: - self._node = await dependency_resolver.resolve(stream) - finally: - self._event.set() - - @override - async def to_string_chunks(self) -> AsyncIterator[str]: - node = await self.node() - async for token in node.to_string_chunks(): - yield token - - @override - def value(self) -> Any: - return self._node.value() if self._node else None diff --git a/aidial_assistant/json_stream/json_string.py b/aidial_assistant/json_stream/json_string.py index e847e7a..0ff5401 100644 --- a/aidial_assistant/json_stream/json_string.py +++ b/aidial_assistant/json_stream/json_string.py @@ -1,82 +1,77 @@ import json -from asyncio import Queue from collections.abc import AsyncIterator from typing_extensions import override -from aidial_assistant.json_stream.characterstream import CharacterStream -from aidial_assistant.json_stream.exceptions import unexpected_symbol_error -from aidial_assistant.json_stream.json_node import ComplexNode, NodeResolver +from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream +from aidial_assistant.json_stream.exceptions import ( + JsonParsingException, + unexpected_end_of_stream_error, + unexpected_symbol_error, +) +from aidial_assistant.json_stream.json_node import CompoundNode -class JsonString(ComplexNode[str], AsyncIterator[str]): - def __init__(self, char_position: int): - super().__init__(char_position) - self._listener = Queue[str | None]() +class JsonString(CompoundNode[str, str]): + def __init__(self, source: AsyncIterator[str], pos: int): + super().__init__(source, pos) self._buffer = "" @override def type(self) -> str: return "string" - @staticmethod - def token() -> str: - return '"' - - def __aiter__(self) -> AsyncIterator[str]: - return self - @override - async def __anext__(self) -> str: - result = await self._listener.get() - if result is None: - raise StopAsyncIteration - - self._buffer += result - return result + def _accumulate(self, element: str): + self._buffer += element @override - async def parse( - self, stream: CharacterStream, dependency_resolver: NodeResolver - ): - async for chunk in JsonString.read(stream): - await self._listener.put(chunk) - await self._listener.put(None) - - @override - async def to_string_chunks(self) -> AsyncIterator[str]: - yield JsonString.token() + async def to_chunks(self) -> AsyncIterator[str]: + yield '"' async for chunk in self: yield json.dumps(chunk)[1:-1] - yield JsonString.token() + yield '"' + + @override + def value(self) -> str: + return self._buffer + + @classmethod + def parse(cls, stream: ChunkedCharStream) -> "JsonString": + return cls(JsonString.read(stream), stream.char_position) @staticmethod - async def read(stream: CharacterStream) -> AsyncIterator[str]: - char = await anext(stream) - if not char == JsonString.token(): - raise unexpected_symbol_error(char, stream.char_position) - result = "" - chunk_position = stream.chunk_position - while True: + async def read(stream: ChunkedCharStream) -> AsyncIterator[str]: + try: char = await anext(stream) - if char == JsonString.token(): - break - - result += await JsonString.escape(stream) if char == "\\" else char - if chunk_position != stream.chunk_position: - yield result - result = "" - chunk_position = stream.chunk_position + if not JsonString.starts_with(char): + raise unexpected_symbol_error(char, stream.char_position) + result = "" + chunk_position = stream.chunk_position + while True: + char = await anext(stream) + if char == '"': + break + + result += ( + await JsonString._escape(stream) if char == "\\" else char + ) + if chunk_position != stream.chunk_position: + yield result + result = "" + chunk_position = stream.chunk_position + except StopAsyncIteration: + raise unexpected_end_of_stream_error(stream.char_position) if result: yield result @staticmethod - async def escape(stream: CharacterStream) -> str: + async def _escape(stream: ChunkedCharStream) -> str: char = await anext(stream) if char == "u": unicode_sequence = "".join([await anext(stream) for _ in range(4)]) # type: ignore - return str(int(unicode_sequence, 16)) + return chr(int(unicode_sequence, 16)) if char in '"\\/': return char if char == "b": @@ -90,10 +85,11 @@ async def escape(stream: CharacterStream) -> str: elif char == "t": return "\t" else: - # Ignore when model cannot escape text properly - return char - # raise ValueError(f"Unexpected escape sequence: \\{char}" + " at " + str(stream.char_position - 1)) + raise JsonParsingException( + f"Unexpected escape sequence: \\{char}.", + stream.char_position - 1, + ) - @override - def value(self) -> str: - return self._buffer + @staticmethod + def starts_with(char: str) -> bool: + return char == '"' diff --git a/tests/unit_tests/json_stream/test_json_stream.py b/tests/unit_tests/json_stream/test_json_stream.py index 0fdacd3..6d3cfa2 100644 --- a/tests/unit_tests/json_stream/test_json_stream.py +++ b/tests/unit_tests/json_stream/test_json_stream.py @@ -4,8 +4,13 @@ import pytest -from aidial_assistant.json_stream.characterstream import CharacterStream -from aidial_assistant.json_stream.json_parser import JsonParser +from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream +from aidial_assistant.json_stream.exceptions import JsonParsingException +from aidial_assistant.json_stream.json_parser import ( + JsonParser, + object_node, + string_node, +) from aidial_assistant.utils.text import join_string JSON_STRINGS = [ @@ -81,7 +86,7 @@ """, """ { - "text": "Hello, 世界" + "text": "Hello, \\u4e16\\u754c" } """, """ @@ -148,10 +153,44 @@ async def _split_into_chunks(json_string: str) -> AsyncIterator[str]: @pytest.mark.asyncio @pytest.mark.parametrize("json_string", JSON_STRINGS) async def test_json_parsing(json_string: str): - async with JsonParser.parse( - CharacterStream(_split_into_chunks(json_string)) - ) as node: - actual = await join_string(node.to_string_chunks()) - expected = json.dumps(json.loads(json_string)) + node = await JsonParser().parse( + ChunkedCharStream(_split_into_chunks(json_string)) + ) + actual = await join_string(node.to_chunks()) + expected = json.dumps(json.loads(json_string)) - assert actual == expected + assert actual == expected + + +@pytest.mark.asyncio +async def test_incomplete_json_parsing(): + incomplete_json_string = """ + { + "test": "field" + """ + node = object_node( + await JsonParser().parse( + ChunkedCharStream(_split_into_chunks(incomplete_json_string)) + ) + ) + _, value = await anext(node) + await string_node(value).read_to_end() + + assert node.value() == {"test": "field"} + + +@pytest.mark.asyncio +async def test_incorrect_escape_sequence(): + incomplete_json_string = '"\\k"' + node = string_node( + await JsonParser().parse( + ChunkedCharStream(_split_into_chunks(incomplete_json_string)) + ) + ) + + with pytest.raises(JsonParsingException) as exc_info: + await node.read_to_end() + + assert str(exc_info.value) == ( + "Failed to parse json string at position 2: Unexpected escape sequence: \\k." + )