Skip to content

Commit

Permalink
Merge pull request #1949 from jczhang07/2023-08-18/feature-tpl-dot
Browse files Browse the repository at this point in the history
Add TPL support for KokkosBlas::dot
  • Loading branch information
lucbv authored Dec 6, 2023
2 parents 4bfde66 + d2e7524 commit f4fd2e5
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 313 deletions.
8 changes: 4 additions & 4 deletions CheckHostBlasReturnComplex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ FUNCTION(CHECK_HOST_BLAS_RETURN_COMPLEX VARNAME)
extern \"C\" {
void F77_BLAS_MANGLE(zdotc,ZDOTC)(
std::complex<double>* result, const int* n,
const std::complex<double> x[], const int* incx,
std::complex<double>* result, const int* n,
const std::complex<double> x[], const int* incx,
const std::complex<double> y[], const int* incy);
}
Expand All @@ -49,9 +49,9 @@ int main() {
CHECK_CXX_SOURCE_RUNS("${SOURCE}" KK_BLAS_RESULT_AS_POINTER_ARG)

IF(${KK_BLAS_RESULT_AS_POINTER_ARG})
SET(VARNAME OFF)
SET(${VARNAME} OFF PARENT_SCOPE)
ELSE()
SET(VARNAME ON)
SET(${VARNAME} ON PARENT_SCOPE)
ENDIF()

ENDFUNCTION()
44 changes: 28 additions & 16 deletions blas/src/KokkosBlas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,37 @@ dot(const execution_space& space, const XVector& x, const YVector& y) {
Kokkos::View<result_type, default_layout, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

result_type result{};
RVector_Result R = RVector_Result(&result);
XVector_Internal X = x;
YVector_Internal Y = y;

// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type is
// 32-bit precision). Impl::Dot needs to support both cases, and it's easier
// to do this with overloading than by extending the ETI to deal with two
// different scalar types.
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space, R,
X, Y);
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
bool useFallback = false;
if (useFallback) {
// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type
// is 32-bit precision). Impl::Dot needs to support both cases, and it's
// easier to do this with overloading than by extending the ETI to deal with
// two different scalar types.
result_type result{};
RVector_Result R = RVector_Result(&result);
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space,
R, X,
Y);
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
} else {
dot_type result{};
RVector_Internal R = RVector_Internal(&result);
Impl::Dot<execution_space, RVector_Internal, XVector_Internal,
YVector_Internal>::dot(space, R, X, Y);
space.fence();
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
}
}

/// \brief Return the dot product of the two vectors x and y.
Expand Down
42 changes: 27 additions & 15 deletions blas/tpls/KokkosBlas1_dot_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,22 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(double, Kokkos::LayoutLeft,
Kokkos::HostSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(float, Kokkos::LayoutLeft,
Kokkos::HostSpace)

// TODO: we met difficuties in FindTPLMKL.cmake to set the BLAS library properly
// such that the test in CheckHostBlasReturnComplex.cmake could not be
// compiled and run to give a correct answer on KK_BLAS_RESULT_AS_POINTER_ARG.
// This resulted in segfault in dot() with MKL and complex.
// So we just temporarily disable it until FindTPLMKL.cmake is fixed.
#if !defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::HostSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::HostSpace)
#endif

#endif

// cuBLAS
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
// double
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(SCALAR, LAYOUT, EXECSPACE, \
MEMSPACE) \
#define KOKKOSBLAS1_DOT_TPL_SPEC(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct dot_tpl_spec_avail< \
EXECSPACE, \
Expand All @@ -77,19 +81,27 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
enum : bool { value = true }; \
};

KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(double, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(float, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<double>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<float>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(float, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(double, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<float>, LAYOUT, EXECSPACE, \
MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<double>, LAYOUT, EXECSPACE, MEMSPACE)

#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::HIP,
Kokkos::HIPSpace)
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Experimental::SYCL,
Kokkos::Experimental::SYCLDeviceUSMSpace)
#endif
} // namespace Impl
} // namespace KokkosBlas
#endif
Loading

0 comments on commit f4fd2e5

Please sign in to comment.