Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add merge-based SpMV #1911

Merged
merged 10 commits into from
Oct 16, 2023
File renamed without changes.
22 changes: 6 additions & 16 deletions common/src/KokkosKernels_LowerBound.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ namespace Impl {

/*! \brief Single-thread sequential lower-bound search

\tparam ViewLike A Kokkos::View or KokkosKernels::Impl::Iota
\tparam Pred a binary predicate function
\tparam ViewLike A Kokkos::View, KokkosKernels::Impl::Iota, or
KokkosSparse::MergeMatrixDiagonal \tparam Pred a binary predicate function
\param view the view to search
\param value the value to search for
\param pred a binary predicate function
Expand All @@ -96,9 +96,6 @@ lower_bound_sequential_thread(
using size_type = typename ViewLike::size_type;
static_assert(1 == ViewLike::rank,
"lower_bound_sequential_thread requires rank-1 views");
static_assert(is_iota_v<ViewLike> || Kokkos::is_view<ViewLike>::value,
"lower_bound_sequential_thread requires a "
"KokkosKernels::Impl::Iota or a Kokkos::View");

size_type i = 0;
while (i < view.size() && pred(view(i), value)) {
Expand All @@ -109,8 +106,8 @@ lower_bound_sequential_thread(

/*! \brief Single-thread binary lower-bound search

\tparam ViewLike A Kokkos::View or KokkosKernels::Impl::Iota
\tparam Pred a binary predicate function
\tparam ViewLike A Kokkos::View, KokkosKernels::Impl::Iota, or
KokkosSparse::MergeMatrixDiagonal \tparam Pred a binary predicate function
\param view the view to search
\param value the value to search for
\param pred a binary predicate function
Expand All @@ -127,9 +124,6 @@ KOKKOS_INLINE_FUNCTION typename ViewLike::size_type lower_bound_binary_thread(
using size_type = typename ViewLike::size_type;
static_assert(1 == ViewLike::rank,
"lower_bound_binary_thread requires rank-1 views");
static_assert(is_iota_v<ViewLike> || Kokkos::is_view<ViewLike>::value,
"lower_bound_binary_thread requires a "
"KokkosKernels::Impl::Iota or a Kokkos::View");

size_type lo = 0;
size_type hi = view.size();
Expand All @@ -150,8 +144,8 @@ KOKKOS_INLINE_FUNCTION typename ViewLike::size_type lower_bound_binary_thread(

/*! \brief single-thread lower-bound search

\tparam ViewLike A Kokkos::View or KokkosKernels::Impl::Iota
\tparam Pred a binary predicate function
\tparam ViewLike A Kokkos::View, KokkosKernels::Impl::Iota, or
KokkosSparse::MergeMatrixDiagonal \tparam Pred a binary predicate function
\param view the view to search
\param value the value to search for
\param pred a binary predicate function
Expand All @@ -168,10 +162,6 @@ KOKKOS_INLINE_FUNCTION typename ViewLike::size_type lower_bound_thread(
Pred pred = Pred()) {
static_assert(1 == ViewLike::rank,
"lower_bound_thread requires rank-1 views");
static_assert(KokkosKernels::Impl::is_iota_v<ViewLike> ||
Kokkos::is_view<ViewLike>::value,
"lower_bound_thread requires a "
"KokkosKernels::Impl::Iota or a Kokkos::View");
/*
sequential search makes on average 0.5 * view.size memory accesses
binary search makes log2(view.size)+1 accesses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,23 @@
#include <type_traits>

#include "KokkosKernels_Iota.hpp"
#include "KokkosKernels_LowerBound.hpp"
#include "KokkosKernels_Predicates.hpp"
#include "KokkosKernels_SafeCompare.hpp"

/// \file KokkosSparse_MergeMatrix.hpp
/// \file KokkosSparse_merge_matrix.hpp

namespace KokkosSparse {
namespace Experimental {
namespace Impl {
namespace KokkosSparse::Impl {

// a joint index into a and b
template <typename AIndex, typename BIndex>
struct MergeMatrixPosition {
using a_index_type = AIndex;
using b_index_type = BIndex;

AIndex ai;
BIndex bi;
};

/*! \class MergeMatrixDiagonal
\brief a view into the entries of the Merge Matrix along a diagonal
Expand Down Expand Up @@ -88,14 +98,7 @@ class MergeMatrixDiagonal {
using a_value_type = typename AView::non_const_value_type;
using b_value_type = typename BViewLike::non_const_value_type;

/*! \struct MatrixPosition
* \brief indices into the a_ and b_ views.
*/
struct MatrixPosition {
a_index_type ai;
b_index_type bi;
};
using position_type = MatrixPosition;
using position_type = MergeMatrixPosition<a_index_type, b_index_type>;

// implement bare minimum parts of the view interface
enum { rank = 1 };
Expand Down Expand Up @@ -145,9 +148,10 @@ class MergeMatrixDiagonal {
KOKKOS_INLINE_FUNCTION
bool operator()(const size_type di) const {
position_type pos = diag_to_a_b(di);
if (pos.ai >= a_.size()) {

if (pos.ai >= typename position_type::a_index_type(a_.size())) {
return true; // on the +a side out of matrix bounds is 1
} else if (pos.bi >= b_.size()) {
} else if (pos.bi >= typename position_type::b_index_type(b_.size())) {
return false; // on the +b side out of matrix bounds is 0
} else {
return KokkosKernels::Impl::safe_gt(a_(pos.ai), b_(pos.bi));
Expand All @@ -161,9 +165,9 @@ class MergeMatrixDiagonal {
*/
KOKKOS_INLINE_FUNCTION
size_type size() const noexcept {
if (d_ <= a_.size() && d_ <= b_.size()) {
if (d_ <= size_type(a_.size()) && d_ <= size_type(b_.size())) {
return d_;
} else if (d_ > a_.size() && d_ > b_.size()) {
} else if (d_ > size_type(a_.size()) && d_ > size_type(b_.size())) {
// TODO: this returns nonsense if d_ happens to be outside the merge
// matrix
return a_.size() + b_.size() - d_;
Expand All @@ -182,8 +186,8 @@ class MergeMatrixDiagonal {
KOKKOS_INLINE_FUNCTION
position_type diag_to_a_b(const size_type &di) const noexcept {
position_type res;
res.ai = d_ < a_.size() ? (d_ - 1) - di : a_.size() - 1 - di;
res.bi = d_ < a_.size() ? di : d_ + di - a_.size();
res.ai = d_ < size_type(a_.size()) ? (d_ - 1) - di : a_.size() - 1 - di;
res.bi = d_ < size_type(a_.size()) ? di : d_ + di - a_.size();
return res;
}

Expand All @@ -192,8 +196,106 @@ class MergeMatrixDiagonal {
size_type d_; ///< diagonal
};

} // namespace Impl
} // namespace Experimental
} // namespace KokkosSparse
/*! \brief Return the first index on diagonal \code diag
in the merge matrix of \code a and \code b that is not 1
This is effectively a lower-bound search on the merge matrix diagonal
where the predicate is "equals 1"
*/
template <typename AView, typename BViewLike>
KOKKOS_INLINE_FUNCTION
typename MergeMatrixDiagonal<AView, BViewLike>::position_type
diagonal_search(
const AView &a, const BViewLike &b,
typename MergeMatrixDiagonal<AView, BViewLike>::size_type diag) {
// unmanaged view types for a and b
using um_a_view =
Kokkos::View<typename AView::value_type *, typename AView::device_type,
Kokkos::MemoryUnmanaged>;
using um_b_view =
Kokkos::View<typename BViewLike::value_type *,
typename BViewLike::device_type, Kokkos::MemoryUnmanaged>;

um_a_view ua(a.data(), a.size());

// if BViewLike is an Iota, pass it on directly to MMD,
// otherwise, create an unmanaged view of B
using b_type =
typename std::conditional<KokkosKernels::Impl::is_iota<BViewLike>::value,
BViewLike, um_b_view>::type;

using MMD = MergeMatrixDiagonal<um_a_view, b_type>;
MMD mmd;
if constexpr (KokkosKernels::Impl::is_iota<BViewLike>::value) {
mmd = MMD(ua, b, diag);
} else {
b_type ub(b.data(), b.size());
mmd = MMD(ua, ub, diag);
}

// returns index of the first element that does not satisfy pred(element,
// value) our input view is the merge matrix entry along the diagonal, and we
// want the first one that is not true. so our predicate just tells us if the
// merge matrix diagonal entry is equal to true or not
const typename MMD::size_type idx = KokkosKernels::lower_bound_thread(
mmd, true, KokkosKernels::Equal<bool>());
return mmd.position(idx);
}

template <typename TeamMember, typename AView, typename BViewLike>
KOKKOS_INLINE_FUNCTION
typename MergeMatrixDiagonal<AView, BViewLike>::position_type
diagonal_search(
const TeamMember &handle, const AView &a, const BViewLike &b,
typename MergeMatrixDiagonal<AView, BViewLike>::size_type diag) {
// unmanaged view types for a and b
using um_a_view =
Kokkos::View<typename AView::value_type *, typename AView::device_type,
Kokkos::MemoryUnmanaged>;
using um_b_view =
Kokkos::View<typename BViewLike::value_type *,
typename BViewLike::device_type, Kokkos::MemoryUnmanaged>;

um_a_view ua(a.data(), a.size());

// if BViewLike is an Iota, pass it on directly to MMD,
// otherwise, create an unmanaged view of B
using b_type =
typename std::conditional<KokkosKernels::Impl::is_iota<BViewLike>::value,
BViewLike, um_b_view>::type;

using MMD = MergeMatrixDiagonal<um_a_view, b_type>;
MMD mmd;
if constexpr (KokkosKernels::Impl::is_iota<BViewLike>::value) {
mmd = MMD(ua, b, diag);
} else {
b_type ub(b.data(), b.size());
mmd = MMD(ua, ub, diag);
}

// returns index of the first element that does not satisfy pred(element,
// value) our input view is the merge matrix entry along the diagonal, and we
// want the first one that is not true. so our predicate just tells us if the
// merge matrix diagonal entry is equal to true or not
const typename MMD::size_type idx = KokkosKernels::lower_bound_team(
handle, mmd, true, KokkosKernels::Equal<bool>());
return mmd.position(idx);
}

/*! \brief

\return A MergeMatrixDiagonal::position_type
*/
template <typename View>
KOKKOS_INLINE_FUNCTION auto diagonal_search(
const View &a, typename View::non_const_value_type totalWork,
typename View::size_type diag) {
using value_type = typename View::non_const_value_type;
using size_type = typename View::size_type;

KokkosKernels::Impl::Iota<value_type, size_type> iota(totalWork);
return diagonal_search(a, iota, diag);
}

} // namespace KokkosSparse::Impl

#endif // KOKKOSSPARSE_MERGEMATRIX_HPP
35 changes: 27 additions & 8 deletions sparse/impl/KokkosSparse_spmv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
#ifndef KOKKOSSPARSE_IMPL_SPMV_DEF_HPP_
#define KOKKOSSPARSE_IMPL_SPMV_DEF_HPP_

#include <sstream>

#include "KokkosKernels_Controls.hpp"
#include "Kokkos_InnerProductSpaceTraits.hpp"
#include "KokkosBlas1_scal.hpp"
#include "KokkosKernels_ExecSpaceUtils.hpp"
#include "KokkosSparse_CrsMatrix.hpp"
#include "KokkosSparse_spmv_impl_omp.hpp"
#include "KokkosSparse_spmv_impl_merge.hpp"
#include "KokkosKernels_Error.hpp"

namespace KokkosSparse {
namespace Impl {

constexpr const char* KOKKOSSPARSE_ALG_MERGE = "merge";
cwpearson marked this conversation as resolved.
Show resolved Hide resolved

// This TransposeFunctor is functional, but not necessarily performant.
template <class execution_space, class AMatrix, class XVector, class YVector,
bool conjugate>
Expand Down Expand Up @@ -629,20 +634,32 @@ static void spmv_beta(const execution_space& exec,
typename YVector::const_value_type& beta,
const YVector& y) {
if (mode[0] == NoTranspose[0]) {
spmv_beta_no_transpose<execution_space, AMatrix, XVector, YVector, dobeta,
false>(exec, controls, alpha, A, x, beta, y);
if (controls.getParameter("algorithm") == KOKKOSSPARSE_ALG_MERGE) {
SpmvMergeHierarchical<execution_space, AMatrix, XVector, YVector>::spmv(
exec, mode, alpha, A, x, beta, y);
} else {
spmv_beta_no_transpose<execution_space, AMatrix, XVector, YVector, dobeta,
false>(exec, controls, alpha, A, x, beta, y);
}
} else if (mode[0] == Conjugate[0]) {
spmv_beta_no_transpose<execution_space, AMatrix, XVector, YVector, dobeta,
true>(exec, controls, alpha, A, x, beta, y);
if (controls.getParameter("algorithm") == KOKKOSSPARSE_ALG_MERGE) {
SpmvMergeHierarchical<execution_space, AMatrix, XVector, YVector>::spmv(
exec, mode, alpha, A, x, beta, y);
} else {
spmv_beta_no_transpose<execution_space, AMatrix, XVector, YVector, dobeta,
true>(exec, controls, alpha, A, x, beta, y);
}
} else if (mode[0] == Transpose[0]) {
spmv_beta_transpose<execution_space, AMatrix, XVector, YVector, dobeta,
false>(exec, alpha, A, x, beta, y);
} else if (mode[0] == ConjugateTranspose[0]) {
spmv_beta_transpose<execution_space, AMatrix, XVector, YVector, dobeta,
true>(exec, alpha, A, x, beta, y);
} else {
KokkosKernels::Impl::throw_runtime_exception(
"Invalid Transpose Mode for KokkosSparse::spmv()");
std::stringstream ss;
ss << __FILE__ << ":" << __LINE__ << " Invalid transpose mode " << mode
<< " for KokkosSparse::spmv()";
KokkosKernels::Impl::throw_runtime_exception(ss.str());
}
}

Expand Down Expand Up @@ -1460,8 +1477,10 @@ static void spmv_alpha_beta_mv(
doalpha, dobeta, true>(exec, alpha, A, x, beta,
y);
} else {
KokkosKernels::Impl::throw_runtime_exception(
"Invalid Transpose Mode for KokkosSparse::spmv()");
std::stringstream ss;
ss << __FILE__ << ":" << __LINE__ << " Invalid transpose mode " << mode
<< " for KokkosSparse::spmv()";
KokkosKernels::Impl::throw_runtime_exception(ss.str());
}
}

Expand Down
Loading