diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 4ed9529a36..780f6f8581 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -268,6 +268,10 @@ if(RAFT_COMPILE_LIBRARY)
     src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu
     src/raft_runtime/random/rmat_rectangular_generator_int_double.cu
     src/raft_runtime/random/rmat_rectangular_generator_int_float.cu
+    src/raft_runtime/solver/lanczos_solver_int64_double.cu
+    src/raft_runtime/solver/lanczos_solver_int64_float.cu
+    src/raft_runtime/solver/lanczos_solver_int_double.cu
+    src/raft_runtime/solver/lanczos_solver_int_float.cu
   )
   set_target_properties(
     raft_objs
diff --git a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp
index c10c0de426..97ac7c45f4 100644
--- a/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp
+++ b/cpp/include/raft/sparse/linalg/detail/cusparse_utils.hpp
@@ -30,7 +30,25 @@ namespace linalg {
 namespace detail {
 
 /**
- * @brief create a cuSparse dense descriptor
+ * @brief create a cuSparse dense descriptor for a vector
+ * @tparam ValueType Data type of vector_view (float/double)
+ * @tparam IndexType Type of vector_view
+ * @param[in] vector_view input raft::device_vector_view
+ * @returns dense vector descriptor to be used by cuSparse API
+ */
+template <typename ValueType, typename IndexType>
+cusparseDnVecDescr_t create_descriptor(raft::device_vector_view<ValueType, IndexType> vector_view)
+{
+  cusparseDnVecDescr_t descr;
+  RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednvec(
+    &descr,
+    vector_view.extent(0),
+    const_cast<std::remove_const_t<ValueType>*>(vector_view.data_handle())));
+  return descr;
+}
+
+/**
+ * @brief create a cuSparse dense descriptor for a matrix
  * @tparam ValueType Data type of dense_view (float/double)
  * @tparam IndexType Type of dense_view
  * @tparam LayoutPolicy layout of dense_view
diff --git a/cpp/include/raft/sparse/solver/detail/lanczos.cuh b/cpp/include/raft/sparse/solver/detail/lanczos.cuh
index 9ecb4b729f..02a77a0d99 100644
--- a/cpp/include/raft/sparse/solver/detail/lanczos.cuh
+++ b/cpp/include/raft/sparse/solver/detail/lanczos.cuh
@@ -19,10 +19,43 @@
 // for cmath:
 #define _USE_MATH_DEFINES
 
+#include <raft/core/detail/macros.hpp>
+#include <raft/core/device_csr_matrix.hpp>
+#include <raft/core/device_mdspan.hpp>
+#include <raft/core/host_mdarray.hpp>
+#include <raft/core/host_mdspan.hpp>
+#include <raft/core/logger-macros.hpp>
+#include <raft/core/mdspan_types.hpp>
 #include <raft/core/resource/cublas_handle.hpp>
 #include <raft/core/resource/cuda_stream.hpp>
 #include <raft/core/resources.hpp>
+#include <raft/linalg/add.cuh>
+#include <raft/linalg/axpy.cuh>
+#include <raft/linalg/binary_op.cuh>
+#include <raft/linalg/detail/add.cuh>
 #include <raft/linalg/detail/cublas_wrappers.hpp>
+#include <raft/linalg/detail/gemv.hpp>
+#include <raft/linalg/dot.cuh>
+#include <raft/linalg/eig.cuh>
+#include <raft/linalg/gemm.hpp>
+#include <raft/linalg/gemv.cuh>
+#include <raft/linalg/init.cuh>
+#include <raft/linalg/map.cuh>
+#include <raft/linalg/multiply.cuh>
+#include <raft/linalg/norm.cuh>
+#include <raft/linalg/norm_types.hpp>
+#include <raft/linalg/normalize.cuh>
+#include <raft/linalg/svd.cuh>
+#include <raft/linalg/transpose.cuh>
+#include <raft/linalg/unary_op.cuh>
+#include <raft/matrix/diagonal.cuh>
+#include <raft/matrix/matrix.cuh>
+#include <raft/matrix/slice.cuh>
+#include <raft/matrix/triangular.cuh>
+#include <raft/random/rng.cuh>
+#include <raft/sparse/detail/cusparse_wrappers.h>
+#include <raft/sparse/linalg/detail/cusparse_utils.hpp>
+#include <raft/sparse/solver/lanczos_types.hpp>
 #include <raft/spectral/detail/lapack.hpp>
 #include <raft/spectral/detail/warn_dbg.hpp>
 #include <raft/spectral/matrix_wrappers.hpp>
@@ -30,9 +63,17 @@
 
 #include <cuda.h>
 
+#include <cublasLt.h>
 #include <curand.h>
+#include <cusparse.h>
+#include <sys/types.h>
 
+#include <algorithm>
 #include <cmath>
+#include <cstdint>
+#include <optional>
+#include <type_traits>
+#include <utility>
 #include <vector>
 
 namespace raft::sparse::solver::detail {
@@ -1396,4 +1437,674 @@ int computeLargestEigenvectors(
   return status;
 }
 
+template <typename T>
+RAFT_KERNEL kernel_triangular_populate(T* M, const T* beta, int n)
+{
+  int row = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (row < n) {
+    // Upper diagonal: M[row + 1, row] in column-major
+    if (row < n - 1) { M[(row + 1) * n + row] = beta[row]; }
+
+    // Lower diagonal: M[row - 1, row] in column-major
+    if (row > 0) { M[(row - 1) * n + row] = beta[row - 1]; }
+  }
+}
+
+template <typename T>
+RAFT_KERNEL kernel_triangular_beta_k(T* t, const T* beta_k, int k, int n)
+{
+  int tid = threadIdx.x + blockIdx.x * blockDim.x;
+
+  if (tid < k) {
+    // Update the k-th column: t[i, k] -> t[k * n + i] in column-major
+    t[tid * n + k] = beta_k[tid];
+
+    // Update the k-th row: t[k, j] -> t[j * n + k] in column-major
+    t[k * n + tid] = beta_k[tid];
+  }
+}
+
+template <typename T>
+RAFT_KERNEL kernel_normalize(const T* u, const T* beta, int j, int n, T* v, T* V, int size)
+{
+  int i = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (i < size) {
+    if (beta[j] == 0) {
+      v[i] = u[i] / 1;
+    } else {
+      v[i] = u[i] / beta[j];
+    }
+    V[i + (j + 1) * n] = v[i];
+  }
+}
+
+template <typename T>
+RAFT_KERNEL kernel_clamp_down(T* value, T threshold)
+{
+  *value = (fabs(*value) < threshold) ? 0 : *value;
+}
+
+template <typename T>
+RAFT_KERNEL kernel_clamp_down_vector(T* vec, T threshold, int size)
+{
+  int idx = threadIdx.x + blockIdx.x * blockDim.x;
+  if (idx < size) { vec[idx] = (fabs(vec[idx]) < threshold) ? 0 : vec[idx]; }
+}
+
+template <typename IndexTypeT, typename ValueTypeT>
+void lanczos_solve_ritz(
+  raft::resources const& handle,
+  raft::device_matrix_view<ValueTypeT, uint32_t, raft::row_major> alpha,
+  raft::device_matrix_view<ValueTypeT, uint32_t, raft::row_major> beta,
+  std::optional<raft::device_vector_view<ValueTypeT, uint32_t>> beta_k,
+  IndexTypeT k,
+  int which,
+  int ncv,
+  raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors,
+  raft::device_vector_view<ValueTypeT> eigenvalues)
+{
+  auto stream = resource::get_cuda_stream(handle);
+
+  ValueTypeT zero = 0;
+  auto triangular_matrix =
+    raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, ncv);
+  raft::matrix::fill(handle, triangular_matrix.view(), zero);
+
+  raft::device_vector_view<const ValueTypeT, uint32_t> alphaVec =
+    raft::make_device_vector_view<const ValueTypeT, uint32_t>(alpha.data_handle(), ncv);
+  raft::matrix::set_diagonal(handle, alphaVec, triangular_matrix.view());
+
+  // raft::matrix::initializeDiagonalMatrix(
+  //   alpha.data_handle(), triangular_matrix.data_handle(), ncv, ncv, stream);
+
+  int blockSize = 256;
+  int numBlocks = (ncv + blockSize - 1) / blockSize;
+  kernel_triangular_populate<ValueTypeT>
+    <<<blockSize, numBlocks, 0, stream>>>(triangular_matrix.data_handle(), beta.data_handle(), ncv);
+
+  if (beta_k) {
+    int threadsPerBlock = 256;
+    int blocksPerGrid   = (k + threadsPerBlock - 1) / threadsPerBlock;
+    kernel_triangular_beta_k<ValueTypeT><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
+      triangular_matrix.data_handle(), beta_k.value().data_handle(), (int)k, ncv);
+  }
+
+  auto triangular_matrix_view =
+    raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::col_major>(
+      triangular_matrix.data_handle(), ncv, ncv);
+
+  raft::linalg::eig_dc(handle, triangular_matrix_view, eigenvectors, eigenvalues);
+}
+
+template <typename IndexTypeT, typename ValueTypeT>
+void lanczos_aux(raft::resources const& handle,
+                 raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
+                 raft::device_matrix_view<ValueTypeT, uint32_t, raft::row_major> V,
+                 raft::device_matrix_view<ValueTypeT, uint32_t> u,
+                 raft::device_matrix_view<ValueTypeT, uint32_t> alpha,
+                 raft::device_matrix_view<ValueTypeT, uint32_t> beta,
+                 int start_idx,
+                 int end_idx,
+                 int ncv,
+                 raft::device_matrix_view<ValueTypeT, uint32_t> v,
+                 raft::device_matrix_view<ValueTypeT, uint32_t> uu,
+                 raft::device_matrix_view<ValueTypeT, uint32_t> vv)
+{
+  auto stream = resource::get_cuda_stream(handle);
+
+  IndexTypeT n  = A.structure_view().get_n_rows();
+  auto v_vector = raft::make_device_vector_view<const ValueTypeT>(v.data_handle(), n);
+  auto u_vector = raft::make_device_vector_view<const ValueTypeT>(u.data_handle(), n);
+
+  raft::copy(
+    v.data_handle(), V.data_handle() + start_idx * V.stride(0), n, stream);  // V(start_idx, 0)
+
+  auto cusparse_h                 = resource::get_cusparse_handle(handle);
+  cusparseSpMatDescr_t cusparse_A = raft::sparse::linalg::detail::create_descriptor(A);
+
+  cusparseDnVecDescr_t cusparse_v = raft::sparse::linalg::detail::create_descriptor(v_vector);
+  cusparseDnVecDescr_t cusparse_u = raft::sparse::linalg::detail::create_descriptor(u_vector);
+
+  ValueTypeT one  = 1;
+  ValueTypeT zero = 0;
+  size_t bufferSize;
+  raft::sparse::detail::cusparsespmv_buffersize(cusparse_h,
+                                                CUSPARSE_OPERATION_NON_TRANSPOSE,
+                                                &one,
+                                                cusparse_A,
+                                                cusparse_v,
+                                                &zero,
+                                                cusparse_u,
+                                                CUSPARSE_SPMV_ALG_DEFAULT,
+                                                &bufferSize,
+                                                stream);
+  auto cusparse_spmv_buffer = raft::make_device_vector<ValueTypeT>(handle, bufferSize);
+
+  for (int i = start_idx; i < end_idx; i++) {
+    raft::sparse::detail::cusparsespmv(cusparse_h,
+                                       CUSPARSE_OPERATION_NON_TRANSPOSE,
+                                       &one,
+                                       cusparse_A,
+                                       cusparse_v,
+                                       &zero,
+                                       cusparse_u,
+                                       CUSPARSE_SPMV_ALG_DEFAULT,
+                                       cusparse_spmv_buffer.data_handle(),
+                                       stream);
+
+    auto alpha_i =
+      raft::make_device_scalar_view(alpha.data_handle() + i * alpha.stride(1));  // alpha(0, i)
+    raft::linalg::dot(handle, v_vector, u_vector, alpha_i);
+
+    raft::matrix::fill(handle, vv, zero);
+
+    auto cublas_h = resource::get_cublas_handle(handle);
+
+    ValueTypeT alpha_i_host = 0;
+    ValueTypeT b            = 0;
+    ValueTypeT mone         = -1;
+
+    raft::copy<ValueTypeT>(
+      &b, beta.data_handle() + ((i - 1 + ncv) % ncv) * beta.stride(1), 1, stream);
+    raft::copy<ValueTypeT>(
+      &alpha_i_host, alpha.data_handle() + i * alpha.stride(1), 1, stream);  // alpha(0, i)
+
+    raft::linalg::axpy(handle, n, &alpha_i_host, v.data_handle(), 1, vv.data_handle(), 1, stream);
+    raft::linalg::axpy(handle,
+                       n,
+                       &b,
+                       V.data_handle() + (((i - 1 + ncv) % ncv) * V.stride(0)),
+                       1,
+                       vv.data_handle(),
+                       1,
+                       stream);
+    raft::linalg::axpy(handle, n, &mone, vv.data_handle(), 1, u.data_handle(), 1, stream);
+
+    raft::linalg::gemv(handle,
+                       CUBLAS_OP_T,
+                       n,
+                       i + 1,
+                       &one,
+                       V.data_handle(),
+                       n,
+                       u.data_handle(),
+                       1,
+                       &zero,
+                       uu.data_handle(),
+                       1,
+                       stream);
+
+    raft::linalg::gemv(handle,
+                       CUBLAS_OP_N,
+                       n,
+                       i + 1,
+                       &mone,
+                       V.data_handle(),
+                       n,
+                       uu.data_handle(),
+                       1,
+                       &one,
+                       u.data_handle(),
+                       1,
+                       stream);
+
+    auto uu_i = raft::make_device_scalar_view(uu.data_handle() + uu.stride(1) * i);  // uu(0, i)
+    raft::linalg::add(handle, make_const_mdspan(alpha_i), make_const_mdspan(uu_i), alpha_i);
+
+    kernel_clamp_down<<<1, 1, 0, stream>>>(alpha_i.data_handle(), static_cast<ValueTypeT>(1e-9));
+
+    auto output = raft::make_device_vector_view<ValueTypeT, uint32_t>(
+      beta.data_handle() + beta.stride(1) * i, 1);
+    auto input = raft::make_device_matrix_view<const ValueTypeT, uint32_t>(u.data_handle(), 1, n);
+    raft::linalg::norm(handle,
+                       input,
+                       output,
+                       raft::linalg::L2Norm,
+                       raft::linalg::Apply::ALONG_ROWS,
+                       raft::sqrt_op());
+
+    int blockSize = 256;
+    int numBlocks = (n + blockSize - 1) / blockSize;
+
+    kernel_clamp_down_vector<<<numBlocks, blockSize, 0, stream>>>(
+      u.data_handle(), static_cast<ValueTypeT>(1e-7), n);
+
+    kernel_clamp_down<<<1, 1, 0, stream>>>(beta.data_handle() + beta.stride(1) * i,
+                                           static_cast<ValueTypeT>(1e-6));
+
+    if (i >= end_idx - 1) { break; }
+
+    int threadsPerBlock = 256;
+    int blocksPerGrid   = (n + threadsPerBlock - 1) / threadsPerBlock;
+
+    kernel_normalize<ValueTypeT><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
+      u.data_handle(), beta.data_handle(), i, n, v.data_handle(), V.data_handle(), n);
+  }
+}
+
+template <typename IndexTypeT, typename ValueTypeT>
+auto lanczos_smallest(
+  raft::resources const& handle,
+  raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
+  int nEigVecs,
+  int maxIter,
+  int restartIter,
+  ValueTypeT tol,
+  ValueTypeT* eigVals_dev,
+  ValueTypeT* eigVecs_dev,
+  ValueTypeT* v0,
+  uint64_t seed) -> int
+{
+  int n       = A.structure_view().get_n_rows();
+  int ncv     = restartIter;
+  auto stream = resource::get_cuda_stream(handle);
+
+  auto V = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, ncv, n);
+  auto V_0_view =
+    raft::make_device_matrix_view<ValueTypeT, uint32_t>(V.data_handle(), 1, n);  // First Row V[0]
+  auto v0_view = raft::make_device_matrix_view<const ValueTypeT, uint32_t>(v0, 1, n);
+
+  auto u        = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, n);
+  auto u_vector = raft::make_device_vector_view<ValueTypeT, uint32_t>(u.data_handle(), n);
+  raft::copy(u.data_handle(), v0, n, stream);
+
+  auto cublas_h = resource::get_cublas_handle(handle);
+  auto v0nrm    = raft::make_device_vector<ValueTypeT, uint32_t>(handle, 1);
+  raft::linalg::norm(handle,
+                     v0_view,
+                     v0nrm.view(),
+                     raft::linalg::L2Norm,
+                     raft::linalg::Apply::ALONG_ROWS,
+                     raft::sqrt_op());
+
+  auto v0_vector_const = raft::make_device_vector_view<const ValueTypeT, uint32_t>(v0, n);
+
+  raft::linalg::unary_op(
+    handle, v0_vector_const, V_0_view, [device_scalar = v0nrm.data_handle()] __device__(auto y) {
+      return y / *device_scalar;
+    });
+
+  auto alpha      = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, ncv);
+  auto beta       = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, ncv);
+  ValueTypeT zero = 0;
+  raft::matrix::fill(handle, alpha.view(), zero);
+  raft::matrix::fill(handle, beta.view(), zero);
+
+  auto v      = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, n);
+  auto aux_uu = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, ncv);
+  auto vv     = raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, 1, n);
+
+  lanczos_aux(handle,
+              A,
+              V.view(),
+              u.view(),
+              alpha.view(),
+              beta.view(),
+              0,
+              ncv,
+              ncv,
+              v.view(),
+              aux_uu.view(),
+              vv.view());
+
+  auto eigenvectors =
+    raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, ncv);
+  auto eigenvalues = raft::make_device_vector<ValueTypeT, uint32_t>(handle, ncv);
+
+  lanczos_solve_ritz<IndexTypeT, ValueTypeT>(handle,
+                                             alpha.view(),
+                                             beta.view(),
+                                             std::nullopt,
+                                             nEigVecs,
+                                             0,
+                                             ncv,
+                                             eigenvectors.view(),
+                                             eigenvalues.view());
+
+  auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
+    eigenvectors.data_handle(), ncv, nEigVecs);
+  auto eigenvalues_k =
+    raft::make_device_vector_view<ValueTypeT, uint32_t>(eigenvalues.data_handle(), nEigVecs);
+
+  auto ritz_eigenvectors =
+    raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(eigVecs_dev, n, nEigVecs);
+
+  auto V_T =
+    raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(V.data_handle(), n, ncv);
+  raft::linalg::gemm<ValueTypeT, uint32_t, raft::col_major, raft::col_major, raft::col_major>(
+    handle, V_T, eigenvectors_k, ritz_eigenvectors);
+
+  auto s = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
+
+  auto eigenvectors_k_slice =
+    raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
+      eigenvectors.data_handle(), ncv, nEigVecs);
+  auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
+    s.data_handle(), 1, nEigVecs);
+
+  raft::matrix::slice_coordinates<IndexTypeT> coords(ncv - 1, 0, ncv, nEigVecs);
+  raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k_slice), S_matrix, coords);
+
+  auto beta_k = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
+  raft::matrix::fill(handle, beta_k.view(), zero);
+  auto beta_scalar = raft::make_device_scalar_view<const ValueTypeT>(beta.data_handle() +
+                                                                     (ncv - 1) * beta.stride(1));
+
+  raft::linalg::axpy(handle, beta_scalar, raft::make_const_mdspan(s.view()), beta_k.view());
+
+  ValueTypeT res = 0;
+
+  raft::device_vector<ValueTypeT, uint32_t> output =
+    raft::make_device_vector<ValueTypeT, uint32_t>(handle, 1);
+  raft::device_matrix_view<const ValueTypeT> input =
+    raft::make_device_matrix_view<const ValueTypeT>(beta_k.data_handle(), 1, nEigVecs);
+  raft::linalg::norm(handle,
+                     input,
+                     output.view(),
+                     raft::linalg::L2Norm,
+                     raft::linalg::Apply::ALONG_ROWS,
+                     raft::sqrt_op());
+  raft::copy(&res, output.data_handle(), 1, stream);
+  resource::sync_stream(handle, stream);
+
+  auto uu  = raft::make_device_matrix<ValueTypeT>(handle, 0, nEigVecs);
+  int iter = ncv;
+  while (res > tol && iter < maxIter) {
+    auto beta_view = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
+      beta.data_handle(), 1, nEigVecs);
+    raft::matrix::fill(handle, beta_view, zero);
+
+    raft::copy(alpha.data_handle(), eigenvalues_k.data_handle(), nEigVecs, stream);
+
+    auto x_T =
+      raft::make_device_matrix_view<ValueTypeT>(ritz_eigenvectors.data_handle(), nEigVecs, n);
+
+    raft::copy(V.data_handle(), x_T.data_handle(), nEigVecs * n, stream);
+
+    ValueTypeT one  = 1;
+    ValueTypeT mone = -1;
+
+    // Using raft::linalg::gemv leads to Reason=7:CUBLAS_STATUS_INVALID_VALUE (issue raft#2484)
+    raft::linalg::detail::cublasgemv(cublas_h,
+                                     CUBLAS_OP_T,
+                                     nEigVecs,
+                                     n,
+                                     &one,
+                                     V.data_handle(),
+                                     nEigVecs,
+                                     u.data_handle(),
+                                     1,
+                                     &zero,
+                                     uu.data_handle(),
+                                     1,
+                                     stream);
+
+    raft::linalg::detail::cublasgemv(cublas_h,
+                                     CUBLAS_OP_N,
+                                     nEigVecs,
+                                     n,
+                                     &mone,
+                                     V.data_handle(),
+                                     nEigVecs,
+                                     uu.data_handle(),
+                                     1,
+                                     &one,
+                                     u.data_handle(),
+                                     1,
+                                     stream);
+
+    auto V_0_view =
+      raft::make_device_matrix_view<ValueTypeT>(V.data_handle() + (nEigVecs * n), 1, n);
+    auto V_0_view_vector =
+      raft::make_device_vector_view<ValueTypeT, uint32_t>(V_0_view.data_handle(), n);
+    auto unrm = raft::make_device_vector<ValueTypeT, uint32_t>(handle, 1);
+    raft::linalg::norm(handle,
+                       raft::make_const_mdspan(u.view()),
+                       unrm.view(),
+                       raft::linalg::L2Norm,
+                       raft::linalg::Apply::ALONG_ROWS,
+                       raft::sqrt_op());
+
+    raft::linalg::unary_op(
+      handle,
+      raft::make_const_mdspan(u_vector),
+      V_0_view,
+      [device_scalar = unrm.data_handle()] __device__(auto y) { return y / *device_scalar; });
+
+    auto cusparse_h                 = resource::get_cusparse_handle(handle);
+    cusparseSpMatDescr_t cusparse_A = raft::sparse::linalg::detail::create_descriptor(A);
+
+    cusparseDnVecDescr_t cusparse_v =
+      raft::sparse::linalg::detail::create_descriptor(V_0_view_vector);
+    cusparseDnVecDescr_t cusparse_u = raft::sparse::linalg::detail::create_descriptor(u_vector);
+
+    ValueTypeT zero = 0;
+    size_t bufferSize;
+    raft::sparse::detail::cusparsespmv_buffersize(cusparse_h,
+                                                  CUSPARSE_OPERATION_NON_TRANSPOSE,
+                                                  &one,
+                                                  cusparse_A,
+                                                  cusparse_v,
+                                                  &zero,
+                                                  cusparse_u,
+                                                  CUSPARSE_SPMV_ALG_DEFAULT,
+                                                  &bufferSize,
+                                                  stream);
+    auto cusparse_spmv_buffer = raft::make_device_vector<ValueTypeT>(handle, bufferSize);
+
+    raft::sparse::detail::cusparsespmv(cusparse_h,
+                                       CUSPARSE_OPERATION_NON_TRANSPOSE,
+                                       &one,
+                                       cusparse_A,
+                                       cusparse_v,
+                                       &zero,
+                                       cusparse_u,
+                                       CUSPARSE_SPMV_ALG_DEFAULT,
+                                       cusparse_spmv_buffer.data_handle(),
+                                       stream);
+
+    auto alpha_k = raft::make_device_scalar_view<ValueTypeT>(alpha.data_handle() + nEigVecs);
+
+    raft::linalg::dot(
+      handle, make_const_mdspan(V_0_view_vector), make_const_mdspan(u_vector), alpha_k);
+
+    raft::linalg::binary_op(handle,
+                            make_const_mdspan(u_vector),
+                            make_const_mdspan(V_0_view_vector),
+                            u_vector,
+                            [device_scalar_ptr = alpha_k.data_handle()] __device__(
+                              ValueTypeT u_element, ValueTypeT V_0_element) {
+                              return u_element - (*device_scalar_ptr) * V_0_element;
+                            });
+
+    auto temp = raft::make_device_vector<ValueTypeT, uint32_t>(handle, n);
+
+    auto V_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
+      V.data_handle(), nEigVecs, n);
+    auto V_k_T =
+      raft::make_device_matrix<ValueTypeT, uint32_t, raft::row_major>(handle, n, nEigVecs);
+
+    raft::linalg::transpose(handle, V_k, V_k_T.view());
+
+    ValueTypeT three = 3;
+    ValueTypeT two   = 2;
+
+    std::vector<ValueTypeT> M   = {1, 2, 3, 4, 5, 6};
+    std::vector<ValueTypeT> vec = {1, 1};
+
+    auto M_dev   = raft::make_device_matrix<ValueTypeT>(handle, 2, 3);
+    auto vec_dev = raft::make_device_vector<ValueTypeT>(handle, 2);
+    auto out     = raft::make_device_vector<ValueTypeT>(handle, 3);
+    raft::copy(M_dev.data_handle(), M.data(), 6, stream);
+    raft::copy(vec_dev.data_handle(), vec.data(), 2, stream);
+
+    raft::linalg::gemv(handle,
+                       CUBLAS_OP_N,
+                       three,
+                       two,
+                       &one,
+                       M_dev.data_handle(),
+                       three,
+                       vec_dev.data_handle(),
+                       1,
+                       &zero,
+                       out.data_handle(),
+                       1,
+                       stream);
+
+    raft::linalg::gemv(handle,
+                       CUBLAS_OP_N,
+                       n,
+                       nEigVecs,
+                       &one,
+                       V_k.data_handle(),
+                       n,
+                       beta_k.data_handle(),
+                       1,
+                       &zero,
+                       temp.data_handle(),
+                       1,
+                       stream);
+
+    auto one_scalar = raft::make_device_scalar<ValueTypeT>(handle, 1);
+    raft::linalg::binary_op(handle,
+                            make_const_mdspan(u_vector),
+                            make_const_mdspan(temp.view()),
+                            u_vector,
+                            [device_scalar_ptr = one_scalar.data_handle()] __device__(
+                              ValueTypeT u_element, ValueTypeT temp_element) {
+                              return u_element - (*device_scalar_ptr) * temp_element;
+                            });
+
+    auto output1 = raft::make_device_vector_view<ValueTypeT, uint32_t>(
+      beta.data_handle() + beta.stride(1) * nEigVecs, 1);
+    raft::linalg::norm(handle,
+                       raft::make_const_mdspan(u.view()),
+                       output1,
+                       raft::linalg::L2Norm,
+                       raft::linalg::Apply::ALONG_ROWS,
+                       raft::sqrt_op());
+
+    auto V_kplus1 =
+      raft::make_device_vector_view<ValueTypeT>(V.data_handle() + V.stride(0) * (nEigVecs + 1), n);
+
+    raft::linalg::unary_op(
+      handle,
+      make_const_mdspan(u_vector),
+      V_kplus1,
+      [device_scalar = (beta.data_handle() + beta.stride(1) * nEigVecs)] __device__(auto y) {
+        return y / *device_scalar;
+      });
+
+    lanczos_aux(handle,
+                A,
+                V.view(),
+                u.view(),
+                alpha.view(),
+                beta.view(),
+                nEigVecs + 1,
+                ncv,
+                ncv,
+                v.view(),
+                aux_uu.view(),
+                vv.view());
+    iter += ncv - nEigVecs;
+    lanczos_solve_ritz<IndexTypeT, ValueTypeT>(handle,
+                                               alpha.view(),
+                                               beta.view(),
+                                               beta_k.view(),
+                                               nEigVecs,
+                                               0,
+                                               ncv,
+                                               eigenvectors.view(),
+                                               eigenvalues.view());
+    auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
+      eigenvectors.data_handle(), ncv, nEigVecs);
+
+    auto ritz_eigenvectors = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
+      eigVecs_dev, n, nEigVecs);
+
+    auto V_T =
+      raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(V.data_handle(), n, ncv);
+    raft::linalg::gemm<ValueTypeT, uint32_t, raft::col_major, raft::col_major, raft::col_major>(
+      handle, V_T, eigenvectors_k, ritz_eigenvectors);
+
+    auto eigenvectors_k_slice =
+      raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
+        eigenvectors.data_handle(), ncv, nEigVecs);
+    auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
+      s.data_handle(), 1, nEigVecs);
+
+    raft::matrix::slice_coordinates<IndexTypeT> coords(ncv - 1, 0, ncv, nEigVecs);
+    raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k_slice), S_matrix, coords);
+
+    raft::matrix::fill(handle, beta_k.view(), zero);
+
+    auto beta_scalar = raft::make_device_scalar_view<const ValueTypeT>(
+      beta.data_handle() + beta.stride(1) * (ncv - 1));  // &((beta.view())(0, ncv - 1))
+
+    raft::linalg::axpy(handle, beta_scalar, raft::make_const_mdspan(s.view()), beta_k.view());
+
+    raft::device_vector<ValueTypeT, uint32_t> output2 =
+      raft::make_device_vector<ValueTypeT, uint32_t>(handle, 1);
+    raft::device_matrix_view<const ValueTypeT> input2 =
+      raft::make_device_matrix_view<const ValueTypeT>(beta_k.data_handle(), 1, nEigVecs);
+    raft::linalg::norm(handle,
+                       input2,
+                       output2.view(),
+                       raft::linalg::L2Norm,
+                       raft::linalg::Apply::ALONG_ROWS,
+                       raft::sqrt_op());
+    raft::copy(&res, output2.data_handle(), 1, stream);
+    resource::sync_stream(handle, stream);
+    RAFT_LOG_TRACE("Iteration %f: residual (tolerance) %d", iter, res);
+  }
+
+  raft::copy(eigVals_dev, eigenvalues_k.data_handle(), nEigVecs, stream);
+  raft::copy(eigVecs_dev, ritz_eigenvectors.data_handle(), n * nEigVecs, stream);
+
+  return 0;
+}
+
+template <typename IndexTypeT, typename ValueTypeT>
+auto lanczos_compute_smallest_eigenvectors(
+  raft::resources const& handle,
+  lanczos_solver_config<ValueTypeT> const& config,
+  raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
+  std::optional<raft::device_vector_view<ValueTypeT, uint32_t>> v0,
+  raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues,
+  raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
+{
+  if (v0.has_value()) {
+    return lanczos_smallest(handle,
+                            A,
+                            config.n_components,
+                            config.max_iterations,
+                            config.ncv,
+                            config.tolerance,
+                            eigenvalues.data_handle(),
+                            eigenvectors.data_handle(),
+                            v0->data_handle(),
+                            config.seed);
+  } else {
+    // Handle the optional v0 initial Lanczos vector if nullopt is used
+    auto n       = A.structure_view().get_n_rows();
+    auto temp_v0 = raft::make_device_vector<ValueTypeT, uint32_t>(handle, n);
+    raft::random::RngState rng_state(config.seed);
+    raft::random::uniform(handle, rng_state, temp_v0.view(), ValueTypeT{0.0}, ValueTypeT{1.0});
+    return lanczos_smallest(handle,
+                            A,
+                            config.n_components,
+                            config.max_iterations,
+                            config.ncv,
+                            config.tolerance,
+                            eigenvalues.data_handle(),
+                            eigenvectors.data_handle(),
+                            temp_v0.data_handle(),
+                            config.seed);
+  }
+}
+
 }  // namespace raft::sparse::solver::detail
diff --git a/cpp/include/raft/sparse/solver/lanczos.cuh b/cpp/include/raft/sparse/solver/lanczos.cuh
index 1aa56d6ba2..fed31e6a9c 100644
--- a/cpp/include/raft/sparse/solver/lanczos.cuh
+++ b/cpp/include/raft/sparse/solver/lanczos.cuh
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023, NVIDIA CORPORATION.
+ * Copyright (c) 2022-2024, NVIDIA CORPORATION.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
 #pragma once
 
 #include <raft/sparse/solver/detail/lanczos.cuh>
+#include <raft/sparse/solver/lanczos_types.hpp>
 #include <raft/spectral/matrix_wrappers.hpp>
 
 namespace raft::sparse::solver {
@@ -27,6 +28,78 @@ namespace raft::sparse::solver {
 // Eigensolver
 // =========================================================
 
+/**
+ *  @brief Find the smallest eigenpairs using lanczos solver
+ *  @tparam index_type_t the type of data used for indexing.
+ *  @tparam value_type_t the type of data used for weights, distances.
+ *  @param handle the raft handle.
+ *  @param config lanczos config used to set hyperparameters
+ *  @param A Sparse matrix in CSR format.
+ *  @param v0 Optional Initial lanczos vector
+ *  @param eigenvalues output eigenvalues
+ *  @param eigenvectors output eigenvectors
+ *  @todo Add largest eigenvalues computation (issue #2483)
+ *  @return Zero if successful. Otherwise non-zero.
+ */
+template <typename IndexTypeT, typename ValueTypeT>
+auto lanczos_compute_smallest_eigenvectors(
+  raft::resources const& handle,
+  lanczos_solver_config<ValueTypeT> const& config,
+  raft::device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT> A,
+  std::optional<raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major>> v0,
+  raft::device_vector_view<ValueTypeT, uint32_t, raft::col_major> eigenvalues,
+  raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
+{
+  return detail::lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
+    handle, config, A, v0, eigenvalues, eigenvectors);
+}
+
+/**
+ *  @brief Find the smallest eigenpairs using lanczos solver
+ *  @tparam index_type_t the type of data used for indexing.
+ *  @tparam value_type_t the type of data used for weights, distances.
+ *  @param handle the raft handle.
+ *  @param config lanczos config used to set hyperparameters
+ *  @param rows Vector view of the rows of the sparse matrix.
+ *  @param cols Vector view of the cols of the sparse matrix.
+ *  @param vals Vector view of the vals of the sparse matrix.
+ *  @param v0 Optional Initial lanczos vector
+ *  @param eigenvalues output eigenvalues
+ *  @param eigenvectors output eigenvectors
+ *  @todo Add largest eigenvalues computation (issue #2483)
+ *  @return Zero if successful. Otherwise non-zero.
+ */
+template <typename IndexTypeT, typename ValueTypeT>
+auto lanczos_compute_smallest_eigenvectors(
+  raft::resources const& handle,
+  lanczos_solver_config<ValueTypeT> const& config,
+  raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> rows,
+  raft::device_vector_view<IndexTypeT, uint32_t, raft::row_major> cols,
+  raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major> vals,
+  std::optional<raft::device_vector_view<ValueTypeT, uint32_t, raft::row_major>> v0,
+  raft::device_vector_view<ValueTypeT, uint32_t, raft::col_major> eigenvalues,
+  raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors) -> int
+{
+  IndexTypeT ncols = rows.extent(0) - 1;
+  IndexTypeT nrows = rows.extent(0) - 1;
+  IndexTypeT nnz   = cols.extent(0);
+
+  auto csr_structure =
+    raft::make_device_compressed_structure_view<IndexTypeT, IndexTypeT, IndexTypeT>(
+      const_cast<IndexTypeT*>(rows.data_handle()),
+      const_cast<IndexTypeT*>(cols.data_handle()),
+      ncols,
+      nrows,
+      nnz);
+
+  auto csr_matrix =
+    raft::make_device_csr_matrix_view<ValueTypeT, IndexTypeT, IndexTypeT, IndexTypeT>(
+      const_cast<ValueTypeT*>(vals.data_handle()), csr_structure);
+
+  return lanczos_compute_smallest_eigenvectors<IndexTypeT, ValueTypeT>(
+    handle, config, csr_matrix, v0, eigenvalues, eigenvectors);
+}
+
 /**
  *  @brief  Compute smallest eigenvectors of symmetric matrix
  *    Computes eigenvalues and eigenvectors that are least
diff --git a/cpp/include/raft/sparse/solver/lanczos_types.hpp b/cpp/include/raft/sparse/solver/lanczos_types.hpp
new file mode 100644
index 0000000000..edd5548079
--- /dev/null
+++ b/cpp/include/raft/sparse/solver/lanczos_types.hpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <cstdint>
+
+namespace raft::sparse::solver {
+
+template <typename ValueTypeT>
+struct lanczos_solver_config {
+  /** The number of eigenvalues and eigenvectors to compute. Must be 1 <= k < n.*/
+  int n_components;
+  /** Maximum number of iteration. */
+  int max_iterations;
+  /** The number of Lanczos vectors generated. Must be k + 1 < ncv < n. */
+  int ncv;
+  /** Tolerance for residuals ``||Ax - wx||`` */
+  ValueTypeT tolerance;
+  /** random seed */
+  uint64_t seed;
+};
+
+}  // namespace raft::sparse::solver
diff --git a/cpp/include/raft/spectral/eigen_solvers.cuh b/cpp/include/raft/spectral/eigen_solvers.cuh
index 4774d8b8ae..d98e90532e 100644
--- a/cpp/include/raft/spectral/eigen_solvers.cuh
+++ b/cpp/include/raft/spectral/eigen_solvers.cuh
@@ -18,6 +18,7 @@
 
 #pragma once
 
+#include <raft/core/mdspan.hpp>
 #include <raft/sparse/solver/lanczos.cuh>
 #include <raft/spectral/matrix_wrappers.hpp>
 
@@ -57,18 +58,32 @@ struct lanczos_solver_t {
   {
     RAFT_EXPECTS(eigVals != nullptr, "Null eigVals buffer.");
     RAFT_EXPECTS(eigVecs != nullptr, "Null eigVecs buffer.");
-    index_type_t iters{};
-    sparse::solver::computeSmallestEigenvectors(handle,
-                                                A,
-                                                config_.n_eigVecs,
-                                                config_.maxIter,
-                                                config_.restartIter,
-                                                config_.tol,
-                                                config_.reorthogonalize,
-                                                iters,
-                                                eigVals,
-                                                eigVecs,
-                                                config_.seed);
+    index_type_t iters{0};  // TODO: return total number of iter
+    auto lanczos_config = raft::sparse::solver::lanczos_solver_config<value_type_t>{
+      config_.n_eigVecs, config_.maxIter, config_.restartIter, config_.tol, config_.seed};
+    auto csr_structure =
+      raft::make_device_compressed_structure_view<index_type_t, index_type_t, index_type_t>(
+        const_cast<index_type_t*>(A.row_offsets_),
+        const_cast<index_type_t*>(A.col_indices_),
+        A.nrows_,
+        A.ncols_,
+        A.nnz_);
+
+    auto csr_matrix =
+      raft::make_device_csr_matrix_view<value_type_t, index_type_t, index_type_t, index_type_t>(
+        const_cast<value_type_t*>(A.values_), csr_structure);
+    std::optional<raft::device_vector_view<value_type_t, uint32_t, raft::row_major>> v0_opt;
+
+    sparse::solver::lanczos_compute_smallest_eigenvectors(
+      handle,
+      lanczos_config,
+      csr_matrix,
+      v0_opt,
+      raft::make_device_vector_view<value_type_t, uint32_t, raft::col_major>(eigVals,
+                                                                             config_.n_eigVecs),
+      raft::make_device_matrix_view<value_type_t, uint32_t, raft::col_major>(
+        eigVecs, A.nrows_, config_.n_eigVecs));
+
     return iters;
   }
 
diff --git a/cpp/include/raft_runtime/solver/lanczos.hpp b/cpp/include/raft_runtime/solver/lanczos.hpp
new file mode 100644
index 0000000000..6c9d901bf1
--- /dev/null
+++ b/cpp/include/raft_runtime/solver/lanczos.hpp
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <raft/core/device_mdspan.hpp>
+#include <raft/core/resources.hpp>
+#include <raft/sparse/solver/lanczos_types.hpp>
+
+#include <cstdint>
+
+namespace raft::runtime::solver {
+
+/**
+ * @defgroup lanczos_runtime lanczos Runtime API
+ * @{
+ */
+
+#define FUNC_DECL(IndexType, ValueType)                                               \
+  void lanczos_solver(                                                                \
+    const raft::resources& handle,                                                    \
+    raft::sparse::solver::lanczos_solver_config<ValueType> config,                    \
+    raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows,              \
+    raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols,              \
+    raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals,              \
+    std::optional<raft::device_vector_view<ValueType, uint32_t, raft::row_major>> v0, \
+    raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues,       \
+    raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors)
+
+FUNC_DECL(int, float);
+FUNC_DECL(int64_t, float);
+FUNC_DECL(int, double);
+FUNC_DECL(int64_t, double);
+
+#undef FUNC_DECL
+
+/** @} */  // end group lanczos_runtime
+
+}  // namespace raft::runtime::solver
diff --git a/cpp/src/raft_runtime/solver/lanczos_solver.cuh b/cpp/src/raft_runtime/solver/lanczos_solver.cuh
new file mode 100644
index 0000000000..0c851ef13a
--- /dev/null
+++ b/cpp/src/raft_runtime/solver/lanczos_solver.cuh
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <raft/sparse/solver/lanczos.cuh>
+
+#define FUNC_DEF(IndexType, ValueType)                                                 \
+  void lanczos_solver(                                                                 \
+    const raft::resources& handle,                                                     \
+    raft::sparse::solver::lanczos_solver_config<ValueType> config,                     \
+    raft::device_vector_view<IndexType, uint32_t, raft::row_major> rows,               \
+    raft::device_vector_view<IndexType, uint32_t, raft::row_major> cols,               \
+    raft::device_vector_view<ValueType, uint32_t, raft::row_major> vals,               \
+    std::optional<raft::device_vector_view<ValueType, uint32_t, raft::row_major>> v0,  \
+    raft::device_vector_view<ValueType, uint32_t, raft::col_major> eigenvalues,        \
+    raft::device_matrix_view<ValueType, uint32_t, raft::col_major> eigenvectors)       \
+  {                                                                                    \
+    raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>( \
+      handle, config, rows, cols, vals, v0, eigenvalues, eigenvectors);                \
+  }
diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu
new file mode 100644
index 0000000000..f772a8a0d1
--- /dev/null
+++ b/cpp/src/raft_runtime/solver/lanczos_solver_int64_double.cu
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lanczos_solver.cuh"
+
+namespace raft::runtime::solver {
+
+FUNC_DEF(int64_t, double);
+
+}  // namespace raft::runtime::solver
diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu
new file mode 100644
index 0000000000..efaf3be565
--- /dev/null
+++ b/cpp/src/raft_runtime/solver/lanczos_solver_int64_float.cu
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lanczos_solver.cuh"
+
+namespace raft::runtime::solver {
+
+FUNC_DEF(int64_t, float);
+
+}  // namespace raft::runtime::solver
diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu
new file mode 100644
index 0000000000..9bbc00e78a
--- /dev/null
+++ b/cpp/src/raft_runtime/solver/lanczos_solver_int_double.cu
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lanczos_solver.cuh"
+
+namespace raft::runtime::solver {
+
+FUNC_DEF(int, double);
+
+}  // namespace raft::runtime::solver
diff --git a/cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu b/cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu
new file mode 100644
index 0000000000..316a9fb7e1
--- /dev/null
+++ b/cpp/src/raft_runtime/solver/lanczos_solver_int_float.cu
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lanczos_solver.cuh"
+
+namespace raft::runtime::solver {
+
+FUNC_DEF(int, float);
+
+}  // namespace raft::runtime::solver
diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt
index a387e9ce09..621ee6c160 100644
--- a/cpp/test/CMakeLists.txt
+++ b/cpp/test/CMakeLists.txt
@@ -232,8 +232,8 @@ if(BUILD_TESTS)
   )
 
   ConfigureTest(
-    NAME SOLVERS_TEST PATH linalg/eigen_solvers.cu lap/lap.cu sparse/mst.cu LIB
-    EXPLICIT_INSTANTIATE_ONLY
+    NAME SOLVERS_TEST PATH linalg/eigen_solvers.cu lap/lap.cu sparse/mst.cu
+    sparse/solver/lanczos.cu LIB EXPLICIT_INSTANTIATE_ONLY
   )
 
   ConfigureTest(
diff --git a/cpp/test/sparse/solver/lanczos.cu b/cpp/test/sparse/solver/lanczos.cu
new file mode 100644
index 0000000000..74611a1fd8
--- /dev/null
+++ b/cpp/test/sparse/solver/lanczos.cu
@@ -0,0 +1,445 @@
+/*
+ * Copyright (c) 2019-2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "../../test_utils.cuh"
+
+#include <raft/core/device_mdarray.hpp>
+#include <raft/core/device_mdspan.hpp>
+#include <raft/core/mdspan.hpp>
+#include <raft/core/mdspan_types.hpp>
+#include <raft/core/resources.hpp>
+#include <raft/matrix/init.cuh>
+#include <raft/random/rmat_rectangular_generator.cuh>
+#include <raft/random/rng.cuh>
+#include <raft/random/rng_state.hpp>
+#include <raft/sparse/convert/csr.cuh>
+#include <raft/sparse/coo.hpp>
+#include <raft/sparse/linalg/degree.cuh>
+#include <raft/sparse/linalg/symmetrize.cuh>
+#include <raft/sparse/op/reduce.cuh>
+#include <raft/sparse/op/sort.cuh>
+#include <raft/sparse/solver/lanczos_types.hpp>
+#include <raft/spectral/eigen_solvers.cuh>
+#include <raft/spectral/matrix_wrappers.hpp>
+#include <raft/util/cudart_utils.hpp>
+
+#include <driver_types.h>
+
+#include <gtest/gtest.h>
+#include <sys/types.h>
+#include <test_utils.h>
+
+#include <algorithm>
+#include <cmath>
+#include <cstdint>
+#include <iostream>
+
+namespace raft::sparse {
+
+template <typename IndexType, typename ValueType>
+struct lanczos_inputs {
+  int n_components;
+  int restartiter;
+  int maxiter;
+  int conv_n_iters;
+  float conv_eps;
+  float tol;
+  uint64_t seed;
+  std::vector<IndexType> rows;  // indptr
+  std::vector<IndexType> cols;  // indices
+  std::vector<ValueType> vals;  // data
+  std::vector<ValueType> expected_eigenvalues;
+};
+
+template <typename IndexType, typename ValueType>
+struct rmat_lanczos_inputs {
+  int n_components;
+  int restartiter;
+  int maxiter;
+  int conv_n_iters;
+  float conv_eps;
+  float tol;
+  uint64_t seed;
+  int r_scale;
+  int c_scale;
+  float sparsity;
+  std::vector<ValueType> expected_eigenvalues;
+};
+
+template <typename IndexType, typename ValueType>
+class rmat_lanczos_tests
+  : public ::testing::TestWithParam<rmat_lanczos_inputs<IndexType, ValueType>> {
+ public:
+  rmat_lanczos_tests()
+    : params(::testing::TestWithParam<rmat_lanczos_inputs<IndexType, ValueType>>::GetParam()),
+      stream(resource::get_cuda_stream(handle)),
+      rng(params.seed),
+      expected_eigenvalues(raft::make_device_vector<ValueType, uint32_t, raft::col_major>(
+        handle, params.n_components)),
+      r_scale(params.r_scale),
+      c_scale(params.c_scale),
+      sparsity(params.sparsity)
+  {
+  }
+
+ protected:
+  void SetUp() override
+  {
+    raft::copy(expected_eigenvalues.data_handle(),
+               params.expected_eigenvalues.data(),
+               params.n_components,
+               stream);
+  }
+
+  void TearDown() override {}
+
+  void Run()
+  {
+    uint64_t n_edges   = sparsity * ((long long)(1 << r_scale) * (long long)(1 << c_scale));
+    uint64_t n_nodes   = 1 << std::max(r_scale, c_scale);
+    uint64_t theta_len = std::max(r_scale, c_scale) * 4;
+
+    auto theta = raft::make_device_vector<ValueType, uint32_t, raft::row_major>(handle, theta_len);
+    raft::random::uniform<ValueType>(handle, rng, theta.view(), 0, 1);
+
+    auto out =
+      raft::make_device_matrix<IndexType, uint32_t, raft::row_major>(handle, n_edges * 2, 2);
+    auto out_src = raft::make_device_vector<IndexType, uint32_t, raft::row_major>(handle, n_edges);
+    auto out_dst = raft::make_device_vector<IndexType, uint32_t, raft::row_major>(handle, n_edges);
+
+    raft::random::RngState rng1{params.seed};
+
+    raft::random::rmat_rectangular_gen<IndexType, ValueType>(handle,
+                                                             rng1,
+                                                             make_const_mdspan(theta.view()),
+                                                             out.view(),
+                                                             out_src.view(),
+                                                             out_dst.view(),
+                                                             r_scale,
+                                                             c_scale);
+
+    raft::device_vector<ValueType, uint32_t, raft::row_major> out_data =
+      raft::make_device_vector<ValueType, uint32_t, raft::row_major>(handle, n_edges);
+    raft::matrix::fill<ValueType>(handle, out_data.view(), 1.0);
+    raft::sparse::COO<ValueType, IndexType> coo(stream);
+
+    raft::sparse::op::coo_sort<ValueType, int>(n_nodes,
+                                               n_nodes,
+                                               n_edges,
+                                               out_src.data_handle(),
+                                               out_dst.data_handle(),
+                                               out_data.data_handle(),
+                                               stream);
+    raft::sparse::op::max_duplicates<IndexType, ValueType>(handle,
+                                                           coo,
+                                                           out_src.data_handle(),
+                                                           out_dst.data_handle(),
+                                                           out_data.data_handle(),
+                                                           n_edges,
+                                                           n_nodes,
+                                                           n_nodes);
+
+    raft::sparse::COO<ValueType, IndexType> symmetric_coo(stream);
+    raft::sparse::linalg::symmetrize(
+      handle, coo.rows(), coo.cols(), coo.vals(), coo.n_rows, coo.n_cols, coo.nnz, symmetric_coo);
+
+    raft::device_vector<IndexType, uint32_t, raft::row_major> row_indices =
+      raft::make_device_vector<IndexType, uint32_t, raft::row_major>(handle,
+                                                                     symmetric_coo.n_rows + 1);
+    raft::sparse::convert::sorted_coo_to_csr(symmetric_coo.rows(),
+                                             symmetric_coo.nnz,
+                                             row_indices.data_handle(),
+                                             symmetric_coo.n_rows + 1,
+                                             stream);
+
+    int n_components = params.n_components;
+
+    raft::device_vector<ValueType, uint32_t, raft::row_major> v0 =
+      raft::make_device_vector<ValueType, uint32_t, raft::row_major>(handle, symmetric_coo.n_rows);
+
+    raft::random::uniform<ValueType>(handle, rng, v0.view(), 0, 1);
+    std::tuple<IndexType, ValueType, IndexType> stats;
+
+    raft::device_vector<ValueType, uint32_t, raft::col_major> eigenvalues =
+      raft::make_device_vector<ValueType, uint32_t, raft::col_major>(handle, n_components);
+    raft::device_matrix<ValueType, uint32_t, raft::col_major> eigenvectors =
+      raft::make_device_matrix<ValueType, uint32_t, raft::col_major>(
+        handle, symmetric_coo.n_rows, n_components);
+
+    raft::spectral::matrix::sparse_matrix_t<IndexType, ValueType> const csr_m{
+      handle,
+      row_indices.data_handle(),
+      symmetric_coo.cols(),
+      symmetric_coo.vals(),
+      symmetric_coo.n_rows,
+      symmetric_coo.nnz};
+    raft::sparse::solver::lanczos_solver_config<ValueType> config{
+      n_components, params.maxiter, params.restartiter, params.tol, rng.seed};
+
+    auto csr_structure =
+      raft::make_device_compressed_structure_view<IndexType, IndexType, IndexType>(
+        const_cast<IndexType*>(row_indices.data_handle()),
+        const_cast<IndexType*>(symmetric_coo.cols()),
+        symmetric_coo.n_rows,
+        symmetric_coo.n_rows,
+        symmetric_coo.nnz);
+
+    auto csr_matrix = raft::make_device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType>(
+      const_cast<ValueType*>(symmetric_coo.vals()), csr_structure);
+
+    std::get<0>(stats) =
+      raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>(
+        handle,
+        config,
+        csr_matrix,
+        std::make_optional(v0.view()),
+        eigenvalues.view(),
+        eigenvectors.view());
+
+    ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
+                                             expected_eigenvalues.data_handle(),
+                                             n_components,
+                                             raft::CompareApprox<ValueType>(1e-5),
+                                             stream));
+  }
+
+ protected:
+  rmat_lanczos_inputs<IndexType, ValueType> params;
+  raft::resources handle;
+  cudaStream_t stream;
+  raft::random::RngState rng;
+  int r_scale;
+  int c_scale;
+  float sparsity;
+  raft::device_vector<ValueType, uint32_t, raft::col_major> expected_eigenvalues;
+};
+
+template <typename IndexType, typename ValueType>
+class lanczos_tests : public ::testing::TestWithParam<lanczos_inputs<IndexType, ValueType>> {
+ public:
+  lanczos_tests()
+    : params(::testing::TestWithParam<lanczos_inputs<IndexType, ValueType>>::GetParam()),
+      stream(resource::get_cuda_stream(handle)),
+      n(params.rows.size() - 1),
+      nnz(params.vals.size()),
+      rng(params.seed),
+      rows(raft::make_device_vector<IndexType, uint32_t, raft::row_major>(handle, n + 1)),
+      cols(raft::make_device_vector<IndexType, uint32_t, raft::row_major>(handle, nnz)),
+      vals(raft::make_device_vector<ValueType, uint32_t, raft::row_major>(handle, nnz)),
+      v0(raft::make_device_vector<ValueType, uint32_t, raft::row_major>(handle, n)),
+      eigenvalues(raft::make_device_vector<ValueType, uint32_t, raft::col_major>(
+        handle, params.n_components)),
+      eigenvectors(raft::make_device_matrix<ValueType, uint32_t, raft::col_major>(
+        handle, n, params.n_components)),
+      expected_eigenvalues(
+        raft::make_device_vector<ValueType, uint32_t, raft::col_major>(handle, params.n_components))
+  {
+  }
+
+ protected:
+  void SetUp() override
+  {
+    raft::copy(rows.data_handle(), params.rows.data(), n + 1, stream);
+    raft::copy(cols.data_handle(), params.cols.data(), nnz, stream);
+    raft::copy(vals.data_handle(), params.vals.data(), nnz, stream);
+    raft::copy(expected_eigenvalues.data_handle(),
+               params.expected_eigenvalues.data(),
+               params.n_components,
+               stream);
+  }
+
+  void TearDown() override {}
+
+  void Run()
+  {
+    raft::random::uniform<ValueType>(handle, rng, v0.view(), 0, 1);
+    std::tuple<IndexType, ValueType, IndexType> stats;
+
+    raft::sparse::solver::lanczos_solver_config<ValueType> config{
+      params.n_components, params.maxiter, params.restartiter, params.tol, rng.seed};
+    auto csr_structure =
+      raft::make_device_compressed_structure_view<IndexType, IndexType, IndexType>(
+        const_cast<IndexType*>(rows.data_handle()),
+        const_cast<IndexType*>(cols.data_handle()),
+        n,
+        n,
+        nnz);
+
+    auto csr_matrix = raft::make_device_csr_matrix_view<ValueType, IndexType, IndexType, IndexType>(
+      const_cast<ValueType*>(vals.data_handle()), csr_structure);
+
+    std::get<0>(stats) =
+      raft::sparse::solver::lanczos_compute_smallest_eigenvectors<IndexType, ValueType>(
+        handle,
+        config,
+        csr_matrix,
+        std::make_optional(v0.view()),
+        eigenvalues.view(),
+        eigenvectors.view());
+
+    ASSERT_TRUE(raft::devArrMatch<ValueType>(eigenvalues.data_handle(),
+                                             expected_eigenvalues.data_handle(),
+                                             params.n_components,
+                                             raft::CompareApprox<ValueType>(1e-5),
+                                             stream));
+  }
+
+ protected:
+  lanczos_inputs<IndexType, ValueType> params;
+  raft::resources handle;
+  cudaStream_t stream;
+  int n;
+  int nnz;
+  raft::random::RngState rng;
+  raft::device_vector<IndexType, uint32_t, raft::row_major> rows;
+  raft::device_vector<IndexType, uint32_t, raft::row_major> cols;
+  raft::device_vector<ValueType, uint32_t, raft::row_major> vals;
+  raft::device_vector<ValueType, uint32_t, raft::row_major> v0;
+  raft::device_vector<ValueType, uint32_t, raft::col_major> eigenvalues;
+  raft::device_matrix<ValueType, uint32_t, raft::col_major> eigenvectors;
+  raft::device_vector<ValueType, uint32_t, raft::col_major> expected_eigenvalues;
+};
+
+// TODO: Find a way to generate and validate test data without hardcoding them (issue #2485)
+const std::vector<lanczos_inputs<int, float>> inputsf = {
+  {2,
+   34,
+   10000,
+   0,
+   0,
+   1e-15,
+   42,
+   {0,   0,   0,   0,   3,   5,   6,   8,   9,   11,  16,  16,  18,  20,  23,  24,  27,
+    30,  31,  33,  37,  37,  39,  41,  43,  44,  46,  46,  47,  49,  50,  50,  51,  53,
+    57,  58,  59,  66,  67,  68,  69,  71,  72,  75,  78,  83,  86,  90,  93,  94,  96,
+    98,  99,  101, 101, 104, 106, 108, 109, 109, 109, 109, 111, 113, 118, 120, 121, 123,
+    124, 128, 132, 134, 136, 138, 139, 141, 145, 148, 151, 152, 154, 155, 157, 160, 164,
+    167, 170, 170, 170, 173, 178, 179, 182, 184, 186, 191, 192, 196, 198, 198, 198},
+   {44, 68, 74, 16, 36, 85, 34, 75, 61, 51, 83, 15, 33, 55, 69, 71, 18, 84, 70, 95, 71, 83,
+    97, 83, 9,  36, 54, 4,  42, 46, 52, 11, 89, 31, 37, 74, 96, 36, 88, 56, 64, 68, 94, 82,
+    35, 90, 50, 82, 85, 83, 19, 47, 94, 9,  44, 56, 79, 6,  25, 4,  15, 21, 52, 75, 79, 92,
+    19, 72, 94, 94, 96, 80, 16, 54, 89, 46, 48, 63, 3,  33, 67, 73, 77, 46, 47, 75, 16, 43,
+    45, 81, 32, 45, 68, 43, 55, 63, 27, 89, 8,  17, 36, 15, 42, 96, 9,  49, 22, 33, 77, 7,
+    75, 78, 88, 43, 49, 66, 76, 91, 22, 82, 69, 63, 84, 44, 3,  23, 47, 81, 9,  65, 76, 92,
+    12, 96, 9,  13, 38, 93, 44, 3,  19, 6,  36, 45, 61, 63, 69, 89, 44, 57, 94, 62, 33, 36,
+    41, 46, 68, 24, 28, 64, 8,  13, 14, 29, 11, 66, 88, 5,  28, 93, 21, 62, 84, 18, 42, 50,
+    76, 91, 25, 63, 89, 97, 36, 69, 72, 85, 23, 32, 39, 40, 77, 12, 19, 40, 54, 70, 13, 91},
+   {0.4734894, 0.1402491, 0.7686475, 0.0416142, 0.2559651, 0.9360436, 0.7486080, 0.5206724,
+    0.0374126, 0.8082515, 0.5993828, 0.4866583, 0.8907925, 0.9251201, 0.8566143, 0.9528994,
+    0.4557763, 0.4907070, 0.4158074, 0.8311127, 0.9026024, 0.3103237, 0.5876446, 0.7585195,
+    0.4866583, 0.4493615, 0.5909155, 0.0416142, 0.0963910, 0.6722401, 0.3468698, 0.4557763,
+    0.1445242, 0.7720124, 0.9923756, 0.1227579, 0.7194629, 0.8916773, 0.4320931, 0.5840980,
+    0.0216121, 0.3709223, 0.1705930, 0.8297898, 0.2409706, 0.9585592, 0.3171389, 0.0228039,
+    0.4350971, 0.4939908, 0.7720124, 0.2722416, 0.1792683, 0.8907925, 0.1085757, 0.8745620,
+    0.3298612, 0.7486080, 0.2409706, 0.2559651, 0.4493615, 0.8916773, 0.5540361, 0.5150571,
+    0.9160119, 0.1767728, 0.9923756, 0.5717281, 0.1077409, 0.9368132, 0.6273088, 0.6616613,
+    0.0963910, 0.9378265, 0.3059566, 0.3159291, 0.0449106, 0.9085807, 0.4734894, 0.1085757,
+    0.2909013, 0.7787509, 0.7168902, 0.9691764, 0.2669757, 0.4389115, 0.6722401, 0.3159291,
+    0.9691764, 0.7467896, 0.2722416, 0.2669757, 0.1532843, 0.0449106, 0.2023634, 0.8934466,
+    0.3171389, 0.6594226, 0.8082515, 0.3468698, 0.5540361, 0.5909155, 0.9378265, 0.2909178,
+    0.9251201, 0.2023634, 0.5840980, 0.8745620, 0.2624605, 0.0374126, 0.1034030, 0.3736577,
+    0.3315690, 0.9085807, 0.8934466, 0.5548525, 0.2302140, 0.7827352, 0.0216121, 0.8262919,
+    0.1646078, 0.5548525, 0.2658700, 0.2909013, 0.1402491, 0.3709223, 0.1532843, 0.5792196,
+    0.8566143, 0.1646078, 0.0827300, 0.5810611, 0.4158074, 0.5188584, 0.9528994, 0.9026024,
+    0.5717281, 0.7269946, 0.7787509, 0.7686475, 0.1227579, 0.5206724, 0.5150571, 0.4389115,
+    0.1034030, 0.2302140, 0.0827300, 0.8961608, 0.7168902, 0.2624605, 0.4823034, 0.3736577,
+    0.3298612, 0.9160119, 0.6616613, 0.7467896, 0.5792196, 0.8297898, 0.0228039, 0.8262919,
+    0.5993828, 0.3103237, 0.7585195, 0.4939908, 0.4907070, 0.2658700, 0.0844443, 0.9360436,
+    0.4350971, 0.6997072, 0.4320931, 0.3315690, 0.0844443, 0.1445242, 0.3059566, 0.6594226,
+    0.8961608, 0.6498466, 0.9585592, 0.7827352, 0.6498466, 0.2812338, 0.1767728, 0.5810611,
+    0.7269946, 0.6997072, 0.1705930, 0.1792683, 0.1077409, 0.9368132, 0.4823034, 0.8311127,
+    0.7194629, 0.6273088, 0.2909178, 0.5188584, 0.5876446, 0.2812338},
+   {-2.0369630, -1.7673520}}};
+
+const std::vector<lanczos_inputs<int, double>> inputsd = {
+  {2,
+   34,
+   10000,
+   0,
+   0,
+   1e-15,
+   42,
+   {0,   0,   0,   0,   3,   5,   6,   8,   9,   11,  16,  16,  18,  20,  23,  24,  27,
+    30,  31,  33,  37,  37,  39,  41,  43,  44,  46,  46,  47,  49,  50,  50,  51,  53,
+    57,  58,  59,  66,  67,  68,  69,  71,  72,  75,  78,  83,  86,  90,  93,  94,  96,
+    98,  99,  101, 101, 104, 106, 108, 109, 109, 109, 109, 111, 113, 118, 120, 121, 123,
+    124, 128, 132, 134, 136, 138, 139, 141, 145, 148, 151, 152, 154, 155, 157, 160, 164,
+    167, 170, 170, 170, 173, 178, 179, 182, 184, 186, 191, 192, 196, 198, 198, 198},
+   {44, 68, 74, 16, 36, 85, 34, 75, 61, 51, 83, 15, 33, 55, 69, 71, 18, 84, 70, 95, 71, 83,
+    97, 83, 9,  36, 54, 4,  42, 46, 52, 11, 89, 31, 37, 74, 96, 36, 88, 56, 64, 68, 94, 82,
+    35, 90, 50, 82, 85, 83, 19, 47, 94, 9,  44, 56, 79, 6,  25, 4,  15, 21, 52, 75, 79, 92,
+    19, 72, 94, 94, 96, 80, 16, 54, 89, 46, 48, 63, 3,  33, 67, 73, 77, 46, 47, 75, 16, 43,
+    45, 81, 32, 45, 68, 43, 55, 63, 27, 89, 8,  17, 36, 15, 42, 96, 9,  49, 22, 33, 77, 7,
+    75, 78, 88, 43, 49, 66, 76, 91, 22, 82, 69, 63, 84, 44, 3,  23, 47, 81, 9,  65, 76, 92,
+    12, 96, 9,  13, 38, 93, 44, 3,  19, 6,  36, 45, 61, 63, 69, 89, 44, 57, 94, 62, 33, 36,
+    41, 46, 68, 24, 28, 64, 8,  13, 14, 29, 11, 66, 88, 5,  28, 93, 21, 62, 84, 18, 42, 50,
+    76, 91, 25, 63, 89, 97, 36, 69, 72, 85, 23, 32, 39, 40, 77, 12, 19, 40, 54, 70, 13, 91},
+   {0.4734894, 0.1402491, 0.7686475, 0.0416142, 0.2559651, 0.9360436, 0.7486080, 0.5206724,
+    0.0374126, 0.8082515, 0.5993828, 0.4866583, 0.8907925, 0.9251201, 0.8566143, 0.9528994,
+    0.4557763, 0.4907070, 0.4158074, 0.8311127, 0.9026024, 0.3103237, 0.5876446, 0.7585195,
+    0.4866583, 0.4493615, 0.5909155, 0.0416142, 0.0963910, 0.6722401, 0.3468698, 0.4557763,
+    0.1445242, 0.7720124, 0.9923756, 0.1227579, 0.7194629, 0.8916773, 0.4320931, 0.5840980,
+    0.0216121, 0.3709223, 0.1705930, 0.8297898, 0.2409706, 0.9585592, 0.3171389, 0.0228039,
+    0.4350971, 0.4939908, 0.7720124, 0.2722416, 0.1792683, 0.8907925, 0.1085757, 0.8745620,
+    0.3298612, 0.7486080, 0.2409706, 0.2559651, 0.4493615, 0.8916773, 0.5540361, 0.5150571,
+    0.9160119, 0.1767728, 0.9923756, 0.5717281, 0.1077409, 0.9368132, 0.6273088, 0.6616613,
+    0.0963910, 0.9378265, 0.3059566, 0.3159291, 0.0449106, 0.9085807, 0.4734894, 0.1085757,
+    0.2909013, 0.7787509, 0.7168902, 0.9691764, 0.2669757, 0.4389115, 0.6722401, 0.3159291,
+    0.9691764, 0.7467896, 0.2722416, 0.2669757, 0.1532843, 0.0449106, 0.2023634, 0.8934466,
+    0.3171389, 0.6594226, 0.8082515, 0.3468698, 0.5540361, 0.5909155, 0.9378265, 0.2909178,
+    0.9251201, 0.2023634, 0.5840980, 0.8745620, 0.2624605, 0.0374126, 0.1034030, 0.3736577,
+    0.3315690, 0.9085807, 0.8934466, 0.5548525, 0.2302140, 0.7827352, 0.0216121, 0.8262919,
+    0.1646078, 0.5548525, 0.2658700, 0.2909013, 0.1402491, 0.3709223, 0.1532843, 0.5792196,
+    0.8566143, 0.1646078, 0.0827300, 0.5810611, 0.4158074, 0.5188584, 0.9528994, 0.9026024,
+    0.5717281, 0.7269946, 0.7787509, 0.7686475, 0.1227579, 0.5206724, 0.5150571, 0.4389115,
+    0.1034030, 0.2302140, 0.0827300, 0.8961608, 0.7168902, 0.2624605, 0.4823034, 0.3736577,
+    0.3298612, 0.9160119, 0.6616613, 0.7467896, 0.5792196, 0.8297898, 0.0228039, 0.8262919,
+    0.5993828, 0.3103237, 0.7585195, 0.4939908, 0.4907070, 0.2658700, 0.0844443, 0.9360436,
+    0.4350971, 0.6997072, 0.4320931, 0.3315690, 0.0844443, 0.1445242, 0.3059566, 0.6594226,
+    0.8961608, 0.6498466, 0.9585592, 0.7827352, 0.6498466, 0.2812338, 0.1767728, 0.5810611,
+    0.7269946, 0.6997072, 0.1705930, 0.1792683, 0.1077409, 0.9368132, 0.4823034, 0.8311127,
+    0.7194629, 0.6273088, 0.2909178, 0.5188584, 0.5876446, 0.2812338},
+   {-2.0369630, -1.7673520}}};
+
+const std::vector<rmat_lanczos_inputs<int, float>> rmat_inputsf = {
+  {50, 100, 10000, 0, 0, 1e-9, 42, 12, 12, 1, {-122.526794, -74.00686,  -59.698284,  -54.68617,
+                                               -49.686813,  -34.02644,  -32.130703,  -31.26906,
+                                               -30.32097,   -22.946098, -20.497862,  -20.23817,
+                                               -19.269697,  -18.42496,  -17.675667,  -17.013401,
+                                               -16.734581,  -15.820215, -15.73925,   -15.448187,
+                                               -15.044634,  -14.692028, -14.127425,  -13.967386,
+                                               -13.6237755, -13.469393, -13.181225,  -12.777589,
+                                               -12.623185,  -12.55508,  -12.2874565, -12.053391,
+                                               -11.677346,  -11.558279, -11.163732,  -10.922034,
+                                               -10.7936945, -10.558049, -10.205776,  -10.005316,
+                                               -9.559181,   -9.491834,  -9.242631,   -8.883637,
+                                               -8.765364,   -8.688508,  -8.458255,   -8.385196,
+                                               -8.217982,   -8.0442095}}};
+
+using LanczosTestF = lanczos_tests<int, float>;
+TEST_P(LanczosTestF, Result) { Run(); }
+
+using LanczosTestD = lanczos_tests<int, double>;
+TEST_P(LanczosTestD, Result) { Run(); }
+
+using RmatLanczosTestF = rmat_lanczos_tests<int, float>;
+TEST_P(RmatLanczosTestF, Result) { Run(); }
+
+INSTANTIATE_TEST_CASE_P(LanczosTests, LanczosTestF, ::testing::ValuesIn(inputsf));
+INSTANTIATE_TEST_CASE_P(LanczosTests, LanczosTestD, ::testing::ValuesIn(inputsd));
+INSTANTIATE_TEST_CASE_P(LanczosTests, RmatLanczosTestF, ::testing::ValuesIn(rmat_inputsf));
+
+}  // namespace raft::sparse
diff --git a/docs/source/pylibraft_api.rst b/docs/source/pylibraft_api.rst
index aaa359e646..ad7d2873d7 100644
--- a/docs/source/pylibraft_api.rst
+++ b/docs/source/pylibraft_api.rst
@@ -9,3 +9,4 @@ Python API
 
    pylibraft_api/common.rst
    pylibraft_api/random.rst
+   pylibraft_api/sparse.rst
diff --git a/docs/source/pylibraft_api/sparse.rst b/docs/source/pylibraft_api/sparse.rst
new file mode 100644
index 0000000000..b2c3f7a2b1
--- /dev/null
+++ b/docs/source/pylibraft_api/sparse.rst
@@ -0,0 +1,11 @@
+Sparse
+======
+
+This page provides pylibraft class references for the publicly-exposed elements of the `pylibraft.sparse.linalg.eigsh` package.
+
+
+.. role:: py(code)
+   :language: python
+   :class: highlight
+
+.. autofunction:: pylibraft.sparse.linalg.eigsh
\ No newline at end of file
diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt
index 9bde613720..758c1e4711 100644
--- a/python/pylibraft/CMakeLists.txt
+++ b/python/pylibraft/CMakeLists.txt
@@ -87,6 +87,7 @@ rapids_cython_init()
 
 add_subdirectory(pylibraft/common)
 add_subdirectory(pylibraft/random)
+add_subdirectory(pylibraft/sparse)
 
 if(DEFINED cython_lib_dir)
   rapids_cython_add_rpath_entries(TARGET raft PATHS "${cython_lib_dir}")
diff --git a/python/pylibraft/pylibraft/sparse/CMakeLists.txt b/python/pylibraft/pylibraft/sparse/CMakeLists.txt
new file mode 100644
index 0000000000..3779fd2715
--- /dev/null
+++ b/python/pylibraft/pylibraft/sparse/CMakeLists.txt
@@ -0,0 +1,15 @@
+# =============================================================================
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied. See the License for the specific language governing permissions and limitations under
+# the License.
+# =============================================================================
+
+add_subdirectory(linalg)
diff --git a/python/pylibraft/pylibraft/sparse/__init__.py b/python/pylibraft/pylibraft/sparse/__init__.py
new file mode 100644
index 0000000000..c77def5bb0
--- /dev/null
+++ b/python/pylibraft/pylibraft/sparse/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pylibraft.sparse import linalg
+
+__all__ = ["linalg"]
diff --git a/python/pylibraft/pylibraft/sparse/linalg/CMakeLists.txt b/python/pylibraft/pylibraft/sparse/linalg/CMakeLists.txt
new file mode 100644
index 0000000000..ef16981644
--- /dev/null
+++ b/python/pylibraft/pylibraft/sparse/linalg/CMakeLists.txt
@@ -0,0 +1,27 @@
+# =============================================================================
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+# in compliance with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software distributed under the License
+# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+# or implied. See the License for the specific language governing permissions and limitations under
+# the License.
+# =============================================================================
+
+# Set the list of Cython files to build
+set(cython_sources lanczos.pyx)
+
+# TODO: should finally be replaced with 'compiled' library to be more generic, when that is
+# available
+set(linked_libraries raft::raft raft::compiled)
+
+# Build all of the Cython targets
+rapids_cython_create_modules(
+  CXX
+  SOURCE_FILES "${cython_sources}"
+  LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX sparse_
+)
diff --git a/python/pylibraft/pylibraft/sparse/linalg/__init__.pxd b/python/pylibraft/pylibraft/sparse/linalg/__init__.pxd
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/python/pylibraft/pylibraft/sparse/linalg/__init__.py b/python/pylibraft/pylibraft/sparse/linalg/__init__.py
new file mode 100644
index 0000000000..04a8106496
--- /dev/null
+++ b/python/pylibraft/pylibraft/sparse/linalg/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .lanczos import eigsh
+
+__all__ = ["eigsh"]
diff --git a/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.pxd b/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.pxd
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.py b/python/pylibraft/pylibraft/sparse/linalg/cpp/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx b/python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx
new file mode 100644
index 0000000000..dc2a84b428
--- /dev/null
+++ b/python/pylibraft/pylibraft/sparse/linalg/lanczos.pyx
@@ -0,0 +1,277 @@
+#
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# cython: profile=False
+# distutils: language = c++
+# cython: embedsignature = True
+# cython: language_level = 3
+
+import cupy as cp
+import numpy as np
+
+from cython.operator cimport dereference as deref
+from libc.stdint cimport int64_t, uint32_t, uint64_t, uintptr_t
+
+from pylibraft.common import Handle, cai_wrapper, device_ndarray
+from pylibraft.common.handle import auto_sync_handle
+
+from libcpp cimport bool
+
+from pylibraft.common.cpp.mdspan cimport (
+    col_major,
+    device_matrix_view,
+    device_vector_view,
+    make_device_matrix_view,
+    make_device_vector_view,
+    row_major,
+)
+from pylibraft.common.cpp.optional cimport optional
+from pylibraft.common.handle cimport device_resources
+from pylibraft.random.cpp.rng_state cimport RngState
+
+
+cdef extern from "raft/sparse/solver/lanczos_types.hpp" \
+        namespace "raft::sparse::solver" nogil:
+
+    cdef cppclass lanczos_solver_config[ValueTypeT]:
+        int n_components
+        int max_iterations
+        int ncv
+        ValueTypeT tolerance
+        uint64_t seed
+
+cdef lanczos_solver_config[float] config_float
+cdef lanczos_solver_config[double] config_double
+
+cdef extern from "raft_runtime/solver/lanczos.hpp" \
+        namespace "raft::runtime::solver" nogil:
+
+    cdef void lanczos_solver(
+        const device_resources &handle,
+        lanczos_solver_config[double] config,
+        device_vector_view[int64_t, uint32_t] rows,
+        device_vector_view[int64_t, uint32_t] cols,
+        device_vector_view[double, uint32_t] vals,
+        optional[device_vector_view[double, uint32_t]] v0,
+        device_vector_view[double, uint32_t] eigenvalues,
+        device_matrix_view[double, uint32_t, col_major] eigenvectors) except +
+
+    cdef void lanczos_solver(
+        const device_resources &handle,
+        lanczos_solver_config[float] config,
+        device_vector_view[int64_t, uint32_t] rows,
+        device_vector_view[int64_t, uint32_t] cols,
+        device_vector_view[float, uint32_t] vals,
+        optional[device_vector_view[float, uint32_t]] v0,
+        device_vector_view[float, uint32_t] eigenvalues,
+        device_matrix_view[float, uint32_t, col_major] eigenvectors) except +
+
+    cdef void lanczos_solver(
+        const device_resources &handle,
+        lanczos_solver_config[double] config,
+        device_vector_view[int, uint32_t] rows,
+        device_vector_view[int, uint32_t] cols,
+        device_vector_view[double, uint32_t] vals,
+        optional[device_vector_view[double, uint32_t]] v0,
+        device_vector_view[double, uint32_t] eigenvalues,
+        device_matrix_view[double, uint32_t, col_major] eigenvectors) except +
+
+    cdef void lanczos_solver(
+        const device_resources &handle,
+        lanczos_solver_config[float] config,
+        device_vector_view[int, uint32_t] rows,
+        device_vector_view[int, uint32_t] cols,
+        device_vector_view[float, uint32_t] vals,
+        optional[device_vector_view[float, uint32_t]] v0,
+        device_vector_view[float, uint32_t] eigenvalues,
+        device_matrix_view[float, uint32_t, col_major] eigenvectors) except +
+
+
+@auto_sync_handle
+def eigsh(A, k=6, v0=None, ncv=None, maxiter=None,
+          tol=0, seed=None, handle=None):
+    """
+    Find ``k`` eigenvalues and eigenvectors of the real symmetric square
+    matrix or complex Hermitian matrix ``A``.
+
+    Solves ``Ax = wx``, the standard eigenvalue problem for ``w`` eigenvalues
+    with corresponding eigenvectors ``x``.
+
+    Args:
+        a (spmatrix): A symmetric square sparse CSR matrix with
+            dimension ``(n, n)``. ``a`` must be of type
+            :class:`cupyx.scipy.sparse._csr.csr_matrix`
+        k (int): The number of eigenvalues and eigenvectors to compute. Must be
+            ``1 <= k < n``.
+        v0 (ndarray): Starting vector for iteration. If ``None``, a random
+            unit vector is used.
+        ncv (int): The number of Lanczos vectors generated. Must be
+            ``k + 1 < ncv < n``. If ``None``, default value is used.
+        maxiter (int): Maximum number of Lanczos update iterations.
+            If ``None``, default value is used.
+        tol (float): Tolerance for residuals ``||Ax - wx||``. If ``0``, machine
+            precision is used.
+
+    Returns:
+        tuple:
+            It returns ``w`` and ``x``
+            where ``w`` is eigenvalues and ``x`` is eigenvectors.
+
+    .. seealso::
+        :func:`scipy.sparse.linalg.eigsh`
+        :func:`cupyx.scipy.sparse.linalg.eigsh`
+
+    .. note::
+        This function uses the thick-restart Lanczos methods
+        (https://sdm.lbl.gov/~kewu/ps/trlan.html).
+
+    """
+
+    if A is None:
+        raise Exception("'A' cannot be None!")
+
+    rows = A.indptr
+    cols = A.indices
+    vals = A.data
+
+    rows = cai_wrapper(rows)
+    cols = cai_wrapper(cols)
+    vals = cai_wrapper(vals)
+
+    IndexType = rows.dtype
+    ValueType = vals.dtype
+
+    N = A.shape[0]
+    n = N
+    nnz = A.nnz
+
+    rows_ptr = <uintptr_t>rows.data
+    cols_ptr = <uintptr_t>cols.data
+    vals_ptr = <uintptr_t>vals.data
+    cdef optional[device_vector_view[double, uint32_t]] d_v0
+    cdef optional[device_vector_view[float, uint32_t]] f_v0
+
+    if ncv is None:
+        ncv = min(n, max(2*k + 1, 20))
+    else:
+        ncv = min(max(ncv, k + 2), n - 1)
+
+    seed = seed if seed is not None else 42
+    if maxiter is None:
+        maxiter = 10 * n
+    if tol == 0:
+        tol = np.finfo(ValueType).eps
+
+    eigenvectors = device_ndarray.empty((N, k), dtype=ValueType, order='F')
+    eigenvalues = device_ndarray.empty((k), dtype=ValueType, order='F')
+
+    eigenvectors_cai = cai_wrapper(eigenvectors)
+    eigenvalues_cai = cai_wrapper(eigenvalues)
+
+    eigenvectors_ptr = <uintptr_t>eigenvectors_cai.data
+    eigenvalues_ptr = <uintptr_t>eigenvalues_cai.data
+
+    handle = handle if handle is not None else Handle()
+    cdef device_resources *h = <device_resources*><size_t>handle.getHandle()
+
+    if IndexType == np.int32 and ValueType == np.float32:
+        config_float.n_components = k
+        config_float.max_iterations = maxiter
+        config_float.ncv = ncv
+        config_float.tolerance = tol
+        config_float.seed = seed
+        if v0 is not None:
+            v0 = cai_wrapper(v0)
+            v0_ptr = <uintptr_t>v0.data
+            f_v0 = make_device_vector_view(<float *>v0_ptr, <uint32_t> N)
+        lanczos_solver(
+            deref(h),
+            <lanczos_solver_config[float]> config_float,
+            make_device_vector_view(<int *>rows_ptr, <uint32_t> (N + 1)),
+            make_device_vector_view(<int *>cols_ptr, <uint32_t> nnz),
+            make_device_vector_view(<float *>vals_ptr, <uint32_t> nnz),
+            f_v0,
+            make_device_vector_view(<float *>eigenvalues_ptr, <uint32_t> k),
+            make_device_matrix_view[float, uint32_t, col_major](
+                <float *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
+        )
+    elif IndexType == np.int64 and ValueType == np.float32:
+        config_float.n_components = k
+        config_float.max_iterations = maxiter
+        config_float.ncv = ncv
+        config_float.tolerance = tol
+        config_float.seed = seed
+        if v0 is not None:
+            v0 = cai_wrapper(v0)
+            v0_ptr = <uintptr_t>v0.data
+            f_v0 = make_device_vector_view(<float *>v0_ptr, <uint32_t> N)
+        lanczos_solver(
+            deref(h),
+            <lanczos_solver_config[float]> config_float,
+            make_device_vector_view(<int64_t *>rows_ptr, <uint32_t> (N + 1)),
+            make_device_vector_view(<int64_t *>cols_ptr, <uint32_t> nnz),
+            make_device_vector_view(<float *>vals_ptr, <uint32_t> nnz),
+            f_v0,
+            make_device_vector_view(<float *>eigenvalues_ptr, <uint32_t> k),
+            make_device_matrix_view[float, uint32_t, col_major](
+                <float *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
+        )
+    elif IndexType == np.int32 and ValueType == np.float64:
+        config_double.n_components = k
+        config_double.max_iterations = maxiter
+        config_double.ncv = ncv
+        config_double.tolerance = tol
+        config_double.seed = seed
+        if v0 is not None:
+            v0 = cai_wrapper(v0)
+            v0_ptr = <uintptr_t>v0.data
+            d_v0 = make_device_vector_view(<double *>v0_ptr, <uint32_t> N)
+        lanczos_solver(
+            deref(h),
+            <lanczos_solver_config[double]> config_double,
+            make_device_vector_view(<int *>rows_ptr, <uint32_t> (N + 1)),
+            make_device_vector_view(<int *>cols_ptr, <uint32_t> nnz),
+            make_device_vector_view(<double *>vals_ptr, <uint32_t> nnz),
+            d_v0,
+            make_device_vector_view(<double *>eigenvalues_ptr, <uint32_t> k),
+            make_device_matrix_view[double, uint32_t, col_major](
+                <double *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
+        )
+    elif IndexType == np.int64 and ValueType == np.float64:
+        config_double.n_components = k
+        config_double.max_iterations = maxiter
+        config_double.ncv = ncv
+        config_double.tolerance = tol
+        config_double.seed = seed
+        if v0 is not None:
+            v0 = cai_wrapper(v0)
+            v0_ptr = <uintptr_t>v0.data
+            d_v0 = make_device_vector_view(<double *>v0_ptr, <uint32_t> N)
+        lanczos_solver(
+            deref(h),
+            <lanczos_solver_config[double]> config_double,
+            make_device_vector_view(<int64_t *>rows_ptr, <uint32_t> (N + 1)),
+            make_device_vector_view(<int64_t *>cols_ptr, <uint32_t> nnz),
+            make_device_vector_view(<double *>vals_ptr, <uint32_t> nnz),
+            d_v0,
+            make_device_vector_view(<double *>eigenvalues_ptr, <uint32_t> k),
+            make_device_matrix_view[double, uint32_t, col_major](
+                <double *>eigenvectors_ptr, <uint32_t> N, <uint32_t> k),
+        )
+    else:
+        raise ValueError("dtype IndexType=%s and ValueType=%s not supported" %
+                         (IndexType, ValueType))
+
+    return (cp.asarray(eigenvalues), cp.asarray(eigenvectors))
diff --git a/python/pylibraft/pylibraft/test/test_sparse.py b/python/pylibraft/pylibraft/test/test_sparse.py
new file mode 100644
index 0000000000..10b261d322
--- /dev/null
+++ b/python/pylibraft/pylibraft/test/test_sparse.py
@@ -0,0 +1,142 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import cupy
+import cupyx.scipy.sparse.linalg  # NOQA
+import numpy
+import pytest
+from cupyx.scipy import sparse
+
+from pylibraft.sparse.linalg import eigsh
+
+
+def shaped_random(
+    shape, xp=cupy, dtype=numpy.float32, scale=10, seed=0, order="C"
+):
+    """
+    Returns an array filled with random values.
+
+    Args
+    ----
+        shape(tuple): Shape of returned ndarray.
+        xp(numpy or cupy): Array module to use.
+        dtype(dtype): Dtype of returned ndarray.
+        scale(float): Scaling factor of elements.
+        seed(int): Random seed.
+
+    Returns
+    -------
+        numpy.ndarray or cupy.ndarray: The array with
+        given shape, array module,
+
+    If ``dtype`` is ``numpy.bool_``, the elements are
+    independently drawn from ``True`` and ``False``
+    with same probabilities.
+    Otherwise, the array is filled with samples
+    independently and identically drawn
+    from uniform distribution over :math:`[0, scale)`
+    with specified dtype.
+    """
+    numpy.random.seed(seed)
+    dtype = numpy.dtype(dtype)
+    if dtype == "?":
+        a = numpy.random.randint(2, size=shape)
+    elif dtype.kind == "c":
+        a = numpy.random.rand(*shape) + 1j * numpy.random.rand(*shape)
+        a *= scale
+    else:
+        a = numpy.random.rand(*shape) * scale
+    return xp.asarray(a, dtype=dtype, order=order)
+
+
+class TestEigsh:
+    n = 30
+    density = 0.33
+    tol = {numpy.float32: 1e-5, numpy.complex64: 1e-5, "default": 1e-12}
+    res_tol = {"f": 1e-5, "d": 1e-12}
+    return_eigenvectors = True
+
+    def _make_matrix(self, dtype, xp):
+        shape = (self.n, self.n)
+        a = shaped_random(shape, xp, dtype=dtype)
+        mask = shaped_random(shape, xp, dtype="f", scale=1)
+        a[mask > self.density] = 0
+        a = a * a.conj().T
+        return a
+
+    def _test_eigsh(self, a, k, xp, sp):
+        expected_ret = sp.linalg.eigsh(
+            a, k=k, return_eigenvectors=self.return_eigenvectors
+        )
+        actual_ret = eigsh(a, k=k)
+        if self.return_eigenvectors:
+            w, x = actual_ret
+            exp_w, _ = expected_ret
+            # Check the residuals to see if eigenvectors are correct.
+            ax_xw = a @ x - xp.multiply(x, w.reshape(1, k))
+            res = xp.linalg.norm(ax_xw) / xp.linalg.norm(w)
+            tol = self.res_tol[numpy.dtype(a.dtype).char.lower()]
+            assert res < tol
+        else:
+            w = actual_ret
+            exp_w = expected_ret
+        w = xp.sort(w)
+        cupy.allclose(w, exp_w, rtol=tol, atol=tol)
+
+    @pytest.mark.parametrize("format", ["csr"])  # , 'csc', 'coo'])
+    @pytest.mark.parametrize("k", [3, 6, 12])
+    @pytest.mark.parametrize("dtype", ["f", "d"])
+    def test_sparse(self, format, k, dtype, xp=cupy, sp=sparse):
+        if format == "csc":
+            pytest.xfail("may be buggy")  # trans=True
+
+        a = self._make_matrix(dtype, xp)
+        a = sp.coo_matrix(a).asformat(format)
+        return self._test_eigsh(a, k, xp, sp)
+
+    def test_invalid(self):
+        xp, sp = cupy, sparse
+        a = xp.diag(xp.ones((self.n,), dtype="f"))
+        with pytest.raises(ValueError):
+            sp.linalg.eigsh(xp.ones((2, 1), dtype="f"))
+        with pytest.raises(ValueError):
+            sp.linalg.eigsh(a, k=0)
+        a = xp.diag(xp.ones((self.n,), dtype="f"))
+        with pytest.raises(ValueError):
+            sp.linalg.eigsh(xp.ones((1,), dtype="f"))
+        with pytest.raises(TypeError):
+            sp.linalg.eigsh(xp.ones((2, 2), dtype="i"))
+        with pytest.raises(ValueError):
+            sp.linalg.eigsh(a, k=self.n)
+
+    def test_starting_vector(self):
+        # Make symmetric matrix
+        aux = self._make_matrix("f", cupy)
+        aux = sparse.coo_matrix(aux).asformat("csr")
+        matrix = (aux + aux.T) / 2.0
+
+        # Find reference eigenvector
+        ew, ev = eigsh(matrix, k=1)
+        v = ev[:, 0]
+
+        # Obtain non-converged eigenvector from random initial guess.
+        ew_aux, ev_aux = eigsh(matrix, k=1, ncv=1, maxiter=0)
+        v_aux = cupy.copysign(ev_aux[:, 0], v)
+
+        # Obtain eigenvector using known eigenvector as initial guess.
+        ew_v0, ev_v0 = eigsh(matrix, k=1, v0=v.copy(), ncv=1, maxiter=0)
+        v_v0 = cupy.copysign(ev_v0[:, 0], v)
+
+        assert cupy.linalg.norm(v - v_v0) < cupy.linalg.norm(v - v_aux)