Skip to content

Commit

Permalink
feat: read json lazily to tolerate incorrect closing brackets (#37)
Browse files Browse the repository at this point in the history
* Read json lazily to tolerate incorrect closing brackets that models occasionally generate.
  • Loading branch information
Oleksii-Klimov authored Dec 4, 2023
1 parent d42ee40 commit 248bee4
Show file tree
Hide file tree
Showing 16 changed files with 462 additions and 482 deletions.
2 changes: 1 addition & 1 deletion aidial_assistant/application/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def build(self, **kwargs) -> Template:
{
"command": "<command name>",
"args": [
// <array of arguments>
"<arg1>", "<arg2>", ...
]
}
]
Expand Down
62 changes: 31 additions & 31 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
skip_to_json_start,
)
from aidial_assistant.commands.base import Command, FinalCommand
from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import JsonParsingException
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_parser import JsonParser
Expand Down Expand Up @@ -166,39 +166,39 @@ async def _run_with_protocol_failure_retries(
async def _run_commands(
self, chunk_stream: AsyncIterator[str], callback: ChainCallback
) -> Tuple[list[CommandInvocation], list[CommandResult]]:
char_stream = CharacterStream(chunk_stream)
char_stream = ChunkedCharStream(chunk_stream)
await skip_to_json_start(char_stream)

async with JsonParser.parse(char_stream) as root_node:
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
async for invocation in request_reader.parse_invocations():
command_name = await invocation.parse_name()
command = self._create_command(command_name)
args = invocation.parse_args()
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
message = await anext(args)
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_string_chunks(),
callback.result_callback(),
)
break
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
)
root_node = await JsonParser().parse(char_stream)
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
async for invocation in request_reader.parse_invocations():
command_name = await invocation.parse_name()
command = self._create_command(command_name)
args = invocation.parse_args()
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
message = await anext(args)
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_chunks(),
callback.result_callback(),
)
break
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
)

commands.append(
cast(CommandInvocation, invocation.node.value())
)
responses.append(response)
commands.append(
cast(CommandInvocation, invocation.node.value())
)
responses.append(response)

return commands, responses
return commands, responses

def _create_command(self, name: str) -> Command:
if name not in self.command_dict:
Expand Down Expand Up @@ -237,7 +237,7 @@ async def _to_args(
arg_callback = args_callback.arg_callback()
arg_callback.on_arg_start()
result = ""
async for chunk in arg.to_string_chunks():
async for chunk in arg.to_chunks():
arg_callback.on_arg(chunk)
result += chunk
arg_callback.on_arg_end()
Expand Down
6 changes: 3 additions & 3 deletions aidial_assistant/chain/model_response_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import AsyncIterator

from aidial_assistant.json_stream.characterstream import AsyncPeekable
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.json_array import JsonArray
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_object import JsonObject
Expand All @@ -16,12 +16,12 @@ class AssistantProtocolException(Exception):
pass


async def skip_to_json_start(stream: AsyncPeekable[str]):
async def skip_to_json_start(stream: ChunkedCharStream):
# Some models tend to provide explanations for their replies regardless of what the prompt says.
try:
while True:
char = await stream.apeek()
if char == JsonObject.token():
if JsonObject.starts_with(char):
break

await anext(stream)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
from abc import ABC, abstractmethod
from abc import ABC
from collections.abc import AsyncIterator
from typing import Generic, TypeVar

from typing_extensions import override

T = TypeVar("T")


class AsyncPeekable(ABC, Generic[T], AsyncIterator[T]):
@abstractmethod
async def apeek(self) -> T:
pass

async def askip(self) -> None:
await anext(self)


class CharacterStream(AsyncPeekable[str]):
class ChunkedCharStream(ABC, AsyncIterator[str]):
def __init__(self, source: AsyncIterator[str]):
self._source = source
self._chunk: str = ""
Expand All @@ -33,18 +21,29 @@ async def __anext__(self) -> str:
self._next_char_offset += 1
return result

@override
async def apeek(self) -> str:
while self._next_char_offset == len(self._chunk):
self._chunk_position += len(self._chunk)
self._chunk = await anext(self._source) # type: ignore
self._next_char_offset = 0
return self._chunk[self._next_char_offset]

async def askip(self):
await anext(self)

@property
def chunk_position(self) -> int:
return self._chunk_position

@property
def char_position(self) -> int:
return self._chunk_position + self._next_char_offset


async def skip_whitespaces(stream: ChunkedCharStream):
while True:
char = await stream.apeek()
if not str.isspace(char):
break

await stream.askip()
18 changes: 16 additions & 2 deletions aidial_assistant/json_stream/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
class JsonParsingException(Exception):
pass
def __init__(self, message: str, char_position: int):
super().__init__(
f"Failed to parse json string at position {char_position}: {message}"
)


def unexpected_symbol_error(
char: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(f"Unexpected symbol '{char}'.", char_position)


def unexpected_end_of_stream_error(char_position: int) -> JsonParsingException:
return JsonParsingException("Unexpected end of stream.", char_position)


def invalid_sequence_error(
expected_type: str, string: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(
f"Failed to parse json string: unexpected symbol {char} at position {char_position}"
f"Can't parse {expected_type} from the string '{string}'.",
char_position,
)
134 changes: 69 additions & 65 deletions aidial_assistant/json_stream/json_array.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,94 @@
from asyncio import Queue
from collections.abc import AsyncIterator
from typing import Any

from typing_extensions import override

from aidial_assistant.json_stream.characterstream import CharacterStream
from aidial_assistant.json_stream.exceptions import unexpected_symbol_error
from aidial_assistant.json_stream.chunked_char_stream import (
ChunkedCharStream,
skip_whitespaces,
)
from aidial_assistant.json_stream.exceptions import (
unexpected_end_of_stream_error,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_node import (
ComplexNode,
CompoundNode,
JsonNode,
NodeResolver,
NodeParser,
)
from aidial_assistant.json_stream.json_normalizer import JsonNormalizer


class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]):
def __init__(self, char_position: int):
super().__init__(char_position)
self.listener = Queue[JsonNode | None]()
self.array: list[JsonNode] = []
class JsonArray(CompoundNode[list[Any], JsonNode]):
def __init__(self, source: AsyncIterator[JsonNode], pos: int):
super().__init__(source, pos)
self._array: list[JsonNode] = []

@override
def type(self) -> str:
return "array"

@staticmethod
def token() -> str:
return "["

@override
def __aiter__(self) -> AsyncIterator[JsonNode]:
return self

@override
async def __anext__(self) -> JsonNode:
result = await self.listener.get()
if result is None:
raise StopAsyncIteration

self.array.append(result)
return result

@override
async def parse(
self, stream: CharacterStream, dependency_resolver: NodeResolver
):
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
self._char_position = stream.char_position
if not char == JsonArray.token():
raise unexpected_symbol_error(char, stream.char_position)

separate = False
while True:
char = await normalised_stream.apeek()
if char == "]":
await anext(normalised_stream)
break

if char == ",":
if not separate:
raise unexpected_symbol_error(char, stream.char_position)

await anext(normalised_stream)
separate = False
else:
value = await dependency_resolver.resolve(stream)
await self.listener.put(value)
if isinstance(value, ComplexNode):
await value.parse(stream, dependency_resolver)
separate = True

await self.listener.put(None)
async def read(
stream: ChunkedCharStream, node_parser: NodeParser
) -> AsyncIterator[JsonNode]:
try:
await skip_whitespaces(stream)
char = await anext(stream)
if not JsonArray.starts_with(char):
raise unexpected_symbol_error(char, stream.char_position)

is_comma_expected = False
while True:
await skip_whitespaces(stream)
char = await stream.apeek()
if char == "]":
await stream.askip()
break

if char == ",":
if not is_comma_expected:
raise unexpected_symbol_error(
char, stream.char_position
)

await stream.askip()
is_comma_expected = False
else:
value = await node_parser.parse(stream)
yield value

if isinstance(value, CompoundNode):
await value.read_to_end()
is_comma_expected = True
except StopAsyncIteration:
raise unexpected_end_of_stream_error(stream.char_position)

@override
async def to_string_chunks(self) -> AsyncIterator[str]:
yield JsonArray.token()
separate = False
async def to_chunks(self) -> AsyncIterator[str]:
yield "["
is_first_element = True
async for value in self:
if separate:
if not is_first_element:
yield ", "
async for chunk in value.to_string_chunks():
async for chunk in value.to_chunks():
yield chunk
separate = True
is_first_element = False
yield "]"

@override
def value(self) -> list[JsonNode]:
return [item.value() for item in self.array]
return [item.value() for item in self._array]

@override
def _accumulate(self, element: JsonNode):
self._array.append(element)

@classmethod
def parse(
cls, stream: ChunkedCharStream, node_parser: NodeParser
) -> "JsonArray":
return cls(JsonArray.read(stream, node_parser), stream.char_position)

@staticmethod
def starts_with(char: str) -> bool:
return char == "["
Loading

0 comments on commit 248bee4

Please sign in to comment.