Skip to content

Commit

Permalink
Implement batched serial pttrs (#2277)
Browse files Browse the repository at this point in the history
* Implement batched serial pttrs

* Add tests for pttrs

* Add tag for pttrs

* fix: remove unnecessary specialization for pttrs internal

* format

* format

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Aug 28, 2024
1 parent d4c2511 commit 69811cc
Show file tree
Hide file tree
Showing 8 changed files with 753 additions and 0 deletions.
91 changes: 91 additions & 0 deletions batched/dense/impl/KokkosBatched_Pttrs_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PTTRS_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_PTTRS_SERIAL_IMPL_HPP_

#include <KokkosBatched_Util.hpp>
#include <KokkosBlas1_scal.hpp>
#include "KokkosBatched_Pttrs_Serial_Internal.hpp"

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

template <typename DViewType, typename EViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int checkPttrsInput([[maybe_unused]] const DViewType &d,
[[maybe_unused]] const EViewType &e,
[[maybe_unused]] const BViewType &b) {
static_assert(Kokkos::is_view_v<DViewType>, "KokkosBatched::pttrs: DViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<EViewType>, "KokkosBatched::pttrs: EViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>, "KokkosBatched::pttrs: BViewType is not a Kokkos::View.");

static_assert(DViewType::rank == 1, "KokkosBatched::pttrs: DViewType must have rank 1.");
static_assert(EViewType::rank == 1, "KokkosBatched::pttrs: EViewType must have rank 1.");
static_assert(BViewType::rank == 1, "KokkosBatched::pttrs: BViewType must have rank 1.");

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int nd = d.extent(0);
const int ne = e.extent(0);
const int ldb = b.extent(0);

if (ne + 1 != nd) {
Kokkos::printf(
"KokkosBatched::pttrs: Dimensions of d and e do not match: d: %d, e: "
"%d \n"
"e.extent(0) must be equal to d.extent(0) - 1\n",
nd, ne);
return 1;
}

if (ldb < Kokkos::max(1, nd)) {
Kokkos::printf(
"KokkosBatched::pttrs: Dimensions of d and b do not match: d: %d, b: "
"%d \n"
"b.extent(0) must be larger or equal to d.extent(0) \n",
ldb, nd);
return 1;
}
#endif
return 0;
}

template <typename ArgUplo>
struct SerialPttrs<ArgUplo, Algo::Pttrs::Unblocked> {
template <typename DViewType, typename EViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const DViewType &d, const EViewType &e, const BViewType &b) {
// Quick return if possible
if (d.extent(0) == 0) return 0;

auto info = checkPttrsInput(d, e, b);
if (info) return info;

using ScalarType = typename DViewType::non_const_value_type;
int n = d.extent(0);

if (n == 1) {
const ScalarType alpha = 1.0 / d(0);
return KokkosBlas::SerialScale::invoke(alpha, b);
}

// Solve A * X = B using the factorization A = L*D*L**T,
// overwriting each right hand side vector with its solution.
return SerialPttrsInternal<ArgUplo, Algo::Pttrs::Unblocked>::invoke(n, d.data(), d.stride(0), e.data(), e.stride(0),
b.data(), b.stride(0));
}
};
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PTTRS_SERIAL_IMPL_HPP_
88 changes: 88 additions & 0 deletions batched/dense/impl/KokkosBatched_Pttrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PTTRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PTTRS_SERIAL_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

template <typename ArgUplo, typename AlgoType>
struct SerialPttrsInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const ValueType *KOKKOS_RESTRICT e, const int es0,
ValueType *KOKKOS_RESTRICT b, const int bs0);

template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0,
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0);
};

///
/// Real matrix
///

template <typename ArgUplo, typename AlgoType>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPttrsInternal<ArgUplo, AlgoType>::invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0, const ValueType *KOKKOS_RESTRICT e, const int es0,
ValueType *KOKKOS_RESTRICT b, const int bs0) {
// Solve A * X = B using the factorization L * D * L**T
for (int i = 1; i < n; i++) {
b[i * bs0] -= e[(i - 1) * es0] * b[(i - 1) * bs0];
}

b[(n - 1) * bs0] /= d[(n - 1) * ds0];

for (int i = n - 2; i >= 0; i--) {
b[i * bs0] = b[i * bs0] / d[i * ds0] - b[(i + 1) * bs0] * e[i * es0];
}

return 0;
}

///
/// Complex matrix
///

template <typename ArgUplo, typename AlgoType>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPttrsInternal<ArgUplo, AlgoType>::invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0, const Kokkos::complex<ValueType> *KOKKOS_RESTRICT e,
const int es0, Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0) {
// Solve A * X = B using the factorization L * D * L**H
for (int i = 1; i < n; i++) {
auto tmp_e = std::is_same_v<ArgUplo, Uplo::Upper> ? Kokkos::conj(e[(i - 1) * es0]) : e[(i - 1) * es0];
b[i * bs0] -= tmp_e * b[(i - 1) * bs0];
}

b[(n - 1) * bs0] /= d[(n - 1) * ds0];

for (int i = n - 2; i >= 0; i--) {
auto tmp_e = std::is_same_v<ArgUplo, Uplo::Lower> ? Kokkos::conj(e[i * es0]) : e[i * es0];
b[i * bs0] = b[i * bs0] / d[i * ds0] - b[(i + 1) * bs0] * tmp_e;
}

return 0;
}

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PTTRS_SERIAL_INTERNAL_HPP_
54 changes: 54 additions & 0 deletions batched/dense/src/KokkosBatched_Pttrs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PTTRS_HPP_
#define KOKKOSBATCHED_PTTRS_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

/// \brief Serial Batched Pttrs:
/// Solve Ab_l x_l = b_l for all l = 0, ..., N
/// using the factorization A = U**H * D * U or A = L * D * L**H computed by
/// Pttrf.
///
/// \tparam DViewType: Input type for the a diagonal matrix, needs to be a 1D
/// view
/// \tparam EViewType: Input type for the a upper/lower diagonal matrix,
/// needs to be a 1D view
/// \tparam BViewType: Input type for the right-hand side and the solution,
/// needs to be a 1D view
///
/// \param d [in]: n diagonal elements of the diagonal matrix D
/// \param e [in]: n-1 upper/lower diagonal elements of the diagonal matrix E
/// \param b [inout]: right-hand side and the solution, a rank 1 view
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgUplo, typename ArgAlgo>
struct SerialPttrs {
template <typename DViewType, typename EViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const DViewType &d, const EViewType &e, const BViewType &b);
};

} // namespace KokkosBatched

#include "KokkosBatched_Pttrs_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_PTTRS_HPP_
3 changes: 3 additions & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
#include "Test_Batched_SerialPttrf.hpp"
#include "Test_Batched_SerialPttrf_Real.hpp"
#include "Test_Batched_SerialPttrf_Complex.hpp"
#include "Test_Batched_SerialPttrs.hpp"
#include "Test_Batched_SerialPttrs_Real.hpp"
#include "Test_Batched_SerialPttrs_Complex.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
Loading

0 comments on commit 69811cc

Please sign in to comment.