-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Refactor JobSet for Pathways #918
base: main
Are you sure you want to change the base?
Conversation
…ific tagged version
axlearn/cloud/gcp/tpu.py
Outdated
@@ -718,7 +718,7 @@ def infer_tpu_workers(tpu_type: str) -> int: | |||
tpu_version, tpu_cores = match.groups() | |||
if tpu_version in {"v3", "v4", "v5p"}: | |||
return int(tpu_cores) // 8 | |||
if tpu_version in {"v5litepod"}: | |||
if tpu_version in {"v5litepod", "v6e"}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rebase main? v6e is already supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
and "jax_backend proxy" in self.config.command | ||
) | ||
|
||
def _import_modules(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we need to import pathwaysutils
on the host that submits GKE jobs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a workaround from previous iterations of testing. The host doesn't need to import pathwaysutils (this library is only needed during the main training loop execution), removing it from here.
] | ||
|
||
if self.using_pathways: | ||
staging_location = f"{cfg.output_dir}/pathways-staging" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this location for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During runtime, Pathways maintains a compilation cache. Part of its internal processing is to create a lifecycle rule in the bucket where this cache is uploaded to automatically delete the data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - I assume the GCS read/write permission would be sufficient? In my earlier testing, I got the message
Persistent compilation cache cannot initialize storage "gcs": PERMISSION_DENIED: Permanent error, with a last message of [REDACTED_SERVICE_ACCOUNT] does not have storage.buckets.update access to the Google Cloud Storage bucket. Permission 'storage.buckets.update' denied on resource (or it may not exist).
The issue was only resolved after I granted the storage.buckets.update
permission to the service account. Can you please look into this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be related to the object lifecycle rule that Pathways implements in the backend. It does require storage.buckets.update to create the rule in the specified bucket.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIU we don't typically grant storage.buckets.update
to compute service accounts because of security concerns. Can you elaborate on the lifecycle rules and why are they required for Pathways?
cc @Ethanlm
axlearn/cloud/gcp/job.py
Outdated
""" | ||
cfg: TPUGKEJob.Config = self.config | ||
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type] | ||
staging_location = f"{cfg.output_dir}/pathways-staging/tmp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this different from the staging_location
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, both should point to the same location
request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE | ||
resources["limits"]["memory"] = f"{machine_memory_gi}Gi" | ||
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"} | ||
if not self.using_pathways: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate why this change is necessary?
axlearn/cloud/gcp/job.py
Outdated
# brittle implementation | ||
return ( | ||
"pathwaysutils" in self.config.import_modules | ||
and "jax_backend proxy" in self.config.command |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "jax_backend proxy" in self.config.command sufficient? Why do we need "pathwaysutils" in self.config.import_modules
?
On the other hand, using "jax_backend proxy" in self.config.command
can be error-prone. For example, the command could include --jax_backend=proxy
. The check here will not work properly.
I think "using pathways or not" should be an explicit flag. It is much more clear. Another benefit is that we can still get a pathways enabled jobset, even if the command doesn't contain "jax_backend proxy", for example, "sleep infinity", for debugging and other purposes
def _is_pathways_used(self) -> bool: | ||
# identify if a job is configured to use pathways by | ||
# checking jax_backend flag and optional import for pathways utils | ||
return "jax_backend=proxy" in self.config.command.replace(" ", "=") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a flag on TPUGKEJob to decide if pathways should be used, instead of checking the command? This is quite hacky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this requires user binaries to define this flag. Even though they don't use it.
The previous PR #746 should now be deprecated.
The current PR does the following:
Please note this PR is still undergoing testing. Opening draft to unblock @changlan