Skip to content

Commit

Permalink
[JAX] Test_multiprocessing_encoder with process spawn in bash (#1394)
Browse files Browse the repository at this point in the history
* add test_multiprocessing_encoder with processing spawning in bash

---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Jan 11, 2025
1 parent 7b861e7 commit a65ad37
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 49 deletions.
7 changes: 7 additions & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,10 @@ def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80


@lru_cache
def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 90
20 changes: 20 additions & 0 deletions examples/jax/encoder/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""config for test_multiprocessing_encoder"""
import pytest


def pytest_addoption(parser):
"""Pytest hook for test_multiprocessing_encoder"""
parser.addoption("--num-process", action="store", default=0)
parser.addoption("--process-id", action="store", default=0)


@pytest.fixture(autouse=True)
def multiprocessing_parses(request):
"""Fixture for querying num-process and process-id"""
if request.cls:
request.cls.num_process = int(request.config.getoption("--num-process"))
request.cls.process_id = int(request.config.getoption("--process-id"))
17 changes: 17 additions & 0 deletions examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
70 changes: 22 additions & 48 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# See LICENSE for license information.
"""Encoder training with multi-GPU, multiprocessing, and tensor parallelism"""
import argparse
import multiprocessing as mp
import os
import unittest
from functools import partial
import pytest

import flax
import jax
Expand All @@ -21,10 +21,10 @@
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding

from common import is_bf16_supported, is_fp8_supported
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
Expand Down Expand Up @@ -252,7 +252,6 @@ def eval_model(

def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
Expand Down Expand Up @@ -342,6 +341,9 @@ def replace_params(x):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
if args.process_id == 0:
nltk.download("punkt_tab")

train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)

jax.distributed.initialize(
Expand Down Expand Up @@ -551,69 +553,41 @@ def encoder_parser(args):
return parser.parse_args(args)


def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])


def unittest_query_gpu():
r"""
It is only used by TestEncoder.
The `jax.distributed.initialize` must be called before any other JAX or Flax API,
otherwise `jax.local_devices` will be incorrect.
Thus, fork another process to query number of GPUs and FP8 capability.
"""
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason


@pytest.mark.usefixtures("multiprocessing_parses")
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()
gpu_has_fp8 = is_fp8_supported()
gpu_has_bf16 = is_bf16_supported()

def exec(self, use_fp8):
"""Run 3 epochs for testing"""
num_gpu = self.num_gpu
args = encoder_parser([])

num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size
batch_size = 64 // dp_size

arg_list = []
for i in range(num_gpu):
args = encoder_parser([])
args.num_process = num_gpu
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.process_id = i
arg_list.append(args)

with mp.Pool(self.num_gpu) as p:
results = p.map(train_and_evaluate, arg_list)
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.num_process = num_gpu
args.process_id = self.process_id

return results
return train_and_evaluate(args)

@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
result = self.exec(False)
assert result[0] < 0.45 and result[1] > 0.79

@unittest.skipIf(not gpu_has_fp8, reason)
@unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8")
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
results = self.exec(True)
actual = results[0]
assert actual[0] < 0.45 and actual[1] > 0.79
result = self.exec(True)
assert result[0] < 0.45 and result[1] > 0.79


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh

0 comments on commit a65ad37

Please sign in to comment.