From 69b9d421a40f89d5c9bff0f062cfc13977844294 Mon Sep 17 00:00:00 2001 From: aamijar Date: Fri, 31 May 2024 22:15:44 +0000 Subject: [PATCH] use mdspan --- cpp/src/tsne/tsne_runner.cuh | 30 +++++++++++++++++++++++------- cpp/test/sg/tsne_test.cu | 7 ++----- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/cpp/src/tsne/tsne_runner.cuh b/cpp/src/tsne/tsne_runner.cuh index c8fffddd4c..643a78fa7f 100644 --- a/cpp/src/tsne/tsne_runner.cuh +++ b/cpp/src/tsne/tsne_runner.cuh @@ -25,10 +25,16 @@ #include #include +#include #include +#include +#include +#include +#include #include #include #include +#include #include #include @@ -124,21 +130,31 @@ class TSNE_runner { noise_vars.data(), prms, stream); - handle.sync_stream(stream); rmm::device_uvector mean_result(dim, stream); rmm::device_uvector std_result(dim, stream); std::vector h_std_result(dim); - float multiplier = 1e-4; + const float multiplier = 1e-4; - raft::stats::mean(mean_result.data(), Y, dim, n, false, false, stream); - raft::stats::stddev(std_result.data(), Y, mean_result.data(), dim, n, true, false, stream); + auto Y_view = raft::make_device_matrix_view(Y, n, dim); + auto Y_view_const = raft::make_device_matrix_view(Y, n, dim); + + auto mean_result_view = raft::make_device_vector_view(mean_result.data(), dim); + auto mean_result_view_const = + raft::make_device_vector_view(mean_result.data(), dim); + + auto std_result_view = raft::make_device_vector_view(std_result.data(), dim); + + auto h_multiplier_view_const = raft::make_host_scalar_view(&multiplier); + auto h_std_result_view_const = raft::make_host_scalar_view(&h_std_result[0]); + + raft::stats::mean(handle_, Y_view_const, mean_result_view, false); + raft::stats::stddev(handle_, Y_view_const, mean_result_view_const, std_result_view, false); raft::update_host(h_std_result.data(), std_result.data(), dim, stream); - handle.sync_stream(stream); - raft::linalg::divideScalar(Y, Y, h_std_result[0], n * dim, stream); - raft::linalg::multiplyScalar(Y, Y, multiplier, n * dim, stream); + raft::linalg::divide_scalar(handle_, Y_view_const, Y_view, h_std_result_view_const); + raft::linalg::multiply_scalar(handle_, Y_view_const, Y_view, h_multiplier_view_const); } } } diff --git a/cpp/test/sg/tsne_test.cu b/cpp/test/sg/tsne_test.cu index f97d994312..66bcc31c0f 100644 --- a/cpp/test/sg/tsne_test.cu +++ b/cpp/test/sg/tsne_test.cu @@ -131,10 +131,8 @@ class TSNETest : public ::testing::TestWithParam { raft::update_device(X_d.data(), dataset.data(), n * p, stream); rmm::device_uvector Xtranspose(n * p, stream); - - raft::update_device(Xtranspose.data(), X_d.data(), n * p, stream); + raft::copy_async(Xtranspose.data(), X_d.data(), n * p, stream); raft::linalg::transpose(handle, Xtranspose.data(), X_d.data(), p, n, stream); - handle.sync_stream(stream); rmm::device_uvector Y_d(n * model_params.dim, stream); rmm::device_uvector input_indices(0, stream); @@ -191,9 +189,8 @@ class TSNETest : public ::testing::TestWithParam { handle.sync_stream(stream); free(embeddings_h); - raft::update_device(Xtranspose.data(), X_d.data(), n * p, stream); + raft::copy_async(Xtranspose.data(), X_d.data(), n * p, stream); raft::linalg::transpose(handle, Xtranspose.data(), X_d.data(), n, p, stream); - handle.sync_stream(stream); // Produce trustworthiness score results.trustworthiness =