diff --git a/src/api.cpp b/src/api.cpp index a02ed551..8713eec1 100644 --- a/src/api.cpp +++ b/src/api.cpp @@ -315,26 +315,19 @@ void check_sparse_inputs(const std::vector &ref_sketches, } } -class SparseDist { - public: - SparseDist(float dist, long j) : _dist(dist), _j(j) {} - - float dist() const {return _dist; } - long j() const {return _j; } - - friend bool operator<(SparseDist const &a, SparseDist const &b) - { - return a._dist < b._dist; - } - friend bool operator==(SparseDist const &a, SparseDist const &b) - { - return a._dist == b._dist; - } - - private: - float _dist; - long _j; +// Struct that allows sorting by dist but also keeping index +struct SparseDist { + float dist; + long j; }; +bool operator<(SparseDist const &a, SparseDist const &b) +{ + return a.dist < b.dist; +} +bool operator==(SparseDist const &a, SparseDist const &b) +{ + return a.dist == b.dist; +} sparse_coo query_db_sparse(std::vector &ref_sketches, const std::vector &kmer_lengths, @@ -398,8 +391,9 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, } } - if (min_dists.size() < kNN || row_dist < min_dists.top().dist()) { - min_dists.push(SparseDist(row_dist, j)); + if (min_dists.size() < kNN || row_dist < min_dists.top().dist) { + SparseDist new_min = {row_dist, j}; + min_dists.push(new_min); if (min_dists.size() > kNN) { min_dists.pop(); } @@ -409,8 +403,8 @@ sparse_coo query_db_sparse(std::vector &ref_sketches, std::fill_n(i_vec.begin() + offset, kNN, i); for (int k = 0; k < kNN; ++k) { SparseDist entry = min_dists.top(); - j_vec[offset + k] = entry.j(); - dists[offset + k] = entry.dist(); + j_vec[offset + k] = entry.j; + dists[offset + k] = entry.dist; min_dists.pop(); }