Skip to content

Commit

Permalink
BUG: Fix leaking clients in actor caller (#115)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
codingl2k1 and mergify[bot] authored Nov 28, 2024
1 parent 7b9f181 commit b1fd262
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 22 deletions.
38 changes: 28 additions & 10 deletions python/xoscar/backends/communication/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import asyncio
import concurrent.futures as futures
import logging
import weakref
from typing import Any, Callable, Coroutine, Dict, Type
from typing import Any, Callable, Coroutine, Dict, Optional, Type
from urllib.parse import urlparse

from ...errors import ServerClosed
Expand All @@ -29,13 +30,15 @@

DEFAULT_DUMMY_ADDRESS = "dummy://0"

logger = logging.getLogger(__name__)


class DummyChannel(Channel):
"""
Channel for communications in same process.
"""

__slots__ = "_in_queue", "_out_queue", "_closed"
__slots__ = "__weakref__", "_in_queue", "_out_queue", "_closed"

name = "dummy"

Expand Down Expand Up @@ -100,8 +103,8 @@ class DummyServer(Server):
_address_to_instances: weakref.WeakValueDictionary[str, "DummyServer"] = (
weakref.WeakValueDictionary()
)
_channels: list[ChannelType]
_tasks: list[asyncio.Task]
_channels: weakref.WeakSet[Channel]
_tasks: set[asyncio.Task]
scheme: str | None = "dummy"

def __init__(
Expand All @@ -111,8 +114,8 @@ def __init__(
):
super().__init__(address, channel_handler)
self._closed = asyncio.Event()
self._channels = []
self._tasks = []
self._channels = weakref.WeakSet()
self._tasks = set()

@classmethod
def get_instance(cls, address: str):
Expand Down Expand Up @@ -178,7 +181,7 @@ async def on_connected(self, *args, **kwargs):
f"{type(self).__name__} got unexpected "
f'arguments: {",".join(kwargs)}'
)
self._channels.append(channel)
self._channels.add(channel)
await self.channel_handler(channel)

@implements(Server.stop)
Expand All @@ -203,6 +206,7 @@ def __init__(
self, local_address: str | None, dest_address: str | None, channel: Channel
):
super().__init__(local_address, dest_address, channel)
self._task: Optional[asyncio.Task] = None

@staticmethod
@implements(Client.connect)
Expand Down Expand Up @@ -232,11 +236,25 @@ async def connect(
task = asyncio.create_task(conn_coro)
client = DummyClient(local_address, dest_address, client_channel)
client._task = task
server._tasks.append(task)
server._tasks.add(task)

def _discard(t):
server._tasks.discard(t)
logger.info("Channel exit: %s", server_channel.info)

task.add_done_callback(_discard)
return client

@implements(Client.close)
async def close(self):
await super().close()
self._task.cancel()
self._task = None
if self._task is not None:
task_loop = self._task.get_loop()
if task_loop is not None:
if not task_loop.is_running():
logger.warning(
"Dummy channel cancel task on a stopped loop, dest address: %s.",
self.dest_address,
)
self._task.cancel()
self._task = None
16 changes: 12 additions & 4 deletions python/xoscar/backends/communication/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def closed(self):
class _BaseSocketServer(Server, metaclass=ABCMeta):
__slots__ = "_aio_server", "_channels"

_channels: list[ChannelType]
_channels: set[Channel]

def __init__(
self,
Expand All @@ -124,7 +124,7 @@ def __init__(
super().__init__(address, channel_handler)
# asyncio.Server
self._aio_server = aio_server
self._channels = []
self._channels = set()

@implements(Server.start)
async def start(self):
Expand Down Expand Up @@ -170,9 +170,16 @@ async def on_connected(self, *args, **kwargs):
dest_address=dest_address,
channel_type=self.channel_type,
)
self._channels.append(channel)
self._channels.add(channel)
# handle over channel to some handlers
await self.channel_handler(channel)
try:
await self.channel_handler(channel)
finally:
if not channel.closed:
await channel.close()
# Remove channel if channel exit
self._channels.discard(channel)
logger.debug("Channel exit: %s", channel.info)

@implements(Server.stop)
async def stop(self):
Expand All @@ -185,6 +192,7 @@ async def stop(self):
await asyncio.gather(
*(channel.close() for channel in self._channels if not channel.closed)
)
self._channels.clear()

@property
@implements(Server.stopped)
Expand Down
17 changes: 12 additions & 5 deletions python/xoscar/backends/communication/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class UCXServer(Server):
scheme = "ucx"

_ucp_listener: "ucp.Listener" # type: ignore
_channels: List[UCXChannel]
_channels: set[UCXChannel]

def __init__(
self,
Expand All @@ -381,7 +381,7 @@ def __init__(
self.host = host
self.port = port
self._ucp_listener = ucp_listener
self._channels = []
self._channels = set()
self._closed = asyncio.Event()

@classproperty
Expand Down Expand Up @@ -469,9 +469,16 @@ async def on_connected(self, *args, **kwargs):
channel = UCXChannel(
ucp_endpoint, local_address=local_address, dest_address=dest_address
)
self._channels.append(channel)
self._channels.add(channel)
# handle over channel to some handlers
await self.channel_handler(channel)
try:
await self.channel_handler(channel)
finally:
if not channel.closed:
await channel.close()
# Remove channel if channel exit
self._channels.discard(channel)
logger.debug("Channel exit: %s", channel.info)

@implements(Server.stop)
async def stop(self):
Expand All @@ -480,7 +487,7 @@ async def stop(self):
await asyncio.gather(
*(channel.close() for channel in self._channels if not channel.closed)
)
self._channels = []
self._channels.clear()
self._ucp_listener = None
self._closed.set()

Expand Down
41 changes: 39 additions & 2 deletions python/xoscar/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import asyncio
import copy
import logging
import threading
import weakref
from typing import Type, Union

from .._utils import Timer
Expand All @@ -31,8 +33,8 @@
logger = logging.getLogger(__name__)


class ActorCaller:
__slots__ = "_client_to_message_futures", "_clients", "_profiling_data"
class ActorCallerThreadLocal:
__slots__ = ("_client_to_message_futures", "_clients", "_profiling_data")

_client_to_message_futures: dict[Client, dict[bytes, asyncio.Future]]
_clients: dict[Client, asyncio.Task]
Expand Down Expand Up @@ -193,6 +195,7 @@ async def call(
return await self.call_with_client(client, message, wait)

async def stop(self):
logger.debug("Actor caller stop.")
try:
await asyncio.gather(*[client.close() for client in self._clients])
except (ConnectionError, ServerClosed):
Expand All @@ -202,3 +205,37 @@ async def stop(self):
def cancel_tasks(self):
# cancel listening for all clients
_ = [task.cancel() for task in self._clients.values()]


class ActorCaller:
__slots__ = "_thread_local"

class _RefHolder:
pass

_close_loop = asyncio.new_event_loop()
_close_thread = threading.Thread(target=_close_loop.run_forever, daemon=True)
_close_thread.start()

def __init__(self):
self._thread_local = threading.local()

def __getattr__(self, item):
try:
actor_caller = self._thread_local.actor_caller
except AttributeError:
thread_info = str(threading.current_thread())
logger.debug("Creating a new actor caller for thread: %s", thread_info)
actor_caller = self._thread_local.actor_caller = ActorCallerThreadLocal()
ref = self._thread_local.ref = ActorCaller._RefHolder()
# If the thread exit, we clean the related actor callers and channels.

def _cleanup():
asyncio.run_coroutine_threadsafe(actor_caller.stop(), self._close_loop)
logger.debug(
"Clean up the actor caller due to thread exit: %s", thread_info
)

weakref.finalize(ref, _cleanup)

return getattr(actor_caller, item)
94 changes: 93 additions & 1 deletion python/xoscar/backends/test/tests/test_actor_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import gc
import os
import sys
import threading

import pytest

import xoscar as mo

from ...communication.dummy import DummyServer
from ...router import Router


class DummyActor(mo.Actor):
def __init__(self, value):
Expand Down Expand Up @@ -60,3 +65,90 @@ async def test_simple(actor_pool_context):
allocate_strategy=mo.allocate_strategy.RandomSubPool(),
)
assert await actor_ref.add(1) == 101


def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)


def _run_forever(loop):
loop.run_forever()
_cancel_all_tasks(loop)


@pytest.mark.asyncio
async def test_channel_cleanup(actor_pool_context):
pool = actor_pool_context
actor_ref = await mo.create_actor(
DummyActor,
0,
address=pool.external_address,
allocate_strategy=mo.allocate_strategy.RandomSubPool(),
)

curr_router = Router.get_instance()
server_address = curr_router.get_internal_address(actor_ref.address)
dummy_server = DummyServer.get_instance(server_address)

async def inc():
await asyncio.gather(*(actor_ref.add.tell(1) for _ in range(10)))

loops = []
threads = []
futures = []
for _ in range(10):
loop = asyncio.new_event_loop()
t = threading.Thread(target=_run_forever, args=(loop,))
t.start()
loops.append(loop)
threads.append(t)
fut = asyncio.run_coroutine_threadsafe(inc(), loop=loop)
futures.append(fut)

for fut in futures:
fut.result()

while True:
if await actor_ref.add(0) == 100:
break

assert len(dummy_server._channels) == 12
assert len(dummy_server._tasks) == 12

for loop in loops:
loop.call_soon_threadsafe(loop.stop)

for t in threads:
t.join()
threads.clear()

curr_router = Router.get_instance()
server_address = curr_router.get_internal_address(actor_ref.address)
dummy_server = DummyServer.get_instance(server_address)

while True:
gc.collect()
# Two channels left:
# 1. from the main pool to the actor
# 2. from current main thread to the actor.
if len(dummy_server._channels) == 2 and len(dummy_server._tasks) == 2:
break

0 comments on commit b1fd262

Please sign in to comment.