Skip to content

Commit

Permalink
Add index change to matrix naive to take Eigen::Index by default. Thi…
Browse files Browse the repository at this point in the history
…s makes index type consistent across the entire library now.
  • Loading branch information
JamesYang007 committed Jun 10, 2024
1 parent 3b53b31 commit 9f2e450
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions adelie/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,13 @@ def interaction(
pairs_seen.add((key, val))
if len(pairs) <= 0:
raise ValueError("No valid pairs exist. There must be at least one valid pair.")
pairs = np.array(pairs, dtype=np.int32)
pairs = np.array(pairs, dtype=int)

class _interaction(core_base, py_base):
def __init__(self):
self._mat = np.array(mat, copy=copy)
self._pairs = pairs
self._levels = np.array(levels, copy=True, dtype=np.int32)
self._levels = np.array(levels, copy=True, dtype=int)
core_base.__init__(self, self._mat, self._pairs, self._levels, n_threads)
py_base.__init__(self, n_threads=n_threads)

Expand Down Expand Up @@ -822,7 +822,7 @@ def one_hot(
class _one_hot(core_base, py_base):
def __init__(self):
self._mat = np.array(mat, copy=copy)
self._levels = np.array(levels, copy=True, dtype=np.int32)
self._levels = np.array(levels, copy=True, dtype=int)
core_base.__init__(self, self._mat, self._levels, n_threads)
py_base.__init__(self, n_threads=n_threads)

Expand Down Expand Up @@ -1318,7 +1318,7 @@ def subset(
class _subset(core_base, py_base):
def __init__(self):
self._mat = mat
self._indices = np.array(indices, copy=True, dtype=np.int32)
self._indices = np.array(indices, copy=True, dtype=int)
core_base.__init__(self, self._mat, self._indices, n_threads)
py_base.__init__(self, n_threads=n_threads)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace adelie_core {
namespace matrix {

template <class ValueType, class IndexType=int>
template <class ValueType, class IndexType=Eigen::Index>
class MatrixNaiveBase
{
protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class MatrixNaiveCConcatenate: public MatrixNaiveBase<ValueType>
const auto slice = _slice_map[j_curr];
auto& mat = *_mat_list[slice];
const auto index = _index_map[j_curr];
const int q_curr = std::min(mat.cols()-index, q-n_processed);
const int q_curr = std::min<int>(mat.cols()-index, q-n_processed);
mat.bmul(index, q_curr, v, weights, out.segment(n_processed, q_curr));
n_processed += q_curr;
}
Expand All @@ -163,7 +163,7 @@ class MatrixNaiveCConcatenate: public MatrixNaiveBase<ValueType>
const auto slice = _slice_map[j_curr];
auto& mat = *_mat_list[slice];
const auto index = _index_map[j_curr];
const int q_curr = std::min(mat.cols()-index, q-n_processed);
const int q_curr = std::min<int>(mat.cols()-index, q-n_processed);
mat.btmul(index, q_curr, v.segment(n_processed, q_curr), out);
n_processed += q_curr;
}
Expand Down

0 comments on commit 9f2e450

Please sign in to comment.