Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodput & Badput recording and monitoring support. #783

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
63 changes: 60 additions & 3 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,39 @@
import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
from ml_goodput_measurement import monitoring as goodput_monitoring

from axlearn.cloud.common.utils import parse_kv_flags
from axlearn.common import measurement
from axlearn.common.config import maybe_set_config
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config


@measurement.register_recorder("goodput")
class GoodputRecorder(measurement.Recorder):
"""Records overall training goodput."""

Config = measurement.Recorder.Config
@config_class
class Config(measurement.Recorder.Config):
"""Configures GoodputRecorder.

Attributes:
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
"""

upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
"""Converts flags to a recorder.

`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values.
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
to Tensorboard.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
Expand All @@ -32,6 +47,7 @@ def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
Expand All @@ -49,10 +65,51 @@ def record(self, event: measurement.Event, *args, **kwargs):
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
"Ignoring unknown event %s",
1,
event,
)

def start_monitoring(self, *args, **kwargs):
"""Starts Monitoring of Goodput.

Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
directory.
Note: This function requires initialization of distributed JAX before it is called.
"""
if self._monitor is None:
cfg: GoodputRecorder.Config = self.config
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
monitoring_enabled=(jax.process_index() == 0),
include_badput_breakdown=True,
)

if self._monitor:
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
else:
logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)
markblee marked this conversation as resolved.
Show resolved Hide resolved
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
41 changes: 38 additions & 3 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class GoodputRecorderTest(parameterized.TestCase):
"""Tests GoodputRecorder."""

@parameterized.parameters(None, ["name=test-name"])
@parameterized.parameters(
(None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],)
)
def test_from_flags(self, spec):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
Expand All @@ -34,13 +36,46 @@ def test_from_flags(self, spec):
# Recorder is not instantiated until first event.
self.assertIsNone(recorder._recorder)

def test_record(self):
def test_record_and_monitor(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default("recorder_spec", ["name=test-name"])
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._recorder = mock.MagicMock()
recorder.record(measurement.Event.START_JOB)
self.assertTrue(recorder._recorder.record_job_start_time.called)

def test_start_monitoring(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test the failure scenario?

fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None

with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor:
mock_monitor_instance = mock_goodput_monitor.return_value
recorder.start_monitoring()

# Check that GoodputMonitor was instantiated
mock_goodput_monitor.assert_called_once_with(
job_name="test-name",
logger_name="goodput_logger_test-name",
tensorboard_dir="/test/path/to/upload",
upload_interval=15,
monitoring_enabled=True,
include_badput_breakdown=True,
)

# Ensure that start_goodput_uploader is called on the monitor instance
mock_monitor_instance.start_goodput_uploader.assert_called_once()
self.assertIsNotNone(recorder._monitor)
1 change: 1 addition & 0 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main(_):
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
measurement.start_monitoring()
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
launch_trainer.run_trainer(trainer_config)


Expand Down
29 changes: 29 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ class Event(enum.Enum):
START_JOB: Start of job.
END_JOB: End of job.
START_STEP: Start of a training step. Should be recorded with `step` as a positional arg.
START_ACCELERATOR_INIT: Start of accelerator mesh initialization.
END_ACCELERATOR_INIT: End of accelerator mesh initialization.
START_TRAINING_PREPARATION: Start of training preparation.
END_TRAINING_PREPARATION: End of training preparation.
START_DATA_LOADING: Start of data loading.
END_DATA_LOADING: End of data loading.
"""

START_JOB = "START_JOB"
END_JOB = "END_JOB"
START_STEP = "START_STEP"
START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT"
END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT"
START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION"
END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION"
START_DATA_LOADING = "START_DATA_LOADING"
END_DATA_LOADING = "END_DATA_LOADING"


class Recorder(Configurable):
Expand All @@ -47,6 +59,10 @@ def record(self, event: Event, *args, **kwargs):
"""Records an event with the given name."""
raise NotImplementedError(type(self))

def start_monitoring(self, **kwargs):
"""Starts computing and uploading metrics at some configured interval in the background."""
raise NotImplementedError(type(self))
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved


_recorders: dict[str, type] = {}
_T = TypeVar("_T")
Expand Down Expand Up @@ -120,3 +136,16 @@ def record_event(event: Event):
logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1)
else:
global_recorder.record(event)


def start_monitoring():
"""Begins monitoring events as per global monitor functionality."""
if global_recorder is None:
logging.log_first_n(
logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1
)
else:
global_recorder.start_monitoring()
logging.info(
"Starting monitoring of events using global recorder's monitor: %s", global_recorder
)
7 changes: 7 additions & 0 deletions axlearn/common/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected):
with mock.patch.object(measurement.global_recorder, "record") as mock_record:
measurement.record_event(measurement.Event.START_JOB)
self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0])

# Ensure that start_monitoring does not fail.
with mock.patch.object(
measurement.global_recorder, "start_monitoring"
) as mock_start_monitoring:
measurement.start_monitoring()
mock_start_monitoring.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ class DummyRecorder(measurement.Recorder):
@classmethod
def from_flags(cls, fv) -> measurement.Recorder:
del fv
return cls.default_config().set(name="dummy_recorder").instantiate()
return (
cls.default_config()
.set(name="dummy_recorder", upload_dir="/dummy/upload_dir", upload_interval=15)
.instantiate()
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ gcp = [
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
"google-cloud-core==2.3.3",
"google-cloud-build==3.24.1",
"ml_goodput_measurement==0.0.2",
"ml-goodput-measurement==0.0.4",
"pika==1.3.2", # used by event queue
"pyOpenSSL>=22.1.0", # compat with cryptography version.
"tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info
Expand Down