diff --git a/python/cuml/tests/dask/test_dask_nearest_neighbors.py b/python/cuml/tests/dask/test_dask_nearest_neighbors.py index 9dbd4dc010..4511978252 100644 --- a/python/cuml/tests/dask/test_dask_nearest_neighbors.py +++ b/python/cuml/tests/dask/test_dask_nearest_neighbors.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -81,24 +81,7 @@ def _scale_rows(client, nrows): return n_workers * nrows -@pytest.mark.parametrize( - "nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)] -) -@pytest.mark.parametrize("ncols", [10, 30]) -@pytest.mark.parametrize( - "nclusters", [unit_param(5), quality_param(10), stress_param(15)] -) -@pytest.mark.parametrize( - "n_neighbors", [unit_param(10), quality_param(4), stress_param(100)] -) -@pytest.mark.parametrize( - "n_parts", - [unit_param(1), unit_param(5), quality_param(7), stress_param(50)], -) -@pytest.mark.parametrize( - "streams_per_handle,reverse_worker_order", [(5, True), (10, False)] -) -def test_compare_skl( +def _test_compare_skl( nrows, ncols, nclusters, @@ -106,8 +89,10 @@ def test_compare_skl( n_neighbors, streams_per_handle, reverse_worker_order, - client, + dask_client, + request, ): + client = request.getfixturevalue(dask_client) from cuml.dask.neighbors import NearestNeighbors as daskNN @@ -162,11 +147,89 @@ def test_compare_skl( assert array_equal(y_hat, skl_y_hat) -@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)]) -@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)]) -@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)]) -@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)]) -def test_batch_size(nrows, ncols, n_parts, batch_size, client): +@pytest.mark.parametrize( + "nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)] +) +@pytest.mark.parametrize("ncols", [10, 30]) +@pytest.mark.parametrize( + "nclusters", [unit_param(5), quality_param(10), stress_param(15)] +) +@pytest.mark.parametrize( + "n_neighbors", [unit_param(10), quality_param(4), stress_param(100)] +) +@pytest.mark.parametrize( + "n_parts", + [unit_param(1), unit_param(5), quality_param(7), stress_param(50)], +) +@pytest.mark.parametrize( + "streams_per_handle,reverse_worker_order", [(5, True), (10, False)] +) +def test_compare_skl( + nrows, + ncols, + nclusters, + n_parts, + n_neighbors, + streams_per_handle, + reverse_worker_order, + request, +): + _test_compare_skl( + nrows, + ncols, + nclusters, + n_parts, + n_neighbors, + streams_per_handle, + reverse_worker_order, + "client", + request, + ) + + +@pytest.mark.parametrize( + "nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)] +) +@pytest.mark.parametrize("ncols", [10, 30]) +@pytest.mark.parametrize( + "nclusters", [unit_param(5), quality_param(10), stress_param(15)] +) +@pytest.mark.parametrize( + "n_neighbors", [unit_param(10), quality_param(4), stress_param(100)] +) +@pytest.mark.parametrize( + "n_parts", + [unit_param(1), unit_param(5), quality_param(7), stress_param(50)], +) +@pytest.mark.parametrize( + "streams_per_handle,reverse_worker_order", [(5, True), (10, False)] +) +@pytest.mark.ucx +def test_compare_skl_ucx( + nrows, + ncols, + nclusters, + n_parts, + n_neighbors, + streams_per_handle, + reverse_worker_order, + request, +): + _test_compare_skl( + nrows, + ncols, + nclusters, + n_parts, + n_neighbors, + streams_per_handle, + reverse_worker_order, + "ucx_client", + request, + ) + + +def _test_batch_size(nrows, ncols, n_parts, batch_size, dask_client, request): + client = request.getfixturevalue(dask_client) n_neighbors = 10 n_clusters = 5 @@ -202,7 +265,25 @@ def test_batch_size(nrows, ncols, n_parts, batch_size, client): assert array_equal(y_hat, y) -def test_return_distance(client): +@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)]) +@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)]) +@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)]) +@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)]) +def test_batch_size(nrows, ncols, n_parts, batch_size, request): + _test_batch_size(nrows, ncols, n_parts, batch_size, "client", request) + + +@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)]) +@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)]) +@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)]) +@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)]) +@pytest.mark.ucx +def test_batch_size_ucx(nrows, ncols, n_parts, batch_size, request): + _test_batch_size(nrows, ncols, n_parts, batch_size, "ucx_client", request) + + +def _test_return_distance(dask_client, request): + client = request.getfixturevalue(dask_client) n_samples = 50 n_feats = 50 @@ -233,7 +314,17 @@ def test_return_distance(client): assert len(ret) == 2 -def test_default_n_neighbors(client): +def test_return_distance(request): + _test_return_distance("client", request) + + +@pytest.mark.ucx +def test_return_distance_ucx(request): + _test_return_distance("ucx_client", request) + + +def _test_default_n_neighbors(dask_client, request): + client = request.getfixturevalue(dask_client) n_samples = 50 n_feats = 50 @@ -269,7 +360,18 @@ def test_default_n_neighbors(client): assert ret.shape[1] == k -def test_one_query_partition(client): +def test_default_n_neighbors(request): + _test_default_n_neighbors("client", request) + + +@pytest.mark.ucx +def test_default_n_neighbors_ucx(request): + _test_default_n_neighbors("ucx_client", request) + + +def _test_one_query_partition(dask_client, request): + client = request.getfixturevalue(dask_client) # noqa + from cuml.dask.neighbors import NearestNeighbors as daskNN from cuml.dask.datasets import make_blobs @@ -280,3 +382,12 @@ def test_one_query_partition(client): cumlModel = daskNN(n_neighbors=4) cumlModel.fit(X_train) cumlModel.kneighbors(X_test) + + +def test_one_query_partition(request): + _test_one_query_partition("client", request) + + +@pytest.mark.ucx +def test_one_query_partition_ucx(request): + _test_one_query_partition("ucx_client", request)