Skip to content

Commit

Permalink
KokkosSparse_merge_matrix.hpp: fix comparison signedness
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Oct 16, 2023
1 parent b8eb1bc commit 01e9773
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions sparse/impl/KokkosSparse_merge_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ namespace KokkosSparse::Impl {
// a joint index into a and b
template <typename AIndex, typename BIndex>
struct MergeMatrixPosition {
using a_index_type = AIndex;
using b_index_type = BIndex;

AIndex ai;
BIndex bi;
};
Expand Down Expand Up @@ -145,9 +148,10 @@ class MergeMatrixDiagonal {
KOKKOS_INLINE_FUNCTION
bool operator()(const size_type di) const {
position_type pos = diag_to_a_b(di);
if (size_type(pos.ai) >= a_.size()) {

if (pos.ai >= typename position_type::a_index_type(a_.size())) {
return true; // on the +a side out of matrix bounds is 1
} else if (size_type(pos.bi) >= b_.size()) {
} else if (pos.bi >= typename position_type::b_index_type(b_.size())) {
return false; // on the +b side out of matrix bounds is 0
} else {
return KokkosKernels::Impl::safe_gt(a_(pos.ai), b_(pos.bi));
Expand All @@ -161,9 +165,9 @@ class MergeMatrixDiagonal {
*/
KOKKOS_INLINE_FUNCTION
size_type size() const noexcept {
if (d_ <= a_.size() && d_ <= b_.size()) {
if (d_ <= size_type(a_.size()) && d_ <= size_type(b_.size())) {
return d_;
} else if (d_ > a_.size() && d_ > b_.size()) {
} else if (d_ > size_type(a_.size()) && d_ > size_type(b_.size())) {
// TODO: this returns nonsense if d_ happens to be outside the merge
// matrix
return a_.size() + b_.size() - d_;
Expand All @@ -182,8 +186,8 @@ class MergeMatrixDiagonal {
KOKKOS_INLINE_FUNCTION
position_type diag_to_a_b(const size_type &di) const noexcept {
position_type res;
res.ai = d_ < a_.size() ? (d_ - 1) - di : a_.size() - 1 - di;
res.bi = d_ < a_.size() ? di : d_ + di - a_.size();
res.ai = d_ < size_type(a_.size()) ? (d_ - 1) - di : a_.size() - 1 - di;
res.bi = d_ < size_type(a_.size()) ? di : d_ + di - a_.size();
return res;
}

Expand Down

0 comments on commit 01e9773

Please sign in to comment.