Skip to content
This repository has been archived by the owner on Jun 4, 2024. It is now read-only.

Commit

Permalink
[Collective][PR 3.5/6] Send/Recv calls and some initial code for comm…
Browse files Browse the repository at this point in the history
…unicator caching (#12935)

* other collectives all work

* auto-linting

* mannual linting #1

* mannual linting 2

* bugfix

* add send/recv point-to-point calls

* add some initial code for communicator caching

* auto linting

* optimize imports

* minor fix

* fix unpassed tests

* support more dtypes

* rerun some distributed tests for send/recv

* linting
  • Loading branch information
zhisbug authored Dec 28, 2020
1 parent c524f86 commit 18f5743
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 63 deletions.
4 changes: 2 additions & 2 deletions python/ray/util/collective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ray.util.collective.collective import nccl_available, mpi_available, \
is_group_initialized, init_collective_group, destroy_collective_group, \
get_rank, get_world_size, allreduce, barrier, reduce, broadcast, \
allgather, reducescatter
allgather, reducescatter, send, recv

__all__ = [
"nccl_available", "mpi_available", "is_group_initialized",
"init_collective_group", "destroy_collective_group", "get_rank",
"get_world_size", "allreduce", "barrier", "reduce", "broadcast",
"allgather", "reducescatter"
"allgather", "reducescatter", "send", "recv"
]
64 changes: 41 additions & 23 deletions python/ray/util/collective/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import ray
from ray.util.collective import types
from ray.util.collective.const import get_nccl_store_name

_MPI_AVAILABLE = False
_NCCL_AVAILABLE = True
Expand All @@ -16,7 +15,6 @@
# _MPI_AVAILABLE = False
try:
from ray.util.collective.collective_group import NCCLGroup
from ray.util.collective.collective_group import nccl_util
except ImportError:
_NCCL_AVAILABLE = False

Expand Down Expand Up @@ -53,17 +51,6 @@ def create_collective_group(self, backend, world_size, rank, group_name):
if backend == types.Backend.MPI:
raise NotImplementedError()
elif backend == types.Backend.NCCL:
# create the ncclUniqueID
if rank == 0:
# availability has been checked before entering here.
group_uid = nccl_util.get_nccl_unique_id()
store_name = get_nccl_store_name(group_name)
# Avoid a potential circular dependency in ray/actor.py
from ray.util.collective.util import NCCLUniqueIDStore
store = NCCLUniqueIDStore.options(
name=store_name, lifetime="detached").remote(store_name)
ray.wait([store.set_id.remote(group_uid)])

logger.debug("creating NCCL group: '{}'".format(group_name))
g = NCCLGroup(world_size, rank, group_name)
self._name_group_map[group_name] = g
Expand All @@ -89,19 +76,9 @@ def destroy_collective_group(self, group_name):

# release the collective group resource
g = self._name_group_map[group_name]
rank = g.rank
backend = g.backend()

# clean up the dicts
del self._group_name_map[g]
del self._name_group_map[group_name]
if backend == types.Backend.NCCL:
# release the named actor
if rank == 0:
store_name = get_nccl_store_name(group_name)
store = ray.get_actor(store_name)
ray.wait([store.__ray_terminate__.remote()])
ray.kill(store)
# Release the communicator resources
g.destroy_group()

Expand Down Expand Up @@ -322,6 +299,46 @@ def reducescatter(tensor,
g.reducescatter(tensor, tensor_list, opts)


def send(tensor, dst_rank: int, group_name: str = "default"):
"""Send a tensor to a remote processes synchronously.
Args:
tensor: the tensor to send.
dst_rank (int): the rank of the destination process.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, dst_rank)
if dst_rank == g.rank:
raise RuntimeError(
"The destination rank '{}' is self.".format(dst_rank))
g.send(tensor, dst_rank)


def recv(tensor, src_rank: int, group_name: str = "default"):
"""Receive a tensor from a remote process synchronously.
Args:
tensor: the received tensor.
src_rank (int): the rank of the source process.
group_name (str): the name of the collective group.
Returns:
None
"""
_check_single_tensor_input(tensor)
g = _check_and_get_group(group_name)
_check_rank_valid(g, src_rank)
if src_rank == g.rank:
raise RuntimeError(
"The destination rank '{}' is self.".format(src_rank))
g.recv(tensor, src_rank)


def _check_and_get_group(group_name):
"""Check the existence and return the group handle."""
_check_inside_actor()
Expand Down Expand Up @@ -368,6 +385,7 @@ def _check_inside_actor():


def _check_rank_valid(g, rank: int):
"""Check the rank: 0 <= rank < world_size."""
if rank < 0:
raise ValueError("rank '{}' is negative.".format(rank))
if rank > g.world_size:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ def reducescatter(self,
tensor_list,
reducescatter_options=ReduceScatterOptions()):
raise NotImplementedError()

@abstractmethod
def send(self, tensor, dst_rank):
raise NotImplementedError()

@abstractmethod
def recv(self, tensor, src_rank):
raise NotImplementedError()
Loading

0 comments on commit 18f5743

Please sign in to comment.