Skip to content

Commit

Permalink
simplify do_conj logic in trsm serial internal
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Dec 3, 2024
1 parent dbd8e60 commit af95556
Showing 1 changed file with 26 additions and 69 deletions.
95 changes: 26 additions & 69 deletions batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,64 +55,41 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::i
else {
if (alpha != one) KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);

if (do_conj) {
for (int p = 0; p < m; ++p) {
const int iend = m - p - 1, jend = n;
for (int p = 0; p < m; ++p) {
const int iend = m - p - 1, jend = n;

const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : nullptr;
const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : nullptr;

ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : nullptr;
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : nullptr;

if (!use_unit_diag) {
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]);
if (!use_unit_diag) {
const ValueType alpha11 =
(do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11;
}

for (int i = 0; i < iend; ++i)

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j)
B2[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1];
for (int j = 0; j < jend; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11;
}
} else {
for (int p = 0; p < m; ++p) {
const int iend = m - p - 1, jend = n;

const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : nullptr;

ValueType *KOKKOS_RESTRICT b1t = B + p * bs0, *KOKKOS_RESTRICT B2 = iend ? B + (p + 1) * bs0 : nullptr;

if (!use_unit_diag) {
const ValueType alpha11 = A[p * as0 + p * as1];
for (int i = 0; i < iend; ++i)

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11;
}

for (int i = 0; i < iend; ++i)

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) B2[i * bs0 + j * bs1] -= a21[i * as0] * b1t[j * bs1];
}
for (int j = 0; j < jend; ++j)
B2[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1]
: a21[i * as0] * b1t[j * bs1]);
}
}

return 0;
}

template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha,
const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) {
constexpr int mbAlgo = Algo::Trsm::Blocked::mb();
Expand Down Expand Up @@ -191,46 +168,26 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i
const ValueType *KOKKOS_RESTRICT a01 = A + p * as1;
ValueType *KOKKOS_RESTRICT b1t = B + p * bs0;

if (do_conj) {
if (!use_unit_diag) {
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11;
}

if (p > 0) { // Note: A workaround to produce correct results for
// complex<double> with Intel-18.2.199
for (int i = 0; i < iend; ++i)

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j)
B0[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a01[i * as0]) * b1t[j * bs1];
}

} else {
if (!use_unit_diag) {
const ValueType alpha11 = A[p * as0 + p * as1];
if (!use_unit_diag) {
const ValueType alpha11 =
(do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11;
}
for (int j = 0; j < n; ++j) b1t[j * bs1] = b1t[j * bs1] / alpha11;
}

if (p > 0) { // Note: A workaround to produce correct results for
// complex<double> with Intel-18.2.199
for (int i = 0; i < iend; ++i)
if (p > 0) { // Note: A workaround to produce correct results for
// complex<double> with Intel-18.2.199
for (int i = 0; i < iend; ++i)

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= a01[i * as0] * b1t[j * bs1];
}
for (int j = 0; j < jend; ++j)
B0[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a01[i * as0]) * b1t[j * bs1]
: a01[i * as0] * b1t[j * bs1]);
}
}
}
Expand All @@ -240,7 +197,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i
template <>
template <typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha,
const bool use_unit_diag, const bool /*do_conj*/, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT B, const int bs0, const int bs1) {
const ScalarType one(1.0), zero(0.0), minus_one(-1.0);
Expand Down

0 comments on commit af95556

Please sign in to comment.