Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Refactor JobSet for Pathways #918

Draft
wants to merge 64 commits into
base: main
Choose a base branch
from

Conversation

jiya-zhang
Copy link
Contributor

The previous PR #746 should now be deprecated.

The current PR does the following:

  • Adds support for v6e. Users can now submit jobs on v6e
  • Adds support for pathways jobs
  • Includes additional flags for testing purposes. Eg. adding "--pdbs=1" to your command will set Per Device Batch Size

Please note this PR is still undergoing testing. Opening draft to unblock @changlan

jesus-orozco and others added 30 commits September 9, 2024 17:06
@@ -718,7 +718,7 @@ def infer_tpu_workers(tpu_type: str) -> int:
tpu_version, tpu_cores = match.groups()
if tpu_version in {"v3", "v4", "v5p"}:
return int(tpu_cores) // 8
if tpu_version in {"v5litepod"}:
if tpu_version in {"v5litepod", "v6e"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rebase main? v6e is already supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@jiya-zhang
Copy link
Contributor Author

cc @jesus-orozco

@jiya-zhang jiya-zhang changed the title Draft: Refactored JobSet for Pathways and v6e Draft: Refactor JobSet for Pathways Jan 14, 2025
axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
axlearn/common/trainer.py Outdated Show resolved Hide resolved
axlearn/cloud/gcp/utils.py Outdated Show resolved Hide resolved
axlearn/common/launch.py Outdated Show resolved Hide resolved
axlearn/common/launch_trainer.py Outdated Show resolved Hide resolved
and "jax_backend proxy" in self.config.command
)

def _import_modules(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we need to import pathwaysutils on the host that submits GKE jobs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a workaround from previous iterations of testing. The host doesn't need to import pathwaysutils (this library is only needed during the main training loop execution), removing it from here.

]

if self.using_pathways:
staging_location = f"{cfg.output_dir}/pathways-staging"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this location for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During runtime, Pathways maintains a compilation cache. Part of its internal processing is to create a lifecycle rule in the bucket where this cache is uploaded to automatically delete the data.

Copy link
Contributor

@changlan changlan Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I assume the GCS read/write permission would be sufficient? In my earlier testing, I got the message

Persistent compilation cache cannot initialize storage "gcs": PERMISSION_DENIED: Permanent error, with a last message of [REDACTED_SERVICE_ACCOUNT] does not have storage.buckets.update access to the Google Cloud Storage bucket. Permission 'storage.buckets.update' denied on resource (or it may not exist).

The issue was only resolved after I granted the storage.buckets.update permission to the service account. Can you please look into this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be related to the object lifecycle rule that Pathways implements in the backend. It does require storage.buckets.update to create the rule in the specified bucket.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIU we don't typically grant storage.buckets.update to compute service accounts because of security concerns. Can you elaborate on the lifecycle rules and why are they required for Pathways?

cc @Ethanlm

"""
cfg: TPUGKEJob.Config = self.config
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
staging_location = f"{cfg.output_dir}/pathways-staging/tmp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this different from the staging_location above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, both should point to the same location

request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
if not self.using_pathways:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate why this change is necessary?

# brittle implementation
return (
"pathwaysutils" in self.config.import_modules
and "jax_backend proxy" in self.config.command
Copy link
Contributor

@Ethanlm Ethanlm Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "jax_backend proxy" in self.config.command sufficient? Why do we need "pathwaysutils" in self.config.import_modules?

On the other hand, using "jax_backend proxy" in self.config.command can be error-prone. For example, the command could include --jax_backend=proxy. The check here will not work properly.

I think "using pathways or not" should be an explicit flag. It is much more clear. Another benefit is that we can still get a pathways enabled jobset, even if the command doesn't contain "jax_backend proxy", for example, "sleep infinity", for debugging and other purposes

def _is_pathways_used(self) -> bool:
# identify if a job is configured to use pathways by
# checking jax_backend flag and optional import for pathways utils
return "jax_backend=proxy" in self.config.command.replace(" ", "=")
Copy link
Contributor

@Ethanlm Ethanlm Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a flag on TPUGKEJob to decide if pathways should be used, instead of checking the command? This is quite hacky.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this requires user binaries to define this flag. Even though they don't use it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants