Skip to content

Commit

Permalink
added streaming support. Maybe we need refactor to LLMSession and Str…
Browse files Browse the repository at this point in the history
…eamingLLMSession.
  • Loading branch information
Alex-Karmazin committed Jun 7, 2024
1 parent cf8ab85 commit 5a2be2b
Showing 1 changed file with 118 additions and 9 deletions.
127 changes: 118 additions & 9 deletions just_agents/llm_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from litellm import ModelResponse, completion, Message
from litellm.utils import ChatCompletionMessageToolCall, Function
from litellm import ModelResponse, completion, acompletion, Message
from typing import Any, Dict, List, Optional, Callable
import litellm
import json
Expand All @@ -9,9 +10,101 @@

from just_agents.llm_options import LLAMA3
from just_agents.memory import Memory
from starlette.responses import ContentStream
import time

OnCompletion = Callable[[ModelResponse], None]

class FunctionParser:
id:str = ""
name:str = ""
arguments:str = ""

def __init__(self, id:str):
self.id = id

def parsed(self, name:str, arguments:str):
if name:
self.name += name
if arguments:
self.arguments += arguments
if len(self.name) > 0 and len(self.arguments) > 0 and self.arguments.endswith("}"):
return True
return False


def get_chunk(i:int, delta:str, options: Dict):
chunk = {
"id": i,
"object": "chat.completion.chunk",
"created": time.time(),
"model": options["model"],
"choices": [{"delta": {"content": delta}}],
}
return json.dumps(chunk)


def process_function(parser:FunctionParser, available_tools: Dict[str, Callable]):
function_args = json.loads(parser.arguments)
function_to_call = available_tools[parser.name]
try:
function_response = function_to_call(**function_args)
except Exception as e:
function_response = str(e)
message = Message(role="tool", content=function_response, name=parser.name,
tool_call_id=parser.id) # TODO need to track arguemnts , arguments=function_args
return message


def get_tool_call_message(parsers:list[FunctionParser]) -> Message:
tool_calls = []
for parser in parsers:
tool_calls.append({"type":"function",
"id":parser.id, "function":{"name":parser.name, "arguments":parser.arguments}})
return Message(role="assistant", content=None, tool_calls=tool_calls)


async def _resp_async_generator(memory: Memory, options: Dict, available_tools: Dict[str, Callable]):
response: ModelResponse = completion(messages=memory.messages, stream=True, **options)
parser:FunctionParser = None
function_response = None
tool_calls_message = None
tool_messages:list[Message] = []
parsers:list[FunctionParser] = []
deltas:list[str] = []
for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {get_chunk(i, delta, options)}\n\n"

tool_calls = part["choices"][0]["delta"].get("tool_calls")
if tool_calls and (available_tools is not None):
if not parser:
parser = FunctionParser(id = tool_calls[0].id)
if parser.parsed(tool_calls[0].function.name, tool_calls[0].function.arguments):
tool_messages.append(process_function(parser, available_tools))
parsers.append(parser)
parser = None

if len(tool_messages) > 0:
memory.add_message(get_tool_call_message(parsers))
for message in tool_messages:
memory.add_message(message)
response = completion(messages=memory.messages, stream=True, **options)
deltas = []
for i, part in enumerate(response):
delta: str = part["choices"][0]["delta"].get("content") # type: ignore
if delta:
deltas.append(delta)
yield f"data: {get_chunk(i, delta, options)}\n\n"
memory.add_message(Message(role="assistant", content="".join(deltas)))
elif len(deltas) > 0:
memory.add_message(Message(role="assistant", content="".join(deltas)))

yield "data: [DONE]\n\n"


@dataclass(kw_only=True)
class LLMSession:
llm_options: Dict[str, Any] = field(default_factory=lambda: LLAMA3)
Expand Down Expand Up @@ -51,28 +144,43 @@ def instruct(self, prompt: str):
self.memory.add_message(system_instruction, True)
return system_instruction

def query(self, prompt: str, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def query(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
"""
Query large language model
:param prompt:
:param stream:
:param run_callbacks:
:param output:
:return:
"""

question = Message(role="user", content=prompt)
self.memory.add_message(question, run_callbacks)
return self._query(stream, run_callbacks, output)
return self._query(run_callbacks, output)


def query_all(self, messages: list, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
self.memory.add_messages(messages, run_callbacks)
return self._query(run_callbacks, output)


def query_all(self, messages: list, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def stream_all(self, messages: list, run_callbacks: bool = True) -> ContentStream:
self.memory.add_messages(messages, run_callbacks)
return self._query(stream, run_callbacks, output)
return self._stream()


def stream(self, prompt: str, run_callbacks: bool = True, output: Optional[Path] = None) -> ContentStream:
question = Message(role="user", content=prompt)
self.memory.add_message(question, run_callbacks)
return self._stream()


def _query(self, stream: bool = False, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
def _stream(self) -> ContentStream:
return _resp_async_generator(self.memory, self.llm_options, self.available_tools)


def _query(self, run_callbacks: bool = True, output: Optional[Path] = None) -> str:
options: Dict = self.llm_options
response: ModelResponse = completion(messages=self.memory.messages, stream=stream, **options)
response: ModelResponse = completion(messages=self.memory.messages, stream=False, **options)
self._process_response(response)
executed_response = self._process_function_calls(response)
if executed_response is not None:
Expand All @@ -90,7 +198,7 @@ def _query(self, stream: bool = False, run_callbacks: bool = True, output: Optio
def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResponse]:
"""
processes function calls in the response
:param response:
:param response_message:
:return:
"""
response_message = response.choices[0].message
Expand All @@ -100,6 +208,7 @@ def _process_function_calls(self, response: ModelResponse) -> Optional[ModelResp
message = self.message_from_response(response)
self.memory.add_message(message)
for tool_call in tool_calls:
print(f"Calling function {function_name}({function_args})")
function_name = tool_call.function.name
function_to_call = self.available_tools[function_name]
function_args = json.loads(tool_call.function.arguments)
Expand Down

0 comments on commit 5a2be2b

Please sign in to comment.