Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodput & Badput recording and monitoring support. #783

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ def record(self, event: measurement.Event, *args, **kwargs):
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
Expand Down
64 changes: 64 additions & 0 deletions axlearn/cloud/gcp/monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright © 2024 Apple Inc.

"""Goodput & Badput computation and monitoring utils for GCP."""

import jax
from absl import flags, logging
from ml_goodput_measurement import monitoring as goodput_monitoring

from axlearn.cloud.common.utils import parse_kv_flags
from axlearn.common import monitoring
from axlearn.common.config import maybe_set_config


@monitoring.register_monitor("GoodputMonitor")
class GoodputMonitor(monitoring.Monitor):
"""Computes and uploads overall training goodput and optionally badput."""

Config = monitoring.Monitor.Config

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputMonitor":
"""Converts flags to a GoodputMonitor.

`fv.monitor_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values. A GoodputMonitor can
additionally take in following Tensorboard configs in the monitor_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
to Tensorboard.
"""
cfg: monitoring.Monitor.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.monitor_spec, delimiter="="))
return cfg.instantiate()

def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputMonitor.Config = self.config
self._monitor = None

def start_monitoring(self, *args, **kwargs):
# Instantiate ml-goodput-measurement's GoodputMonitor
# to asynchronously calculate goodput and badput at
# the upload_interval and upload to the specified
# tensorboard directory.
if self._monitor is None:
cfg: GoodputMonitor.Config = self.config
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
monitoring_enabled=(jax.process_index() == 0),
include_badput_breakdown=True,
)

if self._monitor:
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
else:
logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)
5 changes: 4 additions & 1 deletion axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

from absl import app, flags

from axlearn.common import launch, launch_trainer, measurement
from axlearn.common import launch, launch_trainer, measurement, monitoring
from axlearn.common.config import config_for_function


def main(_):
measurement.initialize(flags.FLAGS)
monitoring.initialize(flags.FLAGS)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
monitoring.start_monitoring()
launch_trainer.run_trainer(trainer_config)


if __name__ == "__main__":
measurement.define_flags()
monitoring.define_flags()
app.run(main)
12 changes: 12 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ class Event(enum.Enum):
START_JOB: Start of job.
END_JOB: End of job.
START_STEP: Start of a training step. Should be recorded with `step` as a positional arg.
START_ACCELERATOR_INIT: Start of accelerator mesh initialization.
END_ACCELERATOR_INIT: End of accelerator mesh initialization.
START_TRAINING_PREPARATION: Start of training preparation.
END_TRAINING_PREPARATION: End of training preparation.
START_DATA_LOADING: Start of data loading.
END_DATA_LOADING: End of data loading.
"""

START_JOB = "START_JOB"
END_JOB = "END_JOB"
START_STEP = "START_STEP"
START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT"
END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT"
START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION"
END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION"
START_DATA_LOADING = "START_DATA_LOADING"
END_DATA_LOADING = "END_DATA_LOADING"


class Recorder(Configurable):
Expand Down
113 changes: 113 additions & 0 deletions axlearn/common/monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright © 2024 Apple Inc.

"""Asynchronously compute and monitor metrics like goodput and badput."""

import importlib
from typing import Optional, TypeVar

from absl import flags, logging

from axlearn.common.config import REQUIRED, Configurable, Required, config_class


class Monitor(Configurable):
"""The base interface for computing and monitoring metrics."""

@config_class
class Config(Configurable.Config):
"""Configures any type of Monitor.

Attributes:
name: Name of the monitor (example: GoodputMonitor).
upload_dir: Storage directory where metrics are uploaded.
upload_interval: Time interval (seconds) at which to query and upload metrics.
"""

name: Required[str] = REQUIRED
upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED

@classmethod
def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Monitor":
"""Converts flags to a monitor."""
raise NotImplementedError(cls)

def start_monitoring(self, **kwargs):
"""Starts computing and uploading metrics at some configured interval in the background."""
raise NotImplementedError(type(self))


_monitors: dict[str, type] = {}
_T = TypeVar("_T")


def register_monitor(name: str):
def fn(cls: _T) -> _T:
"""Registers a monitor into a dict of global monitors with reference to its class type."""
if name in _monitors:
raise ValueError(f"Monitor {name} is already registered.")
_monitors[name] = cls
return cls

return fn


def define_flags(**kwargs):
"""Common monitoring flags."""

flags.DEFINE_string(
"monitor_type",
None,
"The monitor type. It can be a monitor name, e.g. `GoodputMonitor`, or "
"a module paired with a monitor name, e.g. `my.module:my_monitor`.",
**kwargs,
)
flags.DEFINE_multi_string(
"monitor_spec",
[],
"Monitor spec provided as key=value. "
"Refer to each monitor's `from_flags` method docstring for details.",
**kwargs,
)


global_monitor: Optional[Monitor] = None


def initialize(fv: flags.FlagValues):
"""Initializes the monitor from flags."""
global global_monitor
if not fv.monitor_type:
logging.info("No monitor type specified, skipping monitoring initialize().")
return
if global_monitor is None:
# Infer module from monitor_type.
parts = fv.monitor_type.split(":", 1)
if len(parts) > 1:
logging.info("Registering monitors in %s", parts[0])
importlib.import_module(parts[0])
if monitor_class := _monitors.get(parts[-1], None):
# This will instantiate a specific monitor of monitor_type if supported.
global_monitor = monitor_class.from_flags(fv=fv)
else:
raise NotImplementedError(
f"Monitor type: {fv.monitor_type} is not supported. "
f"Supported types are: {sorted(list(_monitors.keys()))}\n"
"You can also specify a specific module to identify the monitor "
"(e.g., `my.module:my_monitor`)."
)
logging.info("Initialized global monitor: %s", global_monitor)
else:
logging.warning(
"Monitor %s is already initialized, ignoring monitoring initialize().",
global_monitor,
)


def start_monitoring():
"""Begins monitoring events as per global monitor functionality."""
if global_monitor is None:
logging.log_first_n(logging.INFO, "No Monitor configured, no events will be monitored.", 1)
else:
global_monitor.start_monitoring()
logging.info("Starting monitoring of events using global monitor: %s", global_monitor)
8 changes: 8 additions & 0 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(
utils.validate_float_dtype(cfg.train_dtype)

# Create the device mesh.
self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT)
if devices is None:
self._step_log(
"Devices: global=%s local=%s %s",
Expand Down Expand Up @@ -304,6 +305,7 @@ def __init__(
model=self.model,
model_param_partition_specs=model_param_partition_specs,
)
self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved

@property
def step(self):
Expand Down Expand Up @@ -736,6 +738,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
# Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`.
self.restore_checkpoint(restore_step=None)

self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION)
if self.step is None:
# If we didn't restore from checkpoint, attempt to build initial state according
# to `cfg.init_state_builder` and initialize the remaining parameters.
Expand All @@ -751,6 +754,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
f.write(str(jax.tree_util.tree_structure(self._trainer_state)))

self._log_trainer_state_stats()
self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
# Log config.
self.summary_writer.log_config(cfg, step=self.step)

Expand Down Expand Up @@ -787,6 +791,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
restore_input_iter = cfg.save_input_iterator
try:
# Try to restore with `input_iter`.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
state=(
Expand All @@ -800,13 +805,15 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
step,
restore_input_iter,
)
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
except ValueError as e:
logging.warning(
"Attempt to restore checkpoint with restore_input_iter=%s failed: %s",
restore_input_iter,
e,
)
# Restore with a different restore_input_iter setting.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
restore_input_iter = not restore_input_iter
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
Expand All @@ -821,6 +828,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
step,
restore_input_iter,
)
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
if step is not None:
self._step = step
self._trainer_state = TrainerState(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ gcp = [
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
"google-cloud-core==2.3.3",
"google-cloud-build==3.24.1",
"ml_goodput_measurement==0.0.2",
"ml-goodput-measurement==0.0.4",
"pika==1.3.2", # used by event queue
"pyOpenSSL>=22.1.0", # compat with cryptography version.
]
Expand Down