diff --git a/python/cuml/tests/dask/test_dask_nearest_neighbors.py b/python/cuml/tests/dask/test_dask_nearest_neighbors.py
index 9dbd4dc010..4511978252 100644
--- a/python/cuml/tests/dask/test_dask_nearest_neighbors.py
+++ b/python/cuml/tests/dask/test_dask_nearest_neighbors.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2023, NVIDIA CORPORATION.
+# Copyright (c) 2020-2024, NVIDIA CORPORATION.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -81,24 +81,7 @@ def _scale_rows(client, nrows):
     return n_workers * nrows
 
 
-@pytest.mark.parametrize(
-    "nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)]
-)
-@pytest.mark.parametrize("ncols", [10, 30])
-@pytest.mark.parametrize(
-    "nclusters", [unit_param(5), quality_param(10), stress_param(15)]
-)
-@pytest.mark.parametrize(
-    "n_neighbors", [unit_param(10), quality_param(4), stress_param(100)]
-)
-@pytest.mark.parametrize(
-    "n_parts",
-    [unit_param(1), unit_param(5), quality_param(7), stress_param(50)],
-)
-@pytest.mark.parametrize(
-    "streams_per_handle,reverse_worker_order", [(5, True), (10, False)]
-)
-def test_compare_skl(
+def _test_compare_skl(
     nrows,
     ncols,
     nclusters,
@@ -106,8 +89,10 @@ def test_compare_skl(
     n_neighbors,
     streams_per_handle,
     reverse_worker_order,
-    client,
+    dask_client,
+    request,
 ):
+    client = request.getfixturevalue(dask_client)
 
     from cuml.dask.neighbors import NearestNeighbors as daskNN
 
@@ -162,11 +147,89 @@ def test_compare_skl(
     assert array_equal(y_hat, skl_y_hat)
 
 
-@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)])
-@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)])
-@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)])
-@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)])
-def test_batch_size(nrows, ncols, n_parts, batch_size, client):
+@pytest.mark.parametrize(
+    "nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)]
+)
+@pytest.mark.parametrize("ncols", [10, 30])
+@pytest.mark.parametrize(
+    "nclusters", [unit_param(5), quality_param(10), stress_param(15)]
+)
+@pytest.mark.parametrize(
+    "n_neighbors", [unit_param(10), quality_param(4), stress_param(100)]
+)
+@pytest.mark.parametrize(
+    "n_parts",
+    [unit_param(1), unit_param(5), quality_param(7), stress_param(50)],
+)
+@pytest.mark.parametrize(
+    "streams_per_handle,reverse_worker_order", [(5, True), (10, False)]
+)
+def test_compare_skl(
+    nrows,
+    ncols,
+    nclusters,
+    n_parts,
+    n_neighbors,
+    streams_per_handle,
+    reverse_worker_order,
+    request,
+):
+    _test_compare_skl(
+        nrows,
+        ncols,
+        nclusters,
+        n_parts,
+        n_neighbors,
+        streams_per_handle,
+        reverse_worker_order,
+        "client",
+        request,
+    )
+
+
+@pytest.mark.parametrize(
+    "nrows", [unit_param(300), quality_param(1e6), stress_param(5e8)]
+)
+@pytest.mark.parametrize("ncols", [10, 30])
+@pytest.mark.parametrize(
+    "nclusters", [unit_param(5), quality_param(10), stress_param(15)]
+)
+@pytest.mark.parametrize(
+    "n_neighbors", [unit_param(10), quality_param(4), stress_param(100)]
+)
+@pytest.mark.parametrize(
+    "n_parts",
+    [unit_param(1), unit_param(5), quality_param(7), stress_param(50)],
+)
+@pytest.mark.parametrize(
+    "streams_per_handle,reverse_worker_order", [(5, True), (10, False)]
+)
+@pytest.mark.ucx
+def test_compare_skl_ucx(
+    nrows,
+    ncols,
+    nclusters,
+    n_parts,
+    n_neighbors,
+    streams_per_handle,
+    reverse_worker_order,
+    request,
+):
+    _test_compare_skl(
+        nrows,
+        ncols,
+        nclusters,
+        n_parts,
+        n_neighbors,
+        streams_per_handle,
+        reverse_worker_order,
+        "ucx_client",
+        request,
+    )
+
+
+def _test_batch_size(nrows, ncols, n_parts, batch_size, dask_client, request):
+    client = request.getfixturevalue(dask_client)
 
     n_neighbors = 10
     n_clusters = 5
@@ -202,7 +265,25 @@ def test_batch_size(nrows, ncols, n_parts, batch_size, client):
     assert array_equal(y_hat, y)
 
 
-def test_return_distance(client):
+@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)])
+@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)])
+@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)])
+@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)])
+def test_batch_size(nrows, ncols, n_parts, batch_size, request):
+    _test_batch_size(nrows, ncols, n_parts, batch_size, "client", request)
+
+
+@pytest.mark.parametrize("nrows", [unit_param(1000), stress_param(1e5)])
+@pytest.mark.parametrize("ncols", [unit_param(10), stress_param(500)])
+@pytest.mark.parametrize("n_parts", [unit_param(10), stress_param(100)])
+@pytest.mark.parametrize("batch_size", [unit_param(100), stress_param(1e3)])
+@pytest.mark.ucx
+def test_batch_size_ucx(nrows, ncols, n_parts, batch_size, request):
+    _test_batch_size(nrows, ncols, n_parts, batch_size, "ucx_client", request)
+
+
+def _test_return_distance(dask_client, request):
+    client = request.getfixturevalue(dask_client)
 
     n_samples = 50
     n_feats = 50
@@ -233,7 +314,17 @@ def test_return_distance(client):
     assert len(ret) == 2
 
 
-def test_default_n_neighbors(client):
+def test_return_distance(request):
+    _test_return_distance("client", request)
+
+
+@pytest.mark.ucx
+def test_return_distance_ucx(request):
+    _test_return_distance("ucx_client", request)
+
+
+def _test_default_n_neighbors(dask_client, request):
+    client = request.getfixturevalue(dask_client)
 
     n_samples = 50
     n_feats = 50
@@ -269,7 +360,18 @@ def test_default_n_neighbors(client):
     assert ret.shape[1] == k
 
 
-def test_one_query_partition(client):
+def test_default_n_neighbors(request):
+    _test_default_n_neighbors("client", request)
+
+
+@pytest.mark.ucx
+def test_default_n_neighbors_ucx(request):
+    _test_default_n_neighbors("ucx_client", request)
+
+
+def _test_one_query_partition(dask_client, request):
+    client = request.getfixturevalue(dask_client)  # noqa
+
     from cuml.dask.neighbors import NearestNeighbors as daskNN
     from cuml.dask.datasets import make_blobs
 
@@ -280,3 +382,12 @@ def test_one_query_partition(client):
     cumlModel = daskNN(n_neighbors=4)
     cumlModel.fit(X_train)
     cumlModel.kneighbors(X_test)
+
+
+def test_one_query_partition(request):
+    _test_one_query_partition("client", request)
+
+
+@pytest.mark.ucx
+def test_one_query_partition_ucx(request):
+    _test_one_query_partition("ucx_client", request)