-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve batched serial trsm implementation and testing #2432
Conversation
@yasahi-hpc I'm using your PR to test a change to our CI infrastructure (related to this |
Hi @cwpearson |
@yasahi-hpc looks like a legit failure in |
d7da4c6
to
5950448
Compare
Sorry for the mistake. It should be fine now. For some reason, I need to relax the tolerance for Intel CPU build. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some clean-ups required to simplify the code and future maintenance.
#pragma unroll | ||
#endif | ||
for (int j = 0; j < jend; ++j) | ||
B2[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
B2[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1]; | |
B2[i * bs0 + j * bs1] -= (do_conj ? Kokkos::ArithTraits<ValueType>::conj(a21[i * as0]) * b1t[j * bs1] : a21[i * as0] * b1t[j * bs1]); |
With this you can remove a lot of the code duplication introduce for the conjugate type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After checking Kokkos::ArithTraits
, since we define it to be a no-op for non-complex floating points, you can just call on both complex and non-complex numbers, it will do the right thing.
@@ -83,8 +112,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::i | |||
template <> | |||
template <typename ScalarType, typename ValueType> | |||
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke( | |||
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, | |||
const int as0, const int as1, | |||
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha, | |
const bool use_unit_diag, const bool /* do_conj */, const int m, const int n, const ScalarType alpha, |
It looks more like it is never used... in that case just omit the variable name in the signature.
if (!use_unit_diag) { | ||
const ValueType alpha11 = A[p * as0 + p * as1]; | ||
if (!use_unit_diag) { | ||
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]); | |
const ValueType alpha11 = (do_conj ? Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]) : A[p * as0 + p * as1]); |
Simplify the code.
const ValueType alpha11 = A[p * as0 + p * as1]; | ||
if (do_conj) { | ||
if (!use_unit_diag) { | ||
const ValueType alpha11 = Kokkos::ArithTraits<ValueType>::conj(A[p * as0 + p * as1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
|
||
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) | ||
#pragma unroll | ||
#endif | ||
for (int j = 0; j < jend; ++j) B0[i * bs0 + j * bs1] -= a01[i * as0] * b1t[j * bs1]; | ||
for (int j = 0; j < jend; ++j) | ||
B0[i * bs0 + j * bs1] -= Kokkos::ArithTraits<ValueType>::conj(a01[i * as0]) * b1t[j * bs1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
@@ -189,8 +240,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::i | |||
template <> | |||
template <typename ScalarType, typename ValueType> | |||
KOKKOS_INLINE_FUNCTION int SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke( | |||
const bool use_unit_diag, const int m, const int n, const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, | |||
const int as0, const int as1, | |||
const bool use_unit_diag, [[maybe_unused]] const bool do_conj, const int m, const int n, const ScalarType alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
5950448
to
af95556
Compare
auto info = KokkosBatched::Impl::checkTrsmInput<Side::Right>(A, B); | ||
if (info) return info; | ||
|
||
return KokkosBatched::Impl::SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you are doing: T^tX=B <=> X^tT=B^t so you can solve with T on the right and no transpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If your T
represents A
in the code, yes, it is what we are doing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me and I think I saw all the combination get tested so should be good
This PR aims at improving the implementation and testing of serial Trsm.
Impl
namespace.Left/Right
,Upper/Lower
,Non-Trans/Trans/ConjTrans
, andUnit/Non-Unit
X
is a rank 1View
. UseTrsv
for this caseLeft/Right
,Upper/Lower
,Non-Trans/Trans/ConjTrans
, andUnit/Non-Unit
As a TO DO task, we need to add a
ConjTrans
implementation of blocked version which requires a little more investigationEdited 25/Nov
X
is a rank 1View