Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TPL support for KokkosBlas::dot #1949

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions CheckHostBlasReturnComplex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ FUNCTION(CHECK_HOST_BLAS_RETURN_COMPLEX VARNAME)

extern \"C\" {
void F77_BLAS_MANGLE(zdotc,ZDOTC)(
std::complex<double>* result, const int* n,
const std::complex<double> x[], const int* incx,
std::complex<double>* result, const int* n,
const std::complex<double> x[], const int* incx,
const std::complex<double> y[], const int* incy);
}

Expand All @@ -49,9 +49,9 @@ int main() {
CHECK_CXX_SOURCE_RUNS("${SOURCE}" KK_BLAS_RESULT_AS_POINTER_ARG)

IF(${KK_BLAS_RESULT_AS_POINTER_ARG})
SET(VARNAME OFF)
SET(${VARNAME} OFF PARENT_SCOPE)
ELSE()
SET(VARNAME ON)
SET(${VARNAME} ON PARENT_SCOPE)
ENDIF()

ENDFUNCTION()
44 changes: 28 additions & 16 deletions blas/src/KokkosBlas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,37 @@ dot(const execution_space& space, const XVector& x, const YVector& y) {
Kokkos::View<result_type, default_layout, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

result_type result{};
RVector_Result R = RVector_Result(&result);
XVector_Internal X = x;
YVector_Internal Y = y;

// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type is
// 32-bit precision). Impl::Dot needs to support both cases, and it's easier
// to do this with overloading than by extending the ETI to deal with two
// different scalar types.
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space, R,
X, Y);
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
bool useFallback = false;
if (useFallback) {
// Even though RVector is the template parameter, Dot::dot has an overload
// that accepts RVector_Internal (with the special accumulator, if dot_type
// is 32-bit precision). Impl::Dot needs to support both cases, and it's
// easier to do this with overloading than by extending the ETI to deal with
// two different scalar types.
result_type result{};
RVector_Result R = RVector_Result(&result);
Impl::DotSpecialAccumulator<execution_space, RVector_Internal,
XVector_Internal, YVector_Internal>::dot(space,
R, X,
Y);
space.fence();
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
} else {
dot_type result{};
RVector_Internal R = RVector_Internal(&result);
Impl::Dot<execution_space, RVector_Internal, XVector_Internal,
YVector_Internal>::dot(space, R, X, Y);
space.fence();
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(
result);
}
}

/// \brief Return the dot product of the two vectors x and y.
Expand Down
42 changes: 27 additions & 15 deletions blas/tpls/KokkosBlas1_dot_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,22 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(double, Kokkos::LayoutLeft,
Kokkos::HostSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(float, Kokkos::LayoutLeft,
Kokkos::HostSpace)

// TODO: we met difficuties in FindTPLMKL.cmake to set the BLAS library properly
// such that the test in CheckHostBlasReturnComplex.cmake could not be
// compiled and run to give a correct answer on KK_BLAS_RESULT_AS_POINTER_ARG.
// This resulted in segfault in dot() with MKL and complex.
// So we just temporarily disable it until FindTPLMKL.cmake is fixed.
#if !defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::HostSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::HostSpace)
#endif

#endif

// cuBLAS
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
// double
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(SCALAR, LAYOUT, EXECSPACE, \
MEMSPACE) \
#define KOKKOSBLAS1_DOT_TPL_SPEC(SCALAR, LAYOUT, EXECSPACE, MEMSPACE) \
template <> \
struct dot_tpl_spec_avail< \
EXECSPACE, \
Expand All @@ -77,19 +81,27 @@ KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_BLAS(Kokkos::complex<float>, Kokkos::LayoutLeft,
enum : bool { value = true }; \
};

KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(double, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(float, Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<double>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL_CUBLAS(Kokkos::complex<float>,
Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#define KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(float, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(double, LAYOUT, EXECSPACE, MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<float>, LAYOUT, EXECSPACE, \
MEMSPACE) \
KOKKOSBLAS1_DOT_TPL_SPEC(Kokkos::complex<double>, LAYOUT, EXECSPACE, MEMSPACE)

#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Cuda,
Kokkos::CudaSpace)
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCBLAS
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::HIP,
Kokkos::HIPSpace)
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL)
KOKKOSBLAS1_DOT_TPL_SPEC_AVAIL(Kokkos::LayoutLeft, Kokkos::Experimental::SYCL,
Kokkos::Experimental::SYCLDeviceUSMSpace)
#endif
} // namespace Impl
} // namespace KokkosBlas
#endif
Loading