Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodput & Badput recording and monitoring support. #783

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
62 changes: 59 additions & 3 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,39 @@
import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
from ml_goodput_measurement import monitoring as goodput_monitoring

from axlearn.cloud.common.utils import parse_kv_flags
from axlearn.common import measurement
from axlearn.common.config import maybe_set_config
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config


@measurement.register_recorder("goodput")
class GoodputRecorder(measurement.Recorder):
"""Records overall training goodput."""

Config = measurement.Recorder.Config
@config_class
class Config(measurement.Recorder.Config):
"""Configures GoodputRecorder.

Attributes:
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
"""

upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
"""Converts flags to a recorder.

`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values.
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
to Tensorboard.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
Expand All @@ -32,6 +47,7 @@ def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
Expand All @@ -49,10 +65,50 @@ def record(self, event: measurement.Event, *args, **kwargs):
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
"Ignoring unknown event %s",
1,
event,
)

def start_monitoring(self, *args, **kwargs):
"""Starts Monitoring of Goodput.

Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
directory.
Note: This function requires initialization of distributed JAX before it is called.
"""
if self._monitor is None:
cfg: GoodputRecorder.Config = self.config
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
monitoring_enabled=(jax.process_index() == 0),
include_badput_breakdown=True,
)
if not self._monitor:
# This could happen if there are internal errors (such as access errors) from GCP services such as Cloud Logging or Cloud Storage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I just triggered the CI, sorry for not doing so early. (I suspect lines like this will fail pylint for being too long.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will re-run precomit checks.

logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
Comment on lines +106 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not self._monitor:
# This could happen if there are internal errors (such as access errors) from GCP services such as Cloud Logging or Cloud Storage.
logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
if self._monitor:
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
else:
# This could happen if there are internal errors (such as access errors) from GCP services such as Cloud Logging or Cloud Storage.
logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)

So that we check that self._monitor is valid before invoking start_goodput_uploader.

BTW, it's still unclear to me how self._monitor can be None after we construct the instance of GoodputMonitor. Are we missing a try/catch somewhere? Does the __init__ method of GoodputMonitor raise an exception (that seems a bit unexpected)?

41 changes: 38 additions & 3 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class GoodputRecorderTest(parameterized.TestCase):
"""Tests GoodputRecorder."""

@parameterized.parameters(None, ["name=test-name"])
@parameterized.parameters(
(None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],)
)
def test_from_flags(self, spec):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
Expand All @@ -34,13 +36,46 @@ def test_from_flags(self, spec):
# Recorder is not instantiated until first event.
self.assertIsNone(recorder._recorder)

def test_record(self):
def test_record_and_monitor(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default("recorder_spec", ["name=test-name"])
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._recorder = mock.MagicMock()
recorder.record(measurement.Event.START_JOB)
self.assertTrue(recorder._recorder.record_job_start_time.called)

def test_start_monitoring(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test the failure scenario?

fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None

with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor:
mock_monitor_instance = mock_goodput_monitor.return_value
recorder.start_monitoring()

# Check that GoodputMonitor was instantiated
mock_goodput_monitor.assert_called_once_with(
job_name="test-name",
logger_name="goodput_logger_test-name",
tensorboard_dir="/test/path/to/upload",
upload_interval=15,
monitoring_enabled=True,
include_badput_breakdown=True,
)

# Ensure that start_goodput_uploader is called on the monitor instance
mock_monitor_instance.start_goodput_uploader.assert_called_once()
self.assertIsNotNone(recorder._monitor)
1 change: 1 addition & 0 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main(_):
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
measurement.start_monitoring()
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
launch_trainer.run_trainer(trainer_config)


Expand Down
29 changes: 29 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ class Event(enum.Enum):
START_JOB: Start of job.
END_JOB: End of job.
START_STEP: Start of a training step. Should be recorded with `step` as a positional arg.
START_ACCELERATOR_INIT: Start of accelerator mesh initialization.
END_ACCELERATOR_INIT: End of accelerator mesh initialization.
START_TRAINING_PREPARATION: Start of training preparation.
END_TRAINING_PREPARATION: End of training preparation.
START_DATA_LOADING: Start of data loading.
END_DATA_LOADING: End of data loading.
"""

START_JOB = "START_JOB"
END_JOB = "END_JOB"
START_STEP = "START_STEP"
START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT"
END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT"
START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION"
END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION"
START_DATA_LOADING = "START_DATA_LOADING"
END_DATA_LOADING = "END_DATA_LOADING"


class Recorder(Configurable):
Expand All @@ -47,6 +59,10 @@ def record(self, event: Event, *args, **kwargs):
"""Records an event with the given name."""
raise NotImplementedError(type(self))

def start_monitoring(self, **kwargs):
"""Starts computing and uploading metrics at some configured interval in the background."""
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's raise NotImplementedError(type(self)) and let subclasses decide whether to implement -- it should be fairly straightforward for a subclass to decide to not monitor, but we want the decision to be explicit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, it was originally raising NotImplementedError(type(self)), but I suggested to change to pass to avoid breaking subclasses upon axlearn bump.

Do we have tests to catch such breakage?



_recorders: dict[str, type] = {}
_T = TypeVar("_T")
Expand Down Expand Up @@ -120,3 +136,16 @@ def record_event(event: Event):
logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1)
else:
global_recorder.record(event)


def start_monitoring():
"""Begins monitoring events as per global monitor functionality."""
if global_recorder is None:
logging.log_first_n(
logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1
)
else:
global_recorder.start_monitoring()
logging.info(
"Starting monitoring of events using global recorder's monitor: %s", global_recorder
)
7 changes: 7 additions & 0 deletions axlearn/common/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected):
with mock.patch.object(measurement.global_recorder, "record") as mock_record:
measurement.record_event(measurement.Event.START_JOB)
self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0])

# Ensure that start_monitoring does not fail.
with mock.patch.object(
measurement.global_recorder, "start_monitoring"
) as mock_start_monitoring:
measurement.start_monitoring()
mock_start_monitoring.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ class DummyRecorder(measurement.Recorder):
@classmethod
def from_flags(cls, fv) -> measurement.Recorder:
del fv
return cls.default_config().set(name="dummy_recorder").instantiate()
return (
cls.default_config()
.set(name="dummy_recorder", upload_dir="/dummy/upload_dir", upload_interval=15)
.instantiate()
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ gcp = [
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
"google-cloud-core==2.3.3",
"google-cloud-build==3.24.1",
"ml_goodput_measurement==0.0.2",
"ml-goodput-measurement==0.0.4",
"pika==1.3.2", # used by event queue
"pyOpenSSL>=22.1.0", # compat with cryptography version.
"tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info
Expand Down
Loading