Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer offloading through weight-only offload #867

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions axlearn/common/factorized_rms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from axlearn.common import factorized_rms
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.optimizer_base import (
NestedOptStateSpec,
Nested,
OptParam,
OptStateSpec,
PartitionedGradientTransformation,
)
from axlearn.common.optimizers import OptStateSpec, with_partition_fn
from axlearn.common.optimizers import with_partition_fn
from axlearn.common.test_utils import TestCase
from axlearn.common.utils import PartitionSpec, flatten_items

Expand Down Expand Up @@ -59,7 +60,7 @@ def testParity(self, factored, dtype):

# The 'exp' optimizer is partitioned according to the mesh_axes of parameters and
# factorization spec.
exp_partition: NestedOptStateSpec = exp.partition(param_specs)
exp_partition: Nested[OptStateSpec] = exp.partition(param_specs)
# Used for `count`.
count_spec = OptStateSpec(
dtype=jnp.int32,
Expand Down
8 changes: 3 additions & 5 deletions axlearn/common/optimizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
- weight_decay_scale: control the weight decay rate.
"""
import dataclasses
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional, Union

import optax
import typing_extensions

from axlearn.common.base_layer import FactorizationSpec, NestedParameterSpec
from axlearn.common.utils import Tensor, TensorSpec
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.utils import Nested, Tensor, TensorSpec


@dataclasses.dataclass
Expand Down Expand Up @@ -66,8 +65,7 @@ def __call__(

# Specification of an optimizer state array.
OptStateSpec = TensorSpec
NestedOptStateSpec = Union[OptStateSpec, dict, Sequence]
TransformPartitionSpecFn = Callable[[NestedParameterSpec], NestedOptStateSpec]
TransformPartitionSpecFn = Callable[[Nested[ParameterSpec]], Nested[OptStateSpec]]
ruomingp marked this conversation as resolved.
Show resolved Hide resolved


class PartitionedGradientTransformation(NamedTuple):
Expand Down
143 changes: 131 additions & 12 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
import typing_extensions
from absl import logging
from jax import numpy as jnp
from jax._src.sharding_impls import TransferToMemoryKind
from optax._src import numerics

from axlearn.common import schedule, struct
from axlearn.common.base_layer import NestedParameterSpec, ParameterSpec, PartitionSpec
from axlearn.common.base_layer import ParameterSpec, PartitionSpec
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.factorized_rms import scale_by_factored_rms
from axlearn.common.module import current_context
Expand All @@ -51,8 +52,8 @@
TransformPartitionSpecFn,
)
from axlearn.common.utils import (
MemoryKind,
Nested,
NestedPartitionSpec,
NestedTensor,
NestedTree,
Tensor,
Expand Down Expand Up @@ -139,19 +140,40 @@ def update_fn(
return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)


def copy_partition(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def copy_partition(
specs: Nested[OptStateSpec],
*,
pattern: Union[None, str, re.Pattern] = None,
memory_kind: Optional[MemoryKind] = None,
) -> Nested[OptStateSpec]:
"""Copies OptStateSpec and optionally assigns with a different memory kind.

Args:
specs: Nested[OptStateSpec] to copy from.
pattern: Regex to match the full path of each spec. Matched specs will have their memory
kind replaced with `memory_kind`.
memory_kind: New memory kind. Default to None.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
Returns:

A Nested[OptStateSpec] with possibly a different memory kind.
"""
return jax.tree.map(
lambda param_spec: OptStateSpec(
dtype=param_spec.dtype, shape=param_spec.shape, mesh_axes=param_spec.mesh_axes
lambda path, spec: OptStateSpec(
dtype=spec.dtype,
shape=spec.shape,
mesh_axes=spec.mesh_axes,
memory_kind=memory_kind
if pattern and re.fullmatch(pattern, path)
else spec.memory_kind,
),
param_specs,
tree_paths(specs),
specs,
)


def trace_partition(
base: optax.GradientTransformation,
) -> PartitionedGradientTransformation:
def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.TraceState(trace=copy_partition(param_specs))

return with_partition_fn(base, partition_fn)
Expand All @@ -160,7 +182,7 @@ def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def adam_partition(base: optax.GradientTransformation) -> PartitionedGradientTransformation:
state: optax.ScaleByAdamState = base.init({})

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return optax.ScaleByAdamState(
count=OptStateSpec(
dtype=state.count.dtype, shape=state.count.shape, mesh_axes=PartitionSpec()
Expand Down Expand Up @@ -950,7 +972,7 @@ def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _U
)
return updates, new_state

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
def get_ema_partition(param_spec: ParameterSpec) -> OptStateSpec:
# Store momentum in accumulator_dtype if it is set and p is not scalar.
if param_spec.shape and accumulator_dtype is not None:
Expand Down Expand Up @@ -1412,7 +1434,7 @@ def _is_valid_step(
drop_stats=new_drop_stats,
)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
if use_adaptive_drop_norm:
one = jnp.ones([], jnp.float32)
dict_thresholds = drop_norm(count=one, mean=one, stddev=one)
Expand Down Expand Up @@ -1571,7 +1593,7 @@ def update_fn(updates, state, params):
)
return updates, ParamEmaState(count=count_inc, ema=new_ema)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return ParamEmaState(
count=OptStateSpec(dtype=jnp.int32, shape=[], mesh_axes=PartitionSpec()),
ema=copy_partition(param_specs),
Expand Down Expand Up @@ -1617,7 +1639,7 @@ def update_fn(updates, state, params=None):
updates = jax.tree.map(lambda g, m: jnp.sign((1.0 - b1) * g + b1 * m), updates, state.mu)
return updates, ScaleByLionState(count=count_inc, mu=mu)

def partition_fn(param_specs: NestedParameterSpec) -> NestedPartitionSpec:
def partition_fn(param_specs: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
mu_specs = param_specs
if mu_dtype is not None:
mu_specs = jax.tree.map(
Expand Down Expand Up @@ -1993,3 +2015,100 @@ def _update2(u: Tensor, param: OptParam):
partition=lambda _: OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec()),
)
return named_chain(**tx)


def offload_optimizer(
optimizer: ConfigOr[PartitionedGradientTransformation],
*,
pattern: Union[str, re.Pattern] = ".*",
offload_src: MemoryKind = "device",
offload_dst: MemoryKind = "pinned_host",
) -> PartitionedGradientTransformation:
"""Offload the state of the wrapped optimizer that matches `pattern` to `offload_dst`.

Args:
optimizer: The optimizer to offload.
pattern: Regex pattern used to match the path of optimizer states. Fully matched states
will be offloaded. Default to regex that matches all states.
offload_src: Offload-from memory kind. Default to "device".
offload_dst: Offload-to memory kind. Default to "pinned_host".

Returns:
A optimizer whose state is on `offload_dst` and does the same computation as `optimizer`.

Raises:
ValueError: when the `update` function of the returned optimizer is called outside of jit
context.

This function returns a new `PartitionedGradientTransformation` that
1. Puts matched states of the wrapped optimizer on `offload_dst` through the partition function
during state initialization in the trainer.
2. Copies the matched states to `offload_src` before `optimizer.update` is called.
3. Copies the matched updated states to `offload_dst` after `optimizer.update` is called.

The regex pattern is matched against the full path of each optimizer state. An example full
path is optimizer/1/0/mu/decoder/transformer/repeat/layer/feed_forward/linear1_0. If the
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
pattern should not depend on model structure, you can use ".*/mu/.*" to offload all `mu`.

The .update function of the returned `PartitionedGradientTransformation` must be called within
a jit function.

Example usage:
```python
your_opt = adamw_optimizer(...)
offloaded_opt = offload_optimizer(your_opt)
```

When using `skip_and_clip_by_global_norm` with this offload optimizer, you must wrap the entire
`skip_and_clip_by_global_norm` inside. Do not wrap the inner of `skip_and_clip_by_global_norm`
or you will get errors. Correct example:
```
offloaded_opt = offload_optimizer(skip_and_clip_by_global_norm(inner=adamw_optimizer(...)))
```
The reason is that `skip_and_clip_by_global_norm` conditionally chooses the previous optimizer
state and the updated new optimizer state using `jnp.where`, which doesn't support tensors on
`pinned_host` memory space.
"""
optimizer = maybe_instantiate(optimizer)
if offload_src is None or offload_dst is None:
raise ValueError(
"offload_src and offload_dst cannot be None when using optimizer offloading."
)

logging.info("Optimizer offloading from %s to %s enabled.", offload_src, offload_dst)

def init_fn(params: NestedOptParam):
return optimizer.init(params)

def _move_fn(state: optax.OptState, dst: MemoryKind) -> optax.OptState:
# TransferToMemoryKind let us change the memory kind of tensors without specifying the full
# sharding (i.e. jax.sharding.NamedSharding). Although there's no documentation about it,
# it's specified in the API signature. Reference:
# https://github.com/jax-ml/jax/blob/21f8885a9e104b8828c9a8b721eed0c68b622691/jax/_src/api.py#L2220
# Note: device_put doesn't move everything at once. When we pass a pytree of arrays to
# device_put, each array in the pytree is moved independent of one another. The exact order
# is decided by the latency hiding scheduler. The scheduler will try to overlap the
# transfers of each state with the state update on TPU whenever possible. There is some
# memory spike due the the temporary state in HBM, but the spike is much less than the full
# memory usage of all states. Moreover, when the optimizer is run, all activations are
# released, so we have less memory pressure at that point in time.
return jax.tree.map(
lambda path, tensor: jax.device_put(tensor, TransferToMemoryKind(dst))
if re.fullmatch(pattern, path)
else tensor,
tree_paths(state),
state,
)

def update_fn(updates: optax.Updates, state: optax.OptState, params: NestedOptParam):
state = _move_fn(state, offload_src)
updates, state = optimizer.update(updates, state, params)
state = _move_fn(state, offload_dst)
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
return updates, state

def partition_fn(param_spec: Nested[ParameterSpec]) -> Nested[OptStateSpec]:
return copy_partition(
optimizer.partition(param_spec), pattern=pattern, memory_kind=offload_dst
)

return PartitionedGradientTransformation(init=init_fn, update=update_fn, partition=partition_fn)
59 changes: 45 additions & 14 deletions axlearn/common/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ema,
l2_regularizer,
lion_optimizer,
offload_optimizer,
opt_param_values,
param_ema,
per_param_scale_by_path,
Expand Down Expand Up @@ -379,12 +380,25 @@ def _check_dtypes(x, y, z):
jax.tree.map(_check_dtypes, init_state, partition_state, update_state)

def _test_optimizer(self, optimizer):
params = OptParam(
value=jnp.asarray([0, 1, 2, -3], dtype=jnp.float32),
factorization_spec=None,
weight_decay_scale=1.0,
)
state = optimizer.init(params)
self._test_optimizer_helper(optimizer, True)
self._test_optimizer_helper(optimizer, False)

def _test_optimizer_helper(self, optimizer, offload):
if offload:
optimizer = offload_optimizer(optimizer)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)

def create_opt_params(x):
return jax.tree.map(
lambda y: OptParam(
value=y,
factorization_spec=None,
weight_decay_scale=1.0,
),
x,
)

state = optimizer.init(create_opt_params(params))

param_spec = ParameterSpec(shape=[4], mesh_axes=PartitionSpec("model"), factorization=None)
state_partition_spec = optimizer.partition(param_spec)
Expand All @@ -399,13 +413,23 @@ def check_partition_spec(spec: OptStateSpec, tree):

jax.tree.map(check_partition_spec, state_partition_spec, state)

def compute_loss(x):
return -jax.nn.log_softmax(x)[1]
@jax.jit
def jit_fn(params, state):
def compute_loss(x):
return -jax.nn.log_softmax(x)[1]

loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
new_loss = compute_loss(updated_params)
params = create_opt_params(params)
loss, grads = jax.value_and_grad(compute_loss)(params.value)
updates, _ = optimizer.update(grads, state=state, params=params)
updated_params = optax.apply_updates(params.value, updates)
return loss, compute_loss(updated_params)

if offload:
self.assertIn(
"TransferToMemoryKind(memory_kind='pinned_host')",
str(jax.make_jaxpr(jit_fn)(params, state)),
)
loss, new_loss = jit_fn(params, state)
self.assertLess(new_loss, loss)

@parameterized.product(
Expand Down Expand Up @@ -788,14 +812,17 @@ def loss_fn(x):
config_for_function(drop_norm_by_grad_norm_ema).set(multipliers=[0.1, 1]),
config_for_function(drop_norm_by_grad_norm_stddev).set(multipliers=[20, 40]),
),
offload=(True, False),
)
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm):
def test_gradient_skipping_and_clipping(self, max_norm, drop_norm, offload):
clip = skip_and_clip_by_global_norm(
inner=_counter(),
drop_norm=drop_norm,
max_norm=max_norm,
grad_norm_ema_decay=0.99,
)
if offload:
clip = offload_optimizer(clip)
params = jnp.asarray([0, 1, 2, -3], dtype=jnp.float32)
state = clip.init(params)
init_ema = state.grad_norm_ema
Expand All @@ -821,7 +848,11 @@ def loss_fn(x):
else:
is_valid_step = drop_norm is None or g_norm < drop_norm

updates, state = clip.update(grads, state=state, params=params)
@jax.jit
def jit_fn(grads, state, params):
return clip.update(grads, state=state, params=params)

updates, state = jit_fn(grads, state, params)
if is_valid_step:
if max_norm is None or g_norm < max_norm:
np.testing.assert_allclose(updates, grads, atol=1e-6)
Expand Down
Loading
Loading