Skip to content

Commit

Permalink
Test Dask NN with UCX-Py
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Apr 10, 2024
1 parent ec2c2df commit 6046457
Showing 1 changed file with 139 additions and 28 deletions.
167 changes: 139 additions & 28 deletions python/cuml/tests/dask/test_dask_nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,33 +81,18 @@ def _scale_rows(client, nrows):
return n_workers * nrows


@pytest.mark.parametrize(
"nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)]
)
@pytest.mark.parametrize("ncols", [10, 30])
@pytest.mark.parametrize(
"nclusters", [unit_param(5), quality_param(10), stress_param(15)]
)
@pytest.mark.parametrize(
"n_neighbors", [unit_param(10), quality_param(4), stress_param(100)]
)
@pytest.mark.parametrize(
"n_parts",
[unit_param(1), unit_param(5), quality_param(7), stress_param(50)],
)
@pytest.mark.parametrize(
"streams_per_handle,reverse_worker_order", [(5, True), (10, False)]
)
def test_compare_skl(
def _test_compare_skl(
nrows,
ncols,
nclusters,
n_parts,
n_neighbors,
streams_per_handle,
reverse_worker_order,
client,
dask_client,
request,
):
client = request.getfixturevalue(dask_client)

from cuml.dask.neighbors import NearestNeighbors as daskNN

Expand Down Expand Up @@ -162,11 +147,89 @@ def test_compare_skl(
assert array_equal(y_hat, skl_y_hat)


@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)])
@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)])
@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)])
@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)])
def test_batch_size(nrows, ncols, n_parts, batch_size, client):
@pytest.mark.parametrize(
"nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)]
)
@pytest.mark.parametrize("ncols", [10, 30])
@pytest.mark.parametrize(
"nclusters", [unit_param(5), quality_param(10), stress_param(15)]
)
@pytest.mark.parametrize(
"n_neighbors", [unit_param(10), quality_param(4), stress_param(100)]
)
@pytest.mark.parametrize(
"n_parts",
[unit_param(1), unit_param(5), quality_param(7), stress_param(50)],
)
@pytest.mark.parametrize(
"streams_per_handle,reverse_worker_order", [(5, True), (10, False)]
)
def test_compare_skl(
nrows,
ncols,
nclusters,
n_parts,
n_neighbors,
streams_per_handle,
reverse_worker_order,
request,
):
_test_compare_skl(
nrows,
ncols,
nclusters,
n_parts,
n_neighbors,
streams_per_handle,
reverse_worker_order,
"client",
request,
)


@pytest.mark.parametrize(
"nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)]
)
@pytest.mark.parametrize("ncols", [10, 30])
@pytest.mark.parametrize(
"nclusters", [unit_param(5), quality_param(10), stress_param(15)]
)
@pytest.mark.parametrize(
"n_neighbors", [unit_param(10), quality_param(4), stress_param(100)]
)
@pytest.mark.parametrize(
"n_parts",
[unit_param(1), unit_param(5), quality_param(7), stress_param(50)],
)
@pytest.mark.parametrize(
"streams_per_handle,reverse_worker_order", [(5, True), (10, False)]
)
@pytest.mark.ucx
def test_compare_skl_ucx(
nrows,
ncols,
nclusters,
n_parts,
n_neighbors,
streams_per_handle,
reverse_worker_order,
request,
):
_test_compare_skl(
nrows,
ncols,
nclusters,
n_parts,
n_neighbors,
streams_per_handle,
reverse_worker_order,
"ucx_client",
request,
)


def _test_batch_size(nrows, ncols, n_parts, batch_size, dask_client, request):
client = request.getfixturevalue(dask_client)

n_neighbors = 10
n_clusters = 5
Expand Down Expand Up @@ -202,7 +265,25 @@ def test_batch_size(nrows, ncols, n_parts, batch_size, client):
assert array_equal(y_hat, y)


def test_return_distance(client):
@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)])
@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)])
@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)])
@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)])
def test_batch_size(nrows, ncols, n_parts, batch_size, request):
_test_batch_size(nrows, ncols, n_parts, batch_size, "client", request)


@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)])
@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)])
@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)])
@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)])
@pytest.mark.ucx
def test_batch_size_ucx(nrows, ncols, n_parts, batch_size, request):
_test_batch_size(nrows, ncols, n_parts, batch_size, "ucx_client", request)


def _test_return_distance(dask_client, request):
client = request.getfixturevalue(dask_client)

n_samples = 50
n_feats = 50
Expand Down Expand Up @@ -233,7 +314,17 @@ def test_return_distance(client):
assert len(ret) == 2


def test_default_n_neighbors(client):
def test_return_distance(request):
_test_return_distance("client", request)


@pytest.mark.ucx
def test_return_distance_ucx(request):
_test_return_distance("ucx_client", request)


def _test_default_n_neighbors(dask_client, request):
client = request.getfixturevalue(dask_client)

n_samples = 50
n_feats = 50
Expand Down Expand Up @@ -269,7 +360,18 @@ def test_default_n_neighbors(client):
assert ret.shape[1] == k


def test_one_query_partition(client):
def test_default_n_neighbors(request):
_test_default_n_neighbors("client", request)


@pytest.mark.ucx
def test_default_n_neighbors_ucx(request):
_test_default_n_neighbors("ucx_client", request)


def _test_one_query_partition(dask_client, request):
client = request.getfixturevalue(dask_client) # noqa

from cuml.dask.neighbors import NearestNeighbors as daskNN
from cuml.dask.datasets import make_blobs

Expand All @@ -280,3 +382,12 @@ def test_one_query_partition(client):
cumlModel = daskNN(n_neighbors=4)
cumlModel.fit(X_train)
cumlModel.kneighbors(X_test)


def test_one_query_partition(request):
_test_one_query_partition("client", request)


@pytest.mark.ucx
def test_one_query_partition_ucx(request):
_test_one_query_partition("ucx_client", request)

0 comments on commit 6046457

Please sign in to comment.