diff --git a/Dockerfile b/Dockerfile index 4a8021af5..7d170d206 100644 --- a/Dockerfile +++ b/Dockerfile @@ -97,12 +97,62 @@ COPY . . # GPU container spec. # ################################################################################ -FROM base AS gpu +# This causes INTERNAL: No valid engine configs for Matmul error +# FROM base AS gpu +# +# RUN apt-get update && apt-get install -y ibverbs-utils +# # TODO(markblee): Support extras. +# ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +# RUN pip install .[core,gpu] +# RUN pip install -U "jax[gpu]==0.4.37" "jax==0.4.37" "jaxlib==0.4.36" \ +# -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +# COPY . . + +FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 as gpu + +# Copy from original base +RUN apt-get update +RUN apt-get install -y apt-transport-https ca-certificates gnupg curl \ + gcc g++ python3 python3-venv ibverbs-utils +RUN ln -s /usr/bin/python3 /usr/bin/python + +# Install git. +RUN apt-get install -y git + +# Install gcloud. https://cloud.google.com/sdk/docs/install +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \ + apt-get update -y && apt-get install google-cloud-cli -y + +# Install screen and other utils for launch script. +RUN apt-get install -y jq screen ca-certificates + +# Setup. +RUN mkdir -p /root +WORKDIR /root +# Introduce the minimum set of files for install. +COPY README.md README.md +COPY pyproject.toml pyproject.toml +RUN mkdir axlearn && touch axlearn/__init__.py +# Setup venv to suppress pip warnings. +ENV VIRTUAL_ENV=/opt/venv +RUN python -m venv $VIRTUAL_ENV +ENV PATH="$VIRTUAL_ENV/bin:$PATH" +# Install dependencies. +RUN pip install flit +RUN pip install --upgrade pip +# End copy original base + + +RUN apt update -y && apt-get install -y google-perftools glibc-tools # TODO(markblee): Support extras. ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html RUN pip install .[core,gpu] COPY . . +RUN pip install -U "jax[gpu]==0.4.38" "jax==0.4.38" "jaxlib==0.4.38" +COPY . . + ################################################################################ # Final target spec. # diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index b1ad5e357..4abe9f30e 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -862,8 +862,8 @@ def _execute(self) -> Any: ) -class GPUGKEJob(GKEJob): - """A GPU job represented as a k8s JobSet. +class BaseGPUGKEJob(GKEJob): + """The base class for creating GPU job as a k8s JobSet. See also `gke_runner` as an example. """ @@ -886,7 +886,7 @@ def define_flags(cls, fv: flags.FlagValues): @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: - cfg: GPUGKEJob.Config = super().from_flags(fv, **kwargs) + cfg: BaseGPUGKEJob.Config = super().from_flags(fv, **kwargs) cfg.accelerator.set(instance_type=fv.instance_type, num_replicas=fv.num_replicas) return cfg @@ -896,18 +896,269 @@ def __init__(self, cfg: Config): if bundler_cfg is None or not issubclass(bundler_cfg.klass, BaseDockerBundler): raise NotImplementedError(f"Only docker bundler supported, got: {bundler_cfg}") super().__init__(cfg) - if cfg.gcsfuse_mount: - raise NotImplementedError("GCSFuse is not supported on GKE with GPU.") if cfg.enable_pre_provisioner: raise NotImplementedError("Pre-provisioner is not supported on GKE with GPU.") - instance_type = cfg.accelerator.instance_type - if not instance_type.startswith("gpu-a3-highgpu"): - raise NotImplementedError( - f"The instance type {instance_type} is not supported on GKE with GPU. " - "Only gpu-a3-highgpu-8g is supported." - ) - def _build_a3_sidecar_container(self) -> Nested[Any]: + def _build_init_containers(self) -> list[Nested[Any]]: + """Builds a sidecar container which is required by A3 + for GPU to GPU RDMA like networking. + + Returns: + A nested dict of the sidecar container. + """ + return [] + + def _build_main_container(self) -> Nested[Any]: + return NotImplementedError() + + def _build_volumes(self) -> list[Nested[Any]]: + """Builds a config for volumes.""" + volumes = [ + { + "name": "shared-memory", + "emptyDir": {"medium": "Memory"}, + }, + { + "name": "nvidia-install-dir-host", + "hostPath": {"path": "/home/kubernetes/bin/nvidia/lib64"}, + }, + ] + return volumes + + def _pod_annotations(self) -> dict[str, str]: + # By default try to use compact placement using Kueue Topology Aware scheduling. + # Docs: https://kueue.sigs.k8s.io/docs/concepts/topology_aware_scheduling/ + return {"kueue.x-k8s.io/podset-preferred-topology": "kubernetes.io/hostname"} + + def _build_pod(self) -> Nested[Any]: + """Builds a config for a single Pod, which is a set of containers. + + https://kubernetes.io/docs/concepts/workloads/pods + + Returns: + A nested dict corresponding to a k8s Pod template, including the pod metadata and spec. + """ + cfg: BaseGPUGKEJob.Config = self.config + volumes = self._build_volumes() + annotations = self._pod_annotations() + containers = [self._build_main_container()] + init_containers = self._build_init_containers() + + return dict( + metadata=dict(annotations=annotations), + spec=dict( + terminationGracePeriodSeconds=60, + # Fail if any pod fails, and allow retries to happen at JobSet level. + restartPolicy="Never", + initContainers=init_containers, + hostNetwork=True, + dnsPolicy="ClusterFirstWithHostNet", + containers=containers, + serviceAccountName=cfg.service_account, + volumes=volumes, + ), + ) + + def _build_job(self) -> Nested[Any]: + """Builds a config for a single Job, which is a set of Pods. + + https://kubernetes.io/docs/concepts/workloads/controllers/job/ + + Returns: + A nested dict corresponding to a k8s Job config, including the job metadata and spec. + """ + cfg: BaseGPUGKEJob.Config = self.config + + return dict( + spec=dict( + parallelism=cfg.accelerator.num_replicas, + completions=cfg.accelerator.num_replicas, + backoffLimit=0, # Fail the job if any node fails. Retries happen at JobSet level. + template=self._build_pod(), + ), + ) + + def _build_jobset(self) -> Nested[Any]: + """Builds a config for a JobSet, which is a set of Jobs. + + https://github.com/kubernetes-sigs/jobset/blob/d49514bee57da8ac9aec2fcea06c3a13c21afeae/docs/concepts/README.md + + Returns: + A nested dict corresponding to a k8s JobSet config. + """ + cfg: BaseGPUGKEJob.Config = self.config + annotations = {} + if cfg.queue: + annotations["kueue.x-k8s.io/queue-name"] = cfg.queue + + return dict( + metadata=dict( + name=cfg.name, + annotations=annotations, + ), + spec=dict( + failurePolicy=dict(maxRestarts=cfg.max_tries - 1), + replicatedJobs=[ + # NOTE: the suffix here impacts how long job names can be. + dict( + name="job", + replicas=1, + template=self._build_job(), + ), + ], + ), + ) + + def _delete(self): + cfg: BaseGPUGKEJob.Config = self.config + # Issues a delete request for the JobSet and proactively delete its descendants. This is not + # fully blocking; after the call returns there can be a delay before everything is deleted. + delete_k8s_jobset(cfg.name, namespace=cfg.namespace) + + def _execute(self) -> Any: + """Submits a JobSet to the cluster.""" + cfg: BaseGPUGKEJob.Config = self.config + api_kwargs = custom_jobset_kwargs() + custom_object = dict( + apiVersion=f"{api_kwargs['group']}/{api_kwargs['version']}", + kind="JobSet", + **self._build_jobset(), + ) + logging.info("Submitting JobSet body=%s api_kwargs=%s", custom_object, api_kwargs) + return k8s.client.CustomObjectsApi().create_namespaced_custom_object( + namespace=cfg.namespace, + body=custom_object, + **api_kwargs, + ) + + +class GPUGKEA3UltraJob(BaseGPUGKEJob): + """An a3-ultragpu-8g GPU job represented as a k8s JobSet.""" + + def _build_volumes(self) -> list[dict]: + volumes = super()._build_volumes() + volumes += [ + { + "name": "gib", + "hostPath": {"path": "/home/kubernetes/bin/gib"}, + }, + ] + return volumes + + def _build_main_container(self) -> Nested[Any]: + """Builds the config for the container running the job. + + Returns: + A nested dict corresponding to a k8s Container config. + """ + cfg: GPUGKEA3UltraJob.Config = self.config + + volume_mounts = [ + {"name": "shared-memory", "mountPath": "/dev/shm"}, + {"name": "nvidia-install-dir-host", "mountPath": "/usr/local/nvidia/lib64"}, + {"name": "gib", "mountPath": "/usr/local/gib"}, + ] + + env_vars: dict[str, str] = {} + env_vars["DISTRIBUTED_COORDINATOR"] = f"{cfg.name}-job-0-0.{cfg.name}:8080" + env_vars["NUM_PROCESSES"] = f"{cfg.accelerator.num_replicas}" + env_vars["LD_LIBRARY_PATH"] = "/usr/local/nvidia/lib64" + + default_xla_flags = [ + # Maxtext XLA flags: + # https://github.com/AI-Hypercomputer/gpu-recipes/blob/dc6ef1afc1492f05e5741356f00cf645a9f1b795/src/helm-charts/a3ultra/maxtext-training/templates/maxtext-configmap.yaml#L26-L38 + "--xla_gpu_enable_latency_hiding_scheduler=true", + "--xla_gpu_enable_triton_gemm=false", + "--xla_gpu_graph_level=0", + "--xla_gpu_all_reduce_combine_threshold_bytes=2147483648", + "--xla_gpu_all_gather_combine_threshold_bytes=2147483648", + "--xla_gpu_reduce_scatter_combine_threshold_bytes=16777216", + "--xla_gpu_enable_pipelined_all_gather=true", + "--xla_gpu_enable_pipelined_reduce_scatter=true", + "--xla_gpu_enable_pipelined_all_reduce=true", + "--xla_gpu_enable_all_gather_combine_by_dim=false", + "--xla_gpu_enable_reduce_scatter_combine_by_dim=false", + "--xla_disable_hlo_passes=rematerialization", + "--xla_gpu_enable_while_loop_double_buffering=true", + ] + env_vars["XLA_FLAGS"] = " ".join(default_xla_flags) + + # NCCL flags needed + env_vars.update( + { + # Enable auto PGLE available in jax 0.4.33 + "JAX_ENABLE_PGLE": "True", + "JAX_PGLE_PROFILING_RUNS": "3", + # This is needed for flash attention + auto PGLE to work + "JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY": "True", + # XLA team flags used this not sure what it does yet. See maxtext_wrapper.py + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + "NVTE_FUSED_ATTN": "1", + # Needed to help resolve GPU OOM on fuji v2 70B + "XLA_PYTHON_CLIENT_MEM_FRACTION": "0.85", + "TF_FORCE_GPU_ALLOW_GROWTH": "true", + "NCCL_DEBUG": "INFO", + "NCCL_SOCKET_IFNAME": "=eth0,eth1", + "NCCL_CROSS_NIC": "0", + "NCCL_NET_GDR_LEVEL": "PIX", + "NCCL_P2P_NET_CHUNKSIZE": "131072", + "NCCL_P2P_PCI_CHUNKSIZE": "131072", + "NCCL_P2P_NVL_CHUNKSIZE": "524288", + "NCCL_NVLS_CHUNKSIZE": "524288", + "NCCL_IB_GID_INDEX": "3", + "NCCL_IB_ADAPTIVE_ROUTING": "1", + "NCCL_IB_QPS_PER_CONNECTION": "4", + "NCCL_IB_TC": "52", + "NCCL_IB_FIFO_TC": "84", + "NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE": ( + "/usr/local/gib/configs/guest_config.txtpb" + ), + "NCCL_TUNER_CONFIG_PATH": "/usr/local/gib/configs/tuner_config.txtpb", + } + ) + + # Override env vars with user provided env vars. + env_vars.update(cfg.env_vars) + # K8s expects each env variable to be a dict. + k8s_env_vars = [{"name": name, "value": value} for name, value in env_vars.items()] + k8s_env_vars.append( + { + "name": "PROCESS_ID", + "valueFrom": { + "fieldRef": { + "fieldPath": ( + "metadata.annotations['batch.kubernetes.io/job-completion-index']" + ), + } + }, + }, + ) + + # command = ["bash", "-c", f"source /usr/local/gib/scripts/set_nccl_env.sh; {cfg.command}"] + + command = ["bash", "-c", cfg.command] + return dict( + name=cfg.name, + image=self._bundler.id(cfg.name), + ports=[ + dict(containerPort=8080), # Port for MXLA coordinator. + ], + securityContext=dict(privileged=True), + # TODO(markblee): Improve SIGTERM behavior for command. + command=command, + resources=dict(limits={"nvidia.com/gpu": "8"}), + env=k8s_env_vars, + volumeMounts=volume_mounts, + ) + + +class GPUGKEA3HighJob(BaseGPUGKEJob): + """An a3-highgpu-8g GPU job represented as a k8s JobSet.""" + + def _build_init_containers(self) -> list[Nested[Any]]: + return [self._build_tcpx_sidecar_container(), self._build_tcpx_nccl_plugin_init_container()] + + def _build_tcpx_sidecar_container(self) -> Nested[Any]: """Builds a sidecar container which is required by A3 for GPU to GPU RDMA like networking. @@ -941,6 +1192,7 @@ def _build_a3_sidecar_container(self) -> Nested[Any]: command=command, env=[{"name": "LD_LIBRARY_PATH", "value": "/usr/local/nvidia/lib64"}], volumeMounts=volume_mounts, + restartPolicy="A", ) def _build_main_container(self) -> Nested[Any]: @@ -949,7 +1201,7 @@ def _build_main_container(self) -> Nested[Any]: Returns: A nested dict corresponding to a k8s Container config. """ - cfg: GPUGKEJob.Config = self.config + cfg: GPUGKEA3HighJob.Config = self.config volume_mounts = [ {"name": "shared-memory", "mountPath": "/dev/shm"}, @@ -1054,7 +1306,7 @@ def _build_main_container(self) -> Nested[Any]: volumeMounts=volume_mounts, ) - def _build_a3_init_container(self) -> Nested[Any]: + def _build_tcpx_nccl_plugin_init_container(self) -> Nested[Any]: """Builds a config for a single container.""" volume_mounts = [ { @@ -1074,17 +1326,11 @@ def _build_a3_init_container(self) -> Nested[Any]: volumeMounts=volume_mounts, ) - def _build_volumes(self) -> Nested[Any]: + def _build_volumes(self) -> list[Nested[Any]]: """Builds a config for volumes.""" - volumes = [ - { - "name": "shared-memory", - "emptyDir": {"medium": "Memory"}, - }, - { - "name": "nvidia-install-dir-host", - "hostPath": {"path": "/home/kubernetes/bin/nvidia/lib64"}, - }, + + volumes = super()._build_volumes() + volumes += [ { "name": "tcpx-socket", "emptyDir": {}, @@ -1097,110 +1343,6 @@ def _build_volumes(self) -> Nested[Any]: return volumes - def _build_pod(self) -> Nested[Any]: - """Builds a config for a single Pod, which is a set of containers. - - https://kubernetes.io/docs/concepts/workloads/pods - - Returns: - A nested dict corresponding to a k8s Pod template, including the pod metadata and spec. - """ - cfg: GPUGKEJob.Config = self.config - volumes = self._build_volumes() - annotations = { - "kubectl.kubernetes.io/default-container": cfg.name, - } - - containers = [self._build_main_container(), self._build_a3_sidecar_container()] - init_containers = [self._build_a3_init_container()] - - return dict( - metadata=dict(annotations=annotations), - spec=dict( - terminationGracePeriodSeconds=60, - # Fail if any pod fails, and allow retries to happen at JobSet level. - restartPolicy="Never", - initContainers=init_containers, - hostNetwork=True, - dnsPolicy="ClusterFirstWithHostNet", - containers=containers, - serviceAccountName=cfg.service_account, - volumes=volumes, - ), - ) - - def _build_job(self) -> Nested[Any]: - """Builds a config for a single Job, which is a set of Pods. - - https://kubernetes.io/docs/concepts/workloads/controllers/job/ - - Returns: - A nested dict corresponding to a k8s Job config, including the job metadata and spec. - """ - cfg: GPUGKEJob.Config = self.config - - return dict( - spec=dict( - parallelism=cfg.accelerator.num_replicas, - completions=cfg.accelerator.num_replicas, - backoffLimit=0, # Fail the job if any node fails. Retries happen at JobSet level. - template=self._build_pod(), - ), - ) - - def _build_jobset(self) -> Nested[Any]: - """Builds a config for a JobSet, which is a set of Jobs. - - https://github.com/kubernetes-sigs/jobset/blob/d49514bee57da8ac9aec2fcea06c3a13c21afeae/docs/concepts/README.md - - Returns: - A nested dict corresponding to a k8s JobSet config. - """ - cfg: GPUGKEJob.Config = self.config - annotations = {} - if cfg.queue: - annotations["kueue.x-k8s.io/queue-name"] = cfg.queue - - return dict( - metadata=dict( - name=cfg.name, - annotations=annotations, - ), - spec=dict( - failurePolicy=dict(maxRestarts=cfg.max_tries - 1), - replicatedJobs=[ - # NOTE: the suffix here impacts how long job names can be. - dict( - name="job", - replicas=1, - template=self._build_job(), - ), - ], - ), - ) - - def _delete(self): - cfg: GPUGKEJob.Config = self.config - # Issues a delete request for the JobSet and proactively delete its descendants. This is not - # fully blocking; after the call returns there can be a delay before everything is deleted. - delete_k8s_jobset(cfg.name, namespace=cfg.namespace) - - def _execute(self) -> Any: - """Submits a JobSet to the cluster.""" - cfg: GPUGKEJob.Config = self.config - api_kwargs = custom_jobset_kwargs() - custom_object = dict( - apiVersion=f"{api_kwargs['group']}/{api_kwargs['version']}", - kind="JobSet", - **self._build_jobset(), - ) - logging.info("Submitting JobSet body=%s api_kwargs=%s", custom_object, api_kwargs) - return k8s.client.CustomObjectsApi().create_namespaced_custom_object( - namespace=cfg.namespace, - body=custom_object, - **api_kwargs, - ) - class CPUJob(GCPJob): """Executes arbitrary commands on CPU VMs.""" diff --git a/axlearn/cloud/gcp/job_test.py b/axlearn/cloud/gcp/job_test.py index bec1c43ee..5e883fdf1 100644 --- a/axlearn/cloud/gcp/job_test.py +++ b/axlearn/cloud/gcp/job_test.py @@ -592,13 +592,13 @@ def _job_config( ): with mock_gcp_settings([job.__name__, bundler.__name__], self._mock_settings): fv = flags.FlagValues() - job.GPUGKEJob.define_flags(fv) + job.GPUGKEA3HighJob.define_flags(fv) if service_account: fv.set_default("service_account", service_account) if num_replicas: fv.set_default("num_replicas", num_replicas) fv.mark_as_parsed() - cfg = job.GPUGKEJob.from_flags(fv) + cfg = job.GPUGKEA3HighJob.from_flags(fv) cfg.bundler = bundler_cls.from_spec([], fv=fv).set(image="test-image") cfg.accelerator.instance_type = "gpu-a3-highgpu-8g-256" cfg.queue = queue @@ -646,8 +646,8 @@ class Config(Bundler.Config): retry_interval=1, name="test", ) - gke_job: job.GPUGKEJob = cfg.instantiate() - job_cfg: job.GPUGKEJob.Config = gke_job.config + gke_job: job.GPUGKEA3HighJob = cfg.instantiate() + job_cfg: job.GPUGKEA3HighJob.Config = gke_job.config self.assertEqual("gpu-a3-highgpu-8g-256", job_cfg.accelerator.instance_type) if num_replicas is None: self.assertEqual(1, job_cfg.accelerator.num_replicas) @@ -666,7 +666,7 @@ def test_build_pod( num_replicas: Optional[int] = None, ): with self._job_config(bundler_cls, env_vars=env_vars, num_replicas=num_replicas) as cfg: - gke_job: job.GPUGKEJob = cfg.set( + gke_job: job.GPUGKEA3HighJob = cfg.set( name="test", ).instantiate() # pylint: disable-next=protected-access @@ -700,7 +700,7 @@ def test_build_jobset( queue: Optional[str] = None, ): with self._job_config(bundler_cls, queue=queue) as cfg: - gke_job: job.GPUGKEJob = cfg.set( + gke_job: job.GPUGKEA3HighJob = cfg.set( name="test", ).instantiate() # pylint: disable-next=protected-access diff --git a/axlearn/cloud/gcp/jobs/gke_runner.py b/axlearn/cloud/gcp/jobs/gke_runner.py index a1dc18392..d1c1fbaab 100644 --- a/axlearn/cloud/gcp/jobs/gke_runner.py +++ b/axlearn/cloud/gcp/jobs/gke_runner.py @@ -48,7 +48,14 @@ from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler from axlearn.cloud.gcp.config import gcp_settings from axlearn.cloud.gcp.event_queue import event_queue_from_config -from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GCPJob, GKEJob, GPUGKEJob, TPUGKEJob +from axlearn.cloud.gcp.job import ( + BASTION_JOB_VERSION_LABEL, + GCPJob, + GKEJob, + GPUGKEA3HighJob, + GPUGKEA3UltraJob, + TPUGKEJob, +) from axlearn.cloud.gcp.jobs import runner_utils from axlearn.cloud.gcp.jobs.tpu_runner import with_tpu_training_defaults from axlearn.cloud.gcp.node_pool import ( @@ -530,14 +537,20 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs): class GPUGKERunnerJob(GKERunnerJob): """A GKERunnerJob that uses GPUGKEJob.""" - inner = GPUGKEJob + @classmethod + def class_from_instance_type(cls, instance_type: str): + if instance_type.startswith("gpu-a3-high"): + cls.inner = GPUGKEA3HighJob + elif instance_type.startswith("gpu-a3-ultra"): + cls.inner = GPUGKEA3UltraJob + return cls def _get_runner_or_exit(instance_type: str): if instance_type.startswith("tpu"): return TPUGKERunnerJob elif instance_type.startswith("gpu-a3"): - return GPUGKERunnerJob + return GPUGKERunnerJob.class_from_instance_type(instance_type) else: raise app.UsageError(f"Unknown instance_type {instance_type}") diff --git a/axlearn/cloud/gcp/jobs/gke_runner_test.py b/axlearn/cloud/gcp/jobs/gke_runner_test.py index ea59ec28d..5dde16929 100644 --- a/axlearn/cloud/gcp/jobs/gke_runner_test.py +++ b/axlearn/cloud/gcp/jobs/gke_runner_test.py @@ -14,7 +14,7 @@ from axlearn.cloud.common.bastion import BASTION_JOB_VERSION_ENV_VAR from axlearn.cloud.gcp import bundler, node_pool_provisioner -from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GPUGKEJob, TPUGKEJob +from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GPUGKEA3HighJob, TPUGKEJob from axlearn.cloud.gcp.jobs import gke_runner from axlearn.cloud.gcp.jobs.bastion_vm_test import _mock_job from axlearn.cloud.gcp.jobs.gke_runner import ( @@ -112,7 +112,7 @@ def test_from_flags(self, name, cluster, service_account, gcsfuse_mount_spec): self.assertEqual(cfg.cluster, cluster or mock_settings["gke_cluster"]) self.assertEqual(cfg.service_account, service_account or "default") if gcsfuse_mount_spec: - fuse = cast(GPUGKEJob.Config, cfg.inner).gcsfuse_mount + fuse = cast(GPUGKEA3HighJob.Config, cfg.inner).gcsfuse_mount self.assertEqual(fuse.gcs_path, "my-test-path") @parameterized.product( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index bbd769dad..97b223ff0 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -504,6 +504,12 @@ def get_trainer_kwargs( "gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)", mesh_shape_from_axes(data=-1, fsdp=128), ), + # v2 on a3-ultragpu-8g-256 8x32, step time 15.493s. + # v2 on a3-ultragpu-8g-512 8x64, step time 8.184s. + ( + "gpu-(a3-ultragpu-8g)-(256|512|1024)", + mesh_shape_from_axes(data=-1, fsdp=64), + ), ( "neuron-(trn2|trn2n).48xlarge-64", ChainConfigModifier.default_config().set(