Skip to content

Commit

Permalink
feat: support creating partition tables with different window widths …
Browse files Browse the repository at this point in the history
…(PROOF-893) (#154)

* add stub for new partition table function

* rework partition table

* rework partition table code

* rework partition table

* rework partition table

* rework partition table

* add tests

* reformat
  • Loading branch information
rnburn authored Jul 16, 2024
1 parent db722fc commit 7a83858
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ make_in_memory_partition_table_accessor_impl(basct::cspan<T> generators,
generators = generators_data;
}
memmg::managed_array<U> sums{partition_table_size_v * num_partitions, alloc};
compute_partition_table<U, T>(sums, generators);
compute_partition_table<U, T>(sums, 16, generators);
return std::make_unique<in_memory_partition_table_accessor<U>>(std::move(sums));
}

Expand Down
43 changes: 27 additions & 16 deletions sxt/multiexp/pippenger2/partition_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,21 @@ template <class U, bascrv::element T>
static_cast<U>(e);
T{u};
}
CUDA_CALLABLE void compute_partition_table_slice(U* __restrict__ sums,
CUDA_CALLABLE void compute_partition_table_slice(U* __restrict__ sums, unsigned window_width,
const T* __restrict__ generators) noexcept {
assert(0u < window_width && window_width <= 32u);

sums[0] = static_cast<U>(T::identity());

// single entry sums
for (unsigned i = 0; i < 16; ++i) {
sums[1 << i] = static_cast<U>(generators[i]);
for (unsigned i = 0; i < window_width; ++i) {
sums[1u << i] = static_cast<U>(generators[i]);
}

// multi-entry sums
for (unsigned k = 2; k <= 16; ++k) {
unsigned partition = std::numeric_limits<uint16_t>::max() >> (16u - k);
auto partition_last = partition << (16u - k);
for (unsigned k = 2; k <= window_width; ++k) {
unsigned partition = std::numeric_limits<uint32_t>::max() >> (32u - k);
auto partition_last = partition << (window_width - k);

// iterate over all possible permutations with k bits set to 1
// until we reach partition_last
Expand All @@ -70,31 +72,40 @@ CUDA_CALLABLE void compute_partition_table_slice(U* __restrict__ sums,
// compute_partition_table
//--------------------------------------------------------------------------------------------------
/**
* Compute table of sums used for Pippenger's partition step with a width of 16. Each slice of the
* table contains all possible sums of a group of 16 generators.
* Compute table of sums used for Pippenger's partition step with a given window width. Each
* slice of the table contains all possible sums of a group of `window_width` generators.
*/
template <class U, bascrv::element T>
requires requires(const U& u, const T& e) {
static_cast<U>(e);
T{u};
}
void compute_partition_table(basct::span<U> sums, basct::cspan<T> generators) noexcept {
void compute_partition_table(basct::span<U> sums, unsigned window_width,
basct::cspan<T> generators) noexcept {
auto table_size = 1u << window_width;
SXT_DEBUG_ASSERT(
// clang-format off
sums.size() == partition_table_size_v * generators.size() / 16u &&
generators.size() % 16 == 0
0u < window_width &&
sums.size() == table_size * generators.size() / window_width &&
generators.size() % window_width == 0
// clang-format on
);
auto n = generators.size() / 16u;
auto n = generators.size() / window_width;
for (unsigned i = 0; i < n; ++i) {
auto sums_slice = sums.subspan(i * partition_table_size_v, partition_table_size_v);
auto generators_slice = generators.subspan(i * 16u, 16u);
compute_partition_table_slice(sums_slice.data(), generators_slice.data());
auto sums_slice = sums.subspan(i * table_size, table_size);
auto generators_slice = generators.subspan(i * window_width, window_width);
compute_partition_table_slice(sums_slice.data(), window_width, generators_slice.data());
}
}

template <bascrv::element T>
void compute_partition_table(basct::span<T> sums, basct::cspan<T> generators) noexcept {
compute_partition_table<T, T>(sums, generators);
compute_partition_table<T, T>(sums, 16u, generators);
}

template <bascrv::element T>
void compute_partition_table(basct::span<T> sums, unsigned window_width,
basct::cspan<T> generators) noexcept {
compute_partition_table<T, T>(sums, window_width, generators);
}
} // namespace sxt::mtxpp2
26 changes: 25 additions & 1 deletion sxt/multiexp/pippenger2/partition_table.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ TEST_CASE("we can compute a slice of the partition table") {
std::vector<E> sums(1u << 16);
std::vector<E> generators = {1u, 2u, 3u, 4u, 5u, 6u, 7u, 8u,
9u, 10u, 11u, 12u, 13u, 14u, 15u, 16u};
compute_partition_table_slice(sums.data(), generators.data());
compute_partition_table_slice(sums.data(), 16u, generators.data());
for (unsigned i = 0; i < sums.size(); ++i) {
auto expected = E::identity();
basbt::for_each_bit(reinterpret_cast<uint8_t*>(&i), sizeof(i), [&](unsigned index) noexcept {
Expand All @@ -54,3 +54,27 @@ TEST_CASE("we can compute the full partition table") {
REQUIRE(sums[1] == generators[0]);
REQUIRE(sums[partition_table_size_v + 1] == generators[16]);
}

TEST_CASE("we can compute a slice of the partition table with a width of 1") {
using E = bascrv::element97;
std::vector<E> generators = {1u, 2u, 3u, 4u, 5u, 6u, 7u, 8u,
9u, 10u, 11u, 12u, 13u, 14u, 15u, 16u};
std::vector<E> sums(2 * generators.size());
compute_partition_table<E>(sums, 1u, generators);
for (unsigned i = 0; i < generators.size(); ++i) {
REQUIRE(sums[2 * i] == 0u);
REQUIRE(sums[2 * i + 1] == generators[i]);
}
}

TEST_CASE("we can compute a slice of the partition table with a width of 2") {
using E = bascrv::element97;
std::vector<E> generators = {1u, 2u, 3u, 4u};
std::vector<E> sums(4 * generators.size() / 2);
compute_partition_table<E>(sums, 2u, generators);
std::vector<E> expected = {
0, generators[0], generators[1], generators[0].value + generators[1].value,
0, generators[2], generators[3], generators[2].value + generators[3].value,
};
REQUIRE(sums == expected);
}

0 comments on commit 7a83858

Please sign in to comment.