Skip to content

Commit

Permalink
fix: remove unnecessary specialization for pttrs internal
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Aug 27, 2024
1 parent 705b6ec commit f5e406d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 44 deletions.
10 changes: 4 additions & 6 deletions batched/dense/impl/KokkosBatched_Pttrs_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ template <typename DViewType, typename EViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int checkPttrsInput(
[[maybe_unused]] const DViewType &d, [[maybe_unused]] const EViewType &e,
[[maybe_unused]] const BViewType &b) {
static_assert(Kokkos::is_view<DViewType>::value,
static_assert(Kokkos::is_view_v<DViewType>,
"KokkosBatched::pttrs: DViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<EViewType>::value,
static_assert(Kokkos::is_view_v<EViewType>,
"KokkosBatched::pttrs: EViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<BViewType>::value,
static_assert(Kokkos::is_view_v<BViewType>,
"KokkosBatched::pttrs: BViewType is not a Kokkos::View.");

static_assert(DViewType::rank == 1,
Expand Down Expand Up @@ -82,7 +82,6 @@ struct SerialPttrs<ArgUplo, Algo::Pttrs::Unblocked> {

using ScalarType = typename DViewType::non_const_value_type;
int n = d.extent(0);
int ldb = b.extent(0);

if (n == 1) {
const ScalarType alpha = 1.0 / d(0);
Expand All @@ -92,8 +91,7 @@ struct SerialPttrs<ArgUplo, Algo::Pttrs::Unblocked> {
// Solve A * X = B using the factorization A = L*D*L**T,
// overwriting each right hand side vector with its solution.
return SerialPttrsInternal<ArgUplo, Algo::Pttrs::Unblocked>::invoke(
n, d.data(), d.stride(0), e.data(), e.stride(0), b.data(), b.stride(0),
ldb);
n, d.data(), d.stride(0), e.data(), e.stride(0), b.data(), b.stride(0));
}
};
} // namespace KokkosBatched
Expand Down
56 changes: 18 additions & 38 deletions batched/dense/impl/KokkosBatched_Pttrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,25 @@ struct SerialPttrsInternal {
KOKKOS_INLINE_FUNCTION static int invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const ValueType *KOKKOS_RESTRICT e, const int es0,
ValueType *KOKKOS_RESTRICT b, const int bs0, const int ldb);
ValueType *KOKKOS_RESTRICT b, const int bs0);

template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0,
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0,
const int ldb);
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0);
};

///
/// Real matrix
///

template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int
SerialPttrsInternal<Uplo::Lower, Algo::Pttrs::Unblocked>::invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const ValueType *KOKKOS_RESTRICT e, const int es0,
ValueType *KOKKOS_RESTRICT b, const int bs0, const int ldb) {
// Solve A * X = B using the factorization L * D * L**T
for (int i = 1; i < n; i++) {
b[i * bs0] -= e[(i - 1) * es0] * b[(i - 1) * bs0];
}

b[(n - 1) * bs0] /= d[(n - 1) * ds0];

for (int i = n - 2; i >= 0; i--) {
b[i * bs0] = b[i * bs0] / d[i * ds0] - b[(i + 1) * bs0] * e[i * es0];
}

return 0;
}

template <>
template <typename ArgUplo, typename AlgoType>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int
SerialPttrsInternal<Uplo::Upper, Algo::Pttrs::Unblocked>::invoke(
KOKKOS_INLINE_FUNCTION int SerialPttrsInternal<ArgUplo, AlgoType>::invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const ValueType *KOKKOS_RESTRICT e, const int es0,
ValueType *KOKKOS_RESTRICT b, const int bs0, const int ldb) {
ValueType *KOKKOS_RESTRICT b, const int bs0) {
// Solve A * X = B using the factorization L * D * L**T
for (int i = 1; i < n; i++) {
b[i * bs0] -= e[(i - 1) * es0] * b[(i - 1) * bs0];
Expand All @@ -88,37 +65,39 @@ SerialPttrsInternal<Uplo::Upper, Algo::Pttrs::Unblocked>::invoke(
/// Complex matrix
///

template <>
template <typename ArgUplo, typename AlgoType>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int
SerialPttrsInternal<Uplo::Lower, Algo::Pttrs::Unblocked>::invoke(
KOKKOS_INLINE_FUNCTION int SerialPttrsInternal<ArgUplo, AlgoType>::invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0,
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0,
const int ldb) {
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0) {
// Solve A * X = B using the factorization L * D * L**H
for (int i = 1; i < n; i++) {
b[i * bs0] -= e[(i - 1) * es0] * b[(i - 1) * bs0];
auto tmp_e = std::is_same_v<ArgUplo, Uplo::Upper>
? Kokkos::conj(e[(i - 1) * es0])
: e[(i - 1) * es0];
b[i * bs0] -= tmp_e * b[(i - 1) * bs0];
}

b[(n - 1) * bs0] /= d[(n - 1) * ds0];

for (int i = n - 2; i >= 0; i--) {
b[i * bs0] =
b[i * bs0] / d[i * ds0] - b[(i + 1) * bs0] * Kokkos::conj(e[i * es0]);
auto tmp_e = std::is_same_v<ArgUplo, Uplo::Lower> ? Kokkos::conj(e[i * es0])
: e[i * es0];
b[i * bs0] = b[i * bs0] / d[i * ds0] - b[(i + 1) * bs0] * tmp_e;
}

return 0;
}

/*
template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int
SerialPttrsInternal<Uplo::Upper, Algo::Pttrs::Unblocked>::invoke(
const int n, const ValueType *KOKKOS_RESTRICT d, const int ds0,
const Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0,
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0,
const int ldb) {
Kokkos::complex<ValueType> *KOKKOS_RESTRICT b, const int bs0) {
// Solve A * X = B using the factorization A = U**H * D * U
for (int i = 1; i < n; i++) {
b[i * bs0] -= Kokkos::conj(e[(i - 1) * es0]) * b[(i - 1) * bs0];
Expand All @@ -132,6 +111,7 @@ SerialPttrsInternal<Uplo::Upper, Algo::Pttrs::Unblocked>::invoke(
return 0;
}
*/

} // namespace KokkosBatched

Expand Down

0 comments on commit f5e406d

Please sign in to comment.