diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 5c10a2be9..ff5c7cce3 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -7,6 +7,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"). """Common utilities.""" +import collections import contextlib import copy import dataclasses @@ -1006,6 +1007,72 @@ def register_per_param_settings(settings: NestedTree, *, description: str): return settings +def build_standard_mesh(mesh_shape: Sequence[int], *, devices: np.ndarray) -> np.ndarray: + logging.info("Building device mesh.") + try: + return mesh_utils.create_device_mesh(mesh_shape, devices=devices) + except NotImplementedError as e: + logging.warning( + "mesh_utils.create_device_mesh cannot handle shape %s: %s. " + "Falling back to the naive mesh. Performance may be reduced.", + mesh_shape, + e, + ) + return devices.reshape(mesh_shape) + + +def create_hybrid_device_mesh( + mesh_shape: Sequence[int], + *, + dcn_mesh_shape: Sequence[int], + devices: Optional[Sequence[Any]] = None, + process_is_granule: bool = False, +) -> np.ndarray: + """Extends the method to have an option to fall back to naive mesh. + Reference: + https://github.com/google/jax/blob/1189d61bc086fcfb548e73235a601ec46c3623c5/jax/experimental/mesh_utils.py#L324 + + Args: + mesh_shape: shape of the logical mesh for the faster/inner network, ordered + by increasing network intensity, e.g. [data, fsdp, model] where model has + the most network communication requirements. + dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in + the same order as mesh_shape. + devices: optionally, the devices to construct a mesh for. Defaults to jax.devices(). + process_is_granule: if True, this function will treat processes as the units + of the slower/outer network. Otherwise it will look for slice_index + attributes on devices and use slices as the units. Enabling this is meant + as a fallback for platforms (e.g., GPU) that don't set slice_index. + + Raises: + ValueError: if the number of slices to which the `devices` belong doesn't + equal the product of `dcn_mesh_shape`, or if the number of devices + belonging to any single slice does not equal the product of `mesh_shape`. + + Returns: + A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape + that can be fed into jax.sharding.Mesh for hybrid parallelism. + """ + attr = "process_index" if process_is_granule else "slice_index" + assert hasattr(devices[0], attr) + granule_dict = collections.defaultdict(list) + for dev in devices: + granule_dict[getattr(dev, attr)].append(dev) + granules = list(granule_dict[key] for key in sorted(granule_dict.keys())) + if np.prod(dcn_mesh_shape) != len(granules): + raise ValueError( + f"Number of slices {len(granules)} must equal the product of " + f"dcn_mesh_shape {dcn_mesh_shape}" + ) + per_granule_meshes = [ + build_standard_mesh(mesh_shape, devices=np.asarray(granule)) for granule in granules + ] + granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) + blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh) + device_mesh = np.block(blocks.tolist()) + return device_mesh + + def create_device_mesh( mesh_shape: Sequence[int], *, devices: Optional[Sequence[Any]] = None ) -> np.ndarray: @@ -1038,19 +1105,6 @@ def create_device_mesh( devices = jax.devices() devices = np.asarray(devices) - def build_standard_mesh(): - logging.info("Building device mesh.") - try: - return mesh_utils.create_device_mesh(mesh_shape, devices=devices) - except NotImplementedError as e: - logging.warning( - "mesh_utils.create_device_mesh cannot handle shape %s: %s. " - "Falling back to the naive mesh. Performance may be reduced.", - mesh_shape, - e, - ) - return devices.reshape(mesh_shape) - # Check if the devices are part of a multi-granule configuration. # device_platform = devices[0].platform @@ -1061,7 +1115,7 @@ def build_standard_mesh(): # Return standard mesh if not a multi-slice/granule env. if not is_multi_granule_env: - return build_standard_mesh() + return build_standard_mesh(mesh_shape, devices=devices) ici_mesh_shape = mesh_shape num_granules = max([getattr(el, attr) for el in devices.flatten()]) + 1 @@ -1069,7 +1123,7 @@ def build_standard_mesh(): # Return standard mesh if on GPU with incompatible multi-slice/granule mesh. if device_platform == "gpu" and ici_mesh_shape[0] % num_granules != 0: logging.warning("Falling back to ICI-only mesh on GPU, performance may be reduced.") - return build_standard_mesh() + return build_standard_mesh(mesh_shape, devices=devices) # We only break the first device axis (the least communication intensive) across granules. assert ( @@ -1088,7 +1142,7 @@ def build_standard_mesh(): f"Num devices {len(devices)} does not match the product of " f"inter and intra slice/granule parallelism {total_parallelism}." ) - return mesh_utils.create_hybrid_device_mesh( + return create_hybrid_device_mesh( ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, devices=devices, diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 4bd94439f..0d2ab3eb0 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -1131,6 +1131,37 @@ def test_create_device_mesh_multi_slice_tpuv4(self, logical_mesh: Sequence[int]) for ix, sub_mesh in enumerate(device_mesh): self.assertTrue(all(el.slice_index == ix for el in sub_mesh.flatten())) + @parameterized.parameters( + {"logical_mesh": (2, 128, 2)}, + {"logical_mesh": (2, 16, 16)}, + ) + def test_create_device_mesh_multi_slice_tpuv5e(self, logical_mesh: Sequence[int]): + slice_physical_mesh = (16, 16, 1) + num_slices = 2 + coords = [ + (x, y, z) + for x in range(slice_physical_mesh[0]) + for y in range(slice_physical_mesh[1]) + for z in range(slice_physical_mesh[2]) + ] + devices = [ + DummyMultiSliceTpuDevice( + platform="tpu", + device_kind="TPU v5litepod", + process_index=(len(coords) * slice_index + ix) // 4, + coords=coord, + slice_index=slice_index, + ) + for ix, coord in enumerate(coords) + for slice_index in range(num_slices) + ] + # Check that the constructed mesh has the expected shape. + device_mesh = create_device_mesh(mesh_shape=logical_mesh, devices=devices) + self.assertEqual(device_mesh.shape, logical_mesh) + # Check that the sub_mesh along the first axis only contains devices from one of the slices. + for ix, sub_mesh in enumerate(device_mesh): + self.assertTrue(all(el.slice_index == ix for el in sub_mesh.flatten())) + @parameterized.parameters( {"logical_mesh": (8, 2, 4)}, {"logical_mesh": (16, 4)},