Skip to content

Commit

Permalink
Merge pull request #1987 from e10harvey/issue1984
Browse files Browse the repository at this point in the history
Test and fix gemv stream interface
  • Loading branch information
lucbv authored Oct 11, 2023
2 parents 0aac17f + 2c103ed commit c173644
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 154 deletions.
19 changes: 11 additions & 8 deletions blas/impl/KokkosBlas2_gemv_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
namespace KokkosBlas {
namespace Impl {
// Specialization struct which defines whether a specialization exists
template <class XMV, class YMV, class ZMV>
template <class ExecutionSpace, class XMV, class YMV, class ZMV>
struct gemv_eti_spec_avail {
enum : bool { value = false };
};
Expand All @@ -44,6 +44,7 @@ struct gemv_eti_spec_avail {
#define KOKKOSBLAS2_GEMV_ETI_SPEC_AVAIL(SCALAR, LAYOUT, EXEC_SPACE, MEM_SPACE) \
template <> \
struct gemv_eti_spec_avail< \
EXEC_SPACE, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Expand All @@ -67,14 +68,14 @@ namespace Impl {
//

// Implementation of KokkosBlas::gemv.
template <class AViewType, class XViewType, class YViewType,
bool tpl_spec_avail =
gemv_tpl_spec_avail<AViewType, XViewType, YViewType>::value,
bool eti_spec_avail =
gemv_eti_spec_avail<AViewType, XViewType, YViewType>::value>
template <
class ExecutionSpace, class AViewType, class XViewType, class YViewType,
bool tpl_spec_avail = gemv_tpl_spec_avail<ExecutionSpace, AViewType,
XViewType, YViewType>::value,
bool eti_spec_avail = gemv_eti_spec_avail<ExecutionSpace, AViewType,
XViewType, YViewType>::value>
struct GEMV {
static void gemv(const typename AViewType::execution_space& space,
const char trans[],
static void gemv(const ExecutionSpace& space, const char trans[],
typename AViewType::const_value_type& alpha,
const AViewType& A, const XViewType& x,
typename YViewType::const_value_type& beta,
Expand Down Expand Up @@ -130,6 +131,7 @@ struct GEMV {

#define KOKKOSBLAS2_GEMV_ETI_SPEC_DECL(SCALAR, LAYOUT, EXEC_SPACE, MEM_SPACE) \
extern template struct GEMV< \
EXEC_SPACE, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Expand All @@ -142,6 +144,7 @@ struct GEMV {

#define KOKKOSBLAS2_GEMV_ETI_SPEC_INST(SCALAR, LAYOUT, EXEC_SPACE, MEM_SPACE) \
template struct GEMV< \
EXEC_SPACE, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Expand Down
28 changes: 15 additions & 13 deletions blas/src/KokkosBlas2_gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ namespace KokkosBlas {
/// \param x [in] Input vector, as a 1-D Kokkos::View
/// \param beta [in] Input coefficient of y
/// \param y [in/out] Output vector, as a nonconst 1-D Kokkos::View
template <class execution_space, class AViewType, class XViewType,
template <class ExecutionSpace, class AViewType, class XViewType,
class YViewType>
void gemv(const execution_space& space, const char trans[],
void gemv(const ExecutionSpace& space, const char trans[],
typename AViewType::const_value_type& alpha, const AViewType& A,
const XViewType& x, typename YViewType::const_value_type& beta,
const YViewType& y) {
static_assert(Kokkos::is_execution_space_v<execution_space>,
"KokkosBlas::gemv: execution_space must be a valid Kokkos "
static_assert(Kokkos::is_execution_space_v<ExecutionSpace>,
"KokkosBlas::gemv: ExecutionSpace must be a valid Kokkos "
"execution space.");
static_assert(Kokkos::is_view<AViewType>::value,
"KokkosBlas::gemv: AViewType must be a Kokkos::View.");
Expand All @@ -71,17 +71,17 @@ void gemv(const execution_space& space, const char trans[],
static_assert(static_cast<int>(YViewType::rank) == 1,
"KokkosBlas::gemv: YViewType must have rank 1.");
static_assert(
Kokkos::SpaceAccessibility<execution_space,
Kokkos::SpaceAccessibility<ExecutionSpace,
typename AViewType::memory_space>::accessible,
"KokkosBlas::gemv: AViewType must be accessible from execution_space");
"KokkosBlas::gemv: AViewType must be accessible from ExecutionSpace");
static_assert(
Kokkos::SpaceAccessibility<execution_space,
Kokkos::SpaceAccessibility<ExecutionSpace,
typename XViewType::memory_space>::accessible,
"KokkosBlas::gemv: XViewType must be accessible from execution_space");
"KokkosBlas::gemv: XViewType must be accessible from ExecutionSpace");
static_assert(
Kokkos::SpaceAccessibility<execution_space,
Kokkos::SpaceAccessibility<ExecutionSpace,
typename YViewType::memory_space>::accessible,
"KokkosBlas::gemv: YViewType must be accessible from execution_space");
"KokkosBlas::gemv: YViewType must be accessible from ExecutionSpace");

// Check compatibility of dimensions at run time.
if (trans[0] == 'N' || trans[0] == 'n') {
Expand Down Expand Up @@ -171,11 +171,13 @@ void gemv(const execution_space& space, const char trans[],

if (useFallback) {
const bool eti_spec_avail =
KokkosBlas::Impl::gemv_eti_spec_avail<AVT, XVT, YVT>::value;
typedef Impl::GEMV<AVT, XVT, YVT, false, eti_spec_avail> fallback_impl_type;
KokkosBlas::Impl::gemv_eti_spec_avail<ExecutionSpace, AVT, XVT,
YVT>::value;
typedef Impl::GEMV<ExecutionSpace, AVT, XVT, YVT, false, eti_spec_avail>
fallback_impl_type;
fallback_impl_type::gemv(space, trans, alpha, A, x, beta, y);
} else {
typedef Impl::GEMV<AVT, XVT, YVT> impl_type;
typedef Impl::GEMV<ExecutionSpace, AVT, XVT, YVT> impl_type;
impl_type::gemv(space, trans, alpha, A, x, beta, y);
}
}
Expand Down
10 changes: 7 additions & 3 deletions blas/tpls/KokkosBlas2_gemv_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
namespace KokkosBlas {
namespace Impl {
// Specialization struct which defines whether a specialization exists
template <class AT, class XT, class YT>
template <class ExecutionSpace, class AT, class XT, class YT>
struct gemv_tpl_spec_avail {
enum : bool { value = false };
};
Expand All @@ -32,6 +32,7 @@ struct gemv_tpl_spec_avail {
LAYOUTY, MEMSPACE) \
template <class ExecSpace> \
struct gemv_tpl_spec_avail< \
ExecSpace, \
Kokkos::View<const SCALAR**, LAYOUTA, \
Kokkos::Device<ExecSpace, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Expand Down Expand Up @@ -78,6 +79,7 @@ KOKKOSBLAS2_GEMV_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>,
LAYOUTY, MEMSPACE) \
template <class ExecSpace> \
struct gemv_tpl_spec_avail< \
ExecSpace, \
Kokkos::View<const SCALAR**, LAYOUTA, \
Kokkos::Device<ExecSpace, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Expand Down Expand Up @@ -126,8 +128,9 @@ KOKKOSBLAS2_GEMV_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<float>,
#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS

#define KOKKOSBLAS2_GEMV_TPL_SPEC_AVAIL_ROCBLAS(SCALAR, LAYOUT) \
template <> \
template <class ExecSpace> \
struct gemv_tpl_spec_avail< \
ExecSpace, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::HIP, \
Kokkos::Experimental::HIPSpace>, \
Expand Down Expand Up @@ -164,8 +167,9 @@ KOKKOSBLAS2_GEMV_TPL_SPEC_AVAIL_ROCBLAS(Kokkos::complex<float>,
#ifdef KOKKOS_ENABLE_SYCL

#define KOKKOSBLAS2_GEMV_TPL_SPEC_AVAIL_ONEMKL(SCALAR, LAYOUT) \
template <> \
template <class ExecSpace> \
struct gemv_tpl_spec_avail< \
ExecSpace, \
Kokkos::View<const SCALAR**, LAYOUT, \
Kokkos::Device<Kokkos::Experimental::SYCL, \
Kokkos::Experimental::SYCLDeviceUSMSpace>, \
Expand Down
Loading

0 comments on commit c173644

Please sign in to comment.