-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: read json lazily to tolerate incorrect closing brackets (#37)
* Read json lazily to tolerate incorrect closing brackets that models occasionally generate.
- Loading branch information
1 parent
d42ee40
commit 248bee4
Showing
16 changed files
with
462 additions
and
482 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,24 @@ | ||
class JsonParsingException(Exception): | ||
pass | ||
def __init__(self, message: str, char_position: int): | ||
super().__init__( | ||
f"Failed to parse json string at position {char_position}: {message}" | ||
) | ||
|
||
|
||
def unexpected_symbol_error( | ||
char: str, char_position: int | ||
) -> JsonParsingException: | ||
return JsonParsingException(f"Unexpected symbol '{char}'.", char_position) | ||
|
||
|
||
def unexpected_end_of_stream_error(char_position: int) -> JsonParsingException: | ||
return JsonParsingException("Unexpected end of stream.", char_position) | ||
|
||
|
||
def invalid_sequence_error( | ||
expected_type: str, string: str, char_position: int | ||
) -> JsonParsingException: | ||
return JsonParsingException( | ||
f"Failed to parse json string: unexpected symbol {char} at position {char_position}" | ||
f"Can't parse {expected_type} from the string '{string}'.", | ||
char_position, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,94 @@ | ||
from asyncio import Queue | ||
from collections.abc import AsyncIterator | ||
from typing import Any | ||
|
||
from typing_extensions import override | ||
|
||
from aidial_assistant.json_stream.characterstream import CharacterStream | ||
from aidial_assistant.json_stream.exceptions import unexpected_symbol_error | ||
from aidial_assistant.json_stream.chunked_char_stream import ( | ||
ChunkedCharStream, | ||
skip_whitespaces, | ||
) | ||
from aidial_assistant.json_stream.exceptions import ( | ||
unexpected_end_of_stream_error, | ||
unexpected_symbol_error, | ||
) | ||
from aidial_assistant.json_stream.json_node import ( | ||
ComplexNode, | ||
CompoundNode, | ||
JsonNode, | ||
NodeResolver, | ||
NodeParser, | ||
) | ||
from aidial_assistant.json_stream.json_normalizer import JsonNormalizer | ||
|
||
|
||
class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]): | ||
def __init__(self, char_position: int): | ||
super().__init__(char_position) | ||
self.listener = Queue[JsonNode | None]() | ||
self.array: list[JsonNode] = [] | ||
class JsonArray(CompoundNode[list[Any], JsonNode]): | ||
def __init__(self, source: AsyncIterator[JsonNode], pos: int): | ||
super().__init__(source, pos) | ||
self._array: list[JsonNode] = [] | ||
|
||
@override | ||
def type(self) -> str: | ||
return "array" | ||
|
||
@staticmethod | ||
def token() -> str: | ||
return "[" | ||
|
||
@override | ||
def __aiter__(self) -> AsyncIterator[JsonNode]: | ||
return self | ||
|
||
@override | ||
async def __anext__(self) -> JsonNode: | ||
result = await self.listener.get() | ||
if result is None: | ||
raise StopAsyncIteration | ||
|
||
self.array.append(result) | ||
return result | ||
|
||
@override | ||
async def parse( | ||
self, stream: CharacterStream, dependency_resolver: NodeResolver | ||
): | ||
normalised_stream = JsonNormalizer(stream) | ||
char = await anext(normalised_stream) | ||
self._char_position = stream.char_position | ||
if not char == JsonArray.token(): | ||
raise unexpected_symbol_error(char, stream.char_position) | ||
|
||
separate = False | ||
while True: | ||
char = await normalised_stream.apeek() | ||
if char == "]": | ||
await anext(normalised_stream) | ||
break | ||
|
||
if char == ",": | ||
if not separate: | ||
raise unexpected_symbol_error(char, stream.char_position) | ||
|
||
await anext(normalised_stream) | ||
separate = False | ||
else: | ||
value = await dependency_resolver.resolve(stream) | ||
await self.listener.put(value) | ||
if isinstance(value, ComplexNode): | ||
await value.parse(stream, dependency_resolver) | ||
separate = True | ||
|
||
await self.listener.put(None) | ||
async def read( | ||
stream: ChunkedCharStream, node_parser: NodeParser | ||
) -> AsyncIterator[JsonNode]: | ||
try: | ||
await skip_whitespaces(stream) | ||
char = await anext(stream) | ||
if not JsonArray.starts_with(char): | ||
raise unexpected_symbol_error(char, stream.char_position) | ||
|
||
is_comma_expected = False | ||
while True: | ||
await skip_whitespaces(stream) | ||
char = await stream.apeek() | ||
if char == "]": | ||
await stream.askip() | ||
break | ||
|
||
if char == ",": | ||
if not is_comma_expected: | ||
raise unexpected_symbol_error( | ||
char, stream.char_position | ||
) | ||
|
||
await stream.askip() | ||
is_comma_expected = False | ||
else: | ||
value = await node_parser.parse(stream) | ||
yield value | ||
|
||
if isinstance(value, CompoundNode): | ||
await value.read_to_end() | ||
is_comma_expected = True | ||
except StopAsyncIteration: | ||
raise unexpected_end_of_stream_error(stream.char_position) | ||
|
||
@override | ||
async def to_string_chunks(self) -> AsyncIterator[str]: | ||
yield JsonArray.token() | ||
separate = False | ||
async def to_chunks(self) -> AsyncIterator[str]: | ||
yield "[" | ||
is_first_element = True | ||
async for value in self: | ||
if separate: | ||
if not is_first_element: | ||
yield ", " | ||
async for chunk in value.to_string_chunks(): | ||
async for chunk in value.to_chunks(): | ||
yield chunk | ||
separate = True | ||
is_first_element = False | ||
yield "]" | ||
|
||
@override | ||
def value(self) -> list[JsonNode]: | ||
return [item.value() for item in self.array] | ||
return [item.value() for item in self._array] | ||
|
||
@override | ||
def _accumulate(self, element: JsonNode): | ||
self._array.append(element) | ||
|
||
@classmethod | ||
def parse( | ||
cls, stream: ChunkedCharStream, node_parser: NodeParser | ||
) -> "JsonArray": | ||
return cls(JsonArray.read(stream, node_parser), stream.char_position) | ||
|
||
@staticmethod | ||
def starts_with(char: str) -> bool: | ||
return char == "[" |
Oops, something went wrong.