Skip to content

Commit

Permalink
Enable Dask tests with UCX-Py/UCXX in CI (#5697)
Browse files Browse the repository at this point in the history
Enable Dask tests in CI with TCP/UCX-Py/UCXX comms. The heavy-lifting is entirely done in `raft-dask`, so no functional changes are required in `cuml`

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Jake Awe (https://github.com/AyodeAwe)
  - Divye Gala (https://github.com/divyegala)

URL: #5697
  • Loading branch information
pentschev authored May 18, 2024
1 parent 68d4336 commit 4dec229
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 42 deletions.
1 change: 0 additions & 1 deletion ci/release/update-version.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ CURRENT_SHORT_TAG=${CURRENT_MAJOR}.${CURRENT_MINOR}
NEXT_MAJOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[1]}')
NEXT_MINOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[2]}')
NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR}
NEXT_UCX_PY_VERSION="$(curl -sL https://version.gpuci.io/rapids/${NEXT_SHORT_TAG}).*"

# Need to distutils-normalize the original version
NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_SHORT_TAG}'))")
Expand Down
9 changes: 8 additions & 1 deletion ci/run_cuml_dask_pytests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@
# Support invoking run_cuml_dask_pytests.sh outside the script directory
cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cuml/tests/dask

python -m pytest --cache-clear "$@" .
rapids-logger "pytest cuml-dask (No UCX-Py/UCXX)"
timeout 2h python -m pytest --cache-clear "$@" .

rapids-logger "pytest cuml-dask (UCX-Py only)"
timeout 5m python -m pytest --cache-clear --run_ucx "$@" .

rapids-logger "pytest cuml-dask (UCXX only)"
timeout 5m python -m pytest --cache-clear --run_ucxx "$@" .
2 changes: 1 addition & 1 deletion ci/test_python_dask.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ EXITCODE=0
trap "EXITCODE=1" ERR
set +e

rapids-logger "pytest cuml-dask"
# Run tests
./ci/run_cuml_dask_pytests.sh \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cuml-dask.xml" \
--cov-config=../../../.coveragerc \
Expand Down
71 changes: 60 additions & 11 deletions python/cuml/tests/dask/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import pytest

Expand Down Expand Up @@ -34,18 +34,8 @@ def client(cluster):

@pytest.fixture(scope="module")
def ucx_cluster():
initialize.initialize(
create_cuda_context=True,
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
)
cluster = LocalCUDACluster(
protocol="ucx",
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()
Expand All @@ -57,3 +47,62 @@ def ucx_client(ucx_cluster):
client = Client(ucx_cluster)
yield client
client.close()


@pytest.fixture(scope="module")
def ucxx_cluster():
cluster = LocalCUDACluster(
protocol="ucxx",
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()


@pytest.fixture(scope="function")
def ucxx_client(ucxx_cluster):
pytest.importorskip("distributed_ucxx")

client = Client(ucxx_cluster)
yield client
client.close()


def pytest_addoption(parser):
group = parser.getgroup("Dask cuML Custom Options")

group.addoption(
"--run_ucx", action="store_true", help="run _only_ UCX-Py tests"
)

group.addoption(
"--run_ucxx", action="store_true", help="run _only_ UCXX tests"
)


def pytest_collection_modifyitems(config, items):
if config.getoption("--run_ucx"):
skip_others = pytest.mark.skip(
reason="only runs when --run_ucx is not specified"
)
for item in items:
if "ucx" not in item.keywords:
item.add_marker(skip_others)
else:
skip_ucx = pytest.mark.skip(reason="requires --run_ucx to run")
for item in items:
if "ucx" in item.keywords:
item.add_marker(skip_ucx)

if config.getoption("--run_ucxx"):
skip_others = pytest.mark.skip(
reason="only runs when --run_ucxx is not specified"
)
for item in items:
if "ucxx" not in item.keywords:
item.add_marker(skip_others)
else:
skip_ucxx = pytest.mark.skip(reason="requires --run_ucxx to run")
for item in items:
if "ucxx" in item.keywords:
item.add_marker(skip_ucxx)
Loading

0 comments on commit 4dec229

Please sign in to comment.