Skip to content

Commit

Permalink
Support naive mesh in multi-slice env. (#234)
Browse files Browse the repository at this point in the history
* Support naive mesh in multi-slice env.

* update

* update

* retrigger checks

---------

Co-authored-by: Xianzhi Du <[email protected]>
  • Loading branch information
xianzhidu and Xianzhi Du authored Dec 11, 2023
1 parent 77165cf commit c84f50e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 16 deletions.
86 changes: 70 additions & 16 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Licensed under the Apache License, Version 2.0 (the "License").

"""Common utilities."""
import collections
import contextlib
import copy
import dataclasses
Expand Down Expand Up @@ -1006,6 +1007,72 @@ def register_per_param_settings(settings: NestedTree, *, description: str):
return settings


def build_standard_mesh(mesh_shape: Sequence[int], *, devices: np.ndarray) -> np.ndarray:
logging.info("Building device mesh.")
try:
return mesh_utils.create_device_mesh(mesh_shape, devices=devices)
except NotImplementedError as e:
logging.warning(
"mesh_utils.create_device_mesh cannot handle shape %s: %s. "
"Falling back to the naive mesh. Performance may be reduced.",
mesh_shape,
e,
)
return devices.reshape(mesh_shape)


def create_hybrid_device_mesh(
mesh_shape: Sequence[int],
*,
dcn_mesh_shape: Sequence[int],
devices: Optional[Sequence[Any]] = None,
process_is_granule: bool = False,
) -> np.ndarray:
"""Extends the method to have an option to fall back to naive mesh.
Reference:
https://github.com/google/jax/blob/1189d61bc086fcfb548e73235a601ec46c3623c5/jax/experimental/mesh_utils.py#L324
Args:
mesh_shape: shape of the logical mesh for the faster/inner network, ordered
by increasing network intensity, e.g. [data, fsdp, model] where model has
the most network communication requirements.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in
the same order as mesh_shape.
devices: optionally, the devices to construct a mesh for. Defaults to jax.devices().
process_is_granule: if True, this function will treat processes as the units
of the slower/outer network. Otherwise it will look for slice_index
attributes on devices and use slices as the units. Enabling this is meant
as a fallback for platforms (e.g., GPU) that don't set slice_index.
Raises:
ValueError: if the number of slices to which the `devices` belong doesn't
equal the product of `dcn_mesh_shape`, or if the number of devices
belonging to any single slice does not equal the product of `mesh_shape`.
Returns:
A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape
that can be fed into jax.sharding.Mesh for hybrid parallelism.
"""
attr = "process_index" if process_is_granule else "slice_index"
assert hasattr(devices[0], attr)
granule_dict = collections.defaultdict(list)
for dev in devices:
granule_dict[getattr(dev, attr)].append(dev)
granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f"Number of slices {len(granules)} must equal the product of "
f"dcn_mesh_shape {dcn_mesh_shape}"
)
per_granule_meshes = [
build_standard_mesh(mesh_shape, devices=np.asarray(granule)) for granule in granules
]
granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh)
device_mesh = np.block(blocks.tolist())
return device_mesh


def create_device_mesh(
mesh_shape: Sequence[int], *, devices: Optional[Sequence[Any]] = None
) -> np.ndarray:
Expand Down Expand Up @@ -1038,19 +1105,6 @@ def create_device_mesh(
devices = jax.devices()
devices = np.asarray(devices)

def build_standard_mesh():
logging.info("Building device mesh.")
try:
return mesh_utils.create_device_mesh(mesh_shape, devices=devices)
except NotImplementedError as e:
logging.warning(
"mesh_utils.create_device_mesh cannot handle shape %s: %s. "
"Falling back to the naive mesh. Performance may be reduced.",
mesh_shape,
e,
)
return devices.reshape(mesh_shape)

# Check if the devices are part of a multi-granule configuration.
# <https://github.com/google/jax/blob/b81b79c1b0d2ec/jax/experimental/mesh_utils.py#L313>
device_platform = devices[0].platform
Expand All @@ -1061,15 +1115,15 @@ def build_standard_mesh():

# Return standard mesh if not a multi-slice/granule env.
if not is_multi_granule_env:
return build_standard_mesh()
return build_standard_mesh(mesh_shape, devices=devices)

ici_mesh_shape = mesh_shape
num_granules = max([getattr(el, attr) for el in devices.flatten()]) + 1

# Return standard mesh if on GPU with incompatible multi-slice/granule mesh.
if device_platform == "gpu" and ici_mesh_shape[0] % num_granules != 0:
logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.")
return build_standard_mesh()
return build_standard_mesh(mesh_shape, devices=devices)

# We only break the first device axis (the least communication intensive) across granules.
assert (
Expand All @@ -1088,7 +1142,7 @@ def build_standard_mesh():
f"Num devices {len(devices)} does not match the product of "
f"inter and intra slice/granule parallelism {total_parallelism}."
)
return mesh_utils.create_hybrid_device_mesh(
return create_hybrid_device_mesh(
ici_mesh_shape,
dcn_mesh_shape=dcn_mesh_shape,
devices=devices,
Expand Down
31 changes: 31 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,37 @@ def test_create_device_mesh_multi_slice_tpuv4(self, logical_mesh: Sequence[int])
for ix, sub_mesh in enumerate(device_mesh):
self.assertTrue(all(el.slice_index == ix for el in sub_mesh.flatten()))

@parameterized.parameters(
{"logical_mesh": (2, 128, 2)},
{"logical_mesh": (2, 16, 16)},
)
def test_create_device_mesh_multi_slice_tpuv5e(self, logical_mesh: Sequence[int]):
slice_physical_mesh = (16, 16, 1)
num_slices = 2
coords = [
(x, y, z)
for x in range(slice_physical_mesh[0])
for y in range(slice_physical_mesh[1])
for z in range(slice_physical_mesh[2])
]
devices = [
DummyMultiSliceTpuDevice(
platform="tpu",
device_kind="TPU v5litepod",
process_index=(len(coords) * slice_index + ix) // 4,
coords=coord,
slice_index=slice_index,
)
for ix, coord in enumerate(coords)
for slice_index in range(num_slices)
]
# Check that the constructed mesh has the expected shape.
device_mesh = create_device_mesh(mesh_shape=logical_mesh, devices=devices)
self.assertEqual(device_mesh.shape, logical_mesh)
# Check that the sub_mesh along the first axis only contains devices from one of the slices.
for ix, sub_mesh in enumerate(device_mesh):
self.assertTrue(all(el.slice_index == ix for el in sub_mesh.flatten()))

@parameterized.parameters(
{"logical_mesh": (8, 2, 4)},
{"logical_mesh": (16, 4)},
Expand Down

0 comments on commit c84f50e

Please sign in to comment.