Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable create_actor_pool to use elastic ip #95

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ jobs:
working-directory: ./python

- name: Report coverage data
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
working-directory: ./python
flags: unittests
22 changes: 18 additions & 4 deletions python/xoscar/backends/communication/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..._utils import to_binary
from ...constants import XOSCAR_UNIX_SOCKET_DIR
from ...serialization import AioDeserializer, AioSerializer, deserialize
from ...utils import classproperty, implements
from ...utils import classproperty, implements, is_v6_ip
from .base import Channel, ChannelType, Client, Server
from .core import register_client, register_server
from .utils import read_buffers, write_buffers
Expand Down Expand Up @@ -201,17 +201,31 @@ def client_type(self) -> Type["Client"]:
def channel_type(self) -> int:
return ChannelType.remote

@classmethod
def parse_config(cls, config: dict) -> dict:
return config

@staticmethod
@implements(Server.create)
async def create(config: Dict) -> "Server":
config = config.copy()
if "address" in config:
address = config.pop("address")
host, port = address.split(":", 1)
host, port = address.rsplit(":", 1)
port = int(port)
else:
host = config.pop("host")
port = int(config.pop("port"))
_host = host
if config.pop("listen_elastic_ip", False):
# The Actor.address will be announce to client, and is not on our host,
# cannot actually listen on it,
# so we have to keep SocketServer.host untouched to make sure Actor.address not changed
if is_v6_ip(host):
_host = "::"
else:
_host = "0.0.0.0"

handle_channel = config.pop("handle_channel")
if "start_serving" not in config:
config["start_serving"] = False
Expand All @@ -224,7 +238,7 @@ async def handle_connection(reader: StreamReader, writer: StreamWriter):

port = port if port != 0 else None
aio_server = await asyncio.start_server(
handle_connection, host=host, port=port, **config
handle_connection, host=_host, port=port, **config
)

# get port of the socket if not specified
Expand All @@ -250,7 +264,7 @@ class SocketClient(Client):
async def connect(
dest_address: str, local_address: str | None = None, **kwargs
) -> "Client":
host, port_str = dest_address.split(":", 1)
host, port_str = dest_address.rsplit(":", 1)
port = int(port_str)
(reader, writer) = await asyncio.open_connection(host=host, port=port, **kwargs)
channel = SocketChannel(
Expand Down
18 changes: 14 additions & 4 deletions python/xoscar/backends/communication/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...nvutils import get_cuda_context, get_index_and_uuid
from ...serialization import deserialize
from ...serialization.aio import BUFFER_SIZES_NAME, AioSerializer, get_header_length
from ...utils import classproperty, implements, is_cuda_buffer, lazy_import
from ...utils import classproperty, implements, is_cuda_buffer, is_v6_ip, lazy_import
from ..message import _MessageBase
from .base import Channel, ChannelType, Client, Server
from .core import register_client, register_server
Expand Down Expand Up @@ -401,11 +401,21 @@ async def create(config: Dict) -> "Server":
prefix = f"{UCXServer.scheme}://"
if address.startswith(prefix):
address = address[len(prefix) :]
host, port = address.split(":", 1)
host, port = address.rsplit(":", 1)
port = int(port)
else:
host = config.pop("host")
port = int(config.pop("port"))
_host = host
if config.pop("listen_elastic_ip", False):
# The Actor.address will be announce to client, and is not on our host,
# cannot actually listen on it,
# so we have to keep SocketServer.host untouched to make sure Actor.address not changed
if is_v6_ip(host):
_host = "::"
else:
_host = "0.0.0.0"

handle_channel = config.pop("handle_channel")

# init
Expand All @@ -414,7 +424,7 @@ async def create(config: Dict) -> "Server":
async def serve_forever(client_ucp_endpoint: "ucp.Endpoint"): # type: ignore
try:
await server.on_connected(
client_ucp_endpoint, local_address=server.address
client_ucp_endpoint, local_address="%s:%d" % (_host, port)
)
except ChannelClosed: # pragma: no cover
logger.exception("Connection closed before handshake completed")
Expand Down Expand Up @@ -498,7 +508,7 @@ async def connect(
prefix = f"{UCXClient.scheme}://"
if dest_address.startswith(prefix):
dest_address = dest_address[len(prefix) :]
host, port_str = dest_address.split(":", 1)
host, port_str = dest_address.rsplit(":", 1)
port = int(port_str)
kwargs = kwargs.copy()
ucx_config = kwargs.pop("config", dict()).get("ucx", dict())
Expand Down
5 changes: 3 additions & 2 deletions python/xoscar/backends/indigen/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def get_external_addresses(
"""Get external address for every process"""
assert n_process is not None
if ":" in address:
host, port_str = address.split(":", 1)
host, port_str = address.rsplit(":", 1)
port = int(port_str)
if ports:
if len(ports) != n_process:
Expand Down Expand Up @@ -324,6 +324,7 @@ async def append_sub_pool(
start_method: str | None = None,
kwargs: dict | None = None,
):
# external_address has port 0, subprocess will bind random port.
external_address = (
external_address
or MainActorPool.get_external_addresses(self.external_address, n_process=1)[
Expand Down Expand Up @@ -393,7 +394,7 @@ def start_pool_in_process():
content=self._config,
)
await self.handle_control_command(control_message)

# The actual port will return in process_status.
return process_status.external_addresses[0]

async def remove_sub_pool(
Expand Down
Loading
Loading