Skip to content

Commit

Permalink
feat: Distinguish between streaming and non streaming internally (#121)
Browse files Browse the repository at this point in the history
* feat: Distinguish between streaming and non streaming internally

* clean up
  • Loading branch information
KenyonY authored Apr 11, 2024
1 parent 8322a46 commit 8bd6f27
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions openai_forward/forward/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ async def aiter_bytes(
route_path: str,
uid: str,
cache_key: str | None = None,
stream: bool | None = None,
):
"""
Asynchronously iterates through the bytes in the given aiohttp.ClientResponse object
Expand All @@ -506,38 +507,54 @@ async def aiter_bytes(
route_path (str): The API route path.
uid (str): Unique identifier for the request.
cache_key (bytes): The cache key.
stream (bool): Whether the response is a stream.
Returns:
AsyncGenerator[bytes]: Each chunk of bytes from the server's response.
"""

queue_is_complete = False

queue = Queue()
# todo:
task = asyncio.create_task(self.read_chunks(r, queue))
yield_completed = False
chunk_list = []
try:
while True:
chunk = await queue.get()
if not isinstance(chunk, bytes):
queue.task_done()
queue_is_complete = True
break
if CACHE_OPENAI:
chunk_list.append(chunk)
chunk = None
if stream:
queue = Queue()
# todo:
task = asyncio.create_task(self.read_chunks(r, queue))
try:
while True:
chunk = await queue.get()
if not isinstance(chunk, bytes):
queue.task_done()
yield_completed = True
break
if CACHE_OPENAI:
chunk_list.append(chunk)
yield chunk
except Exception as e:
logger.warning(
f"aiter_bytes error:{e}\nhost:{request.client.host} method:{request.method}: "
f"{traceback.format_exc()}"
)
finally:
if not task.done():
task.cancel()
else:
try:
chunk = await r.read()
yield chunk
except Exception:
logger.warning(
f"aiter_bytes error:\nhost:{request.client.host} method:{request.method}: {traceback.format_exc()}"
)
finally:
if not task.done():
task.cancel()
r.release()
chunk_list.append(chunk)
chunk = bytearray(chunk)
yield_completed = True
except Exception as e:
logger.warning(
f"aiter_bytes error:{e}\nhost:{request.client.host} method:{request.method}: "
f"{traceback.format_exc()}"
)

r.release()

if uid:
if r.ok and queue_is_complete:
if r.ok and yield_completed:
target_info = self._handle_result(
chunk, uid, route_path, request.method
)
Expand Down Expand Up @@ -595,6 +612,7 @@ async def reverse_proxy(self, request: Request):
request, route_path, model_set
)
uid = payload_info["uid"]
stream = payload_info.get('stream', None)

cached_response, cache_key = get_cached_response(
payload,
Expand All @@ -610,7 +628,7 @@ async def reverse_proxy(self, request: Request):

r = await self.send(client_config, data=payload)
return StreamingResponse(
self.aiter_bytes(r, request, route_path, uid, cache_key),
self.aiter_bytes(r, request, route_path, uid, cache_key, stream),
status_code=r.status,
media_type=r.headers.get("content-type"),
)

0 comments on commit 8bd6f27

Please sign in to comment.