Skip to content

Commit

Permalink
feat: use chunked reduction combination algorithm for varying size mu…
Browse files Browse the repository at this point in the history
…ltiexponentiations (PROOF-923) (#213)

* rework mx

* rework mx

* rework mx

* rework mx

* rework mx

* rework mx

* rework mx

* rework mx

* rework mx

* reformat
  • Loading branch information
rnburn authored Jan 10, 2025
1 parent 9157dff commit 736d4b4
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 262 deletions.
28 changes: 0 additions & 28 deletions sxt/multiexp/pippenger2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,6 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "combination",
test_deps = [
"//sxt/base/curve:example_element",
"//sxt/base/device:stream",
"//sxt/base/device:synchronization",
"//sxt/base/test:unit_test",
"//sxt/execution/schedule:scheduler",
"//sxt/memory/resource:managed_device_resource",
],
deps = [
"//sxt/algorithm/iteration:for_each",
"//sxt/base/container:span",
"//sxt/base/container:span_utility",
"//sxt/base/curve:element",
"//sxt/base/device:memory_utility",
"//sxt/base/error:assert",
"//sxt/base/macro:cuda_callable",
"//sxt/base/type:raw_stream",
"//sxt/execution/async:coroutine",
"//sxt/execution/device:synchronization",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:async_device_resource",
],
)

sxt_cc_component(
name = "combine_reduce",
test_deps = [
Expand Down Expand Up @@ -225,7 +199,6 @@ sxt_cc_component(
"//sxt/ristretto/random:element",
],
deps = [
":combination",
":combine_reduce",
":partition_product",
":partition_table_accessor",
Expand Down Expand Up @@ -260,7 +233,6 @@ sxt_cc_component(
"//sxt/ristretto/random:element",
],
deps = [
":combination",
":combine_reduce",
":partition_table_accessor",
":reduce",
Expand Down
17 changes: 0 additions & 17 deletions sxt/multiexp/pippenger2/combination.cc

This file was deleted.

74 changes: 0 additions & 74 deletions sxt/multiexp/pippenger2/combination.h

This file was deleted.

68 changes: 0 additions & 68 deletions sxt/multiexp/pippenger2/combination.t.cc

This file was deleted.

131 changes: 57 additions & 74 deletions sxt/multiexp/pippenger2/multiexponentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,72 +36,12 @@
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/device_resource.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/pippenger2/combination.h"
#include "sxt/multiexp/pippenger2/combine_reduce.h"
#include "sxt/multiexp/pippenger2/partition_product.h"
#include "sxt/multiexp/pippenger2/partition_table_accessor.h"
#include "sxt/multiexp/pippenger2/reduce.h"

namespace sxt::mtxpp2 {
//--------------------------------------------------------------------------------------------------
// multiexponentiate_product_step
//--------------------------------------------------------------------------------------------------
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<>
multiexponentiate_product_step(basct::span<T> products, basdv::stream& reduction_stream,
const partition_table_accessor<U>& accessor,
unsigned num_output_bytes, basct::cspan<uint8_t> scalars,
const basit::split_options& split_options) noexcept {
auto num_products = products.size();
auto n = scalars.size() / num_output_bytes;
auto window_width = accessor.window_width();

// compute bitwise products
//
// We split the work by groups of generators so that a single chunk will process
// all the outputs for those generators. This minimizes the amount of host->device
// copying we need to do for the table of precomputed sums.
auto [chunk_first, chunk_last] =
basit::split(basit::index_range{0, n}.chunk_multiple(window_width), split_options);
auto num_chunks = static_cast<size_t>(std::distance(chunk_first, chunk_last));
basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks",
num_products, n, num_chunks);

// handle no chunk case
if (num_chunks == 1) {
co_await async_partition_product<T>(products, accessor, scalars, 0);
co_return;
}

// handle multiple chunks
memmg::managed_array<T> partial_products{num_products * num_chunks, memr::get_pinned_resource()};
size_t chunk_index = 0;
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products,
rng.a(), rng.b(), basdv::get_device());
memmg::managed_array<T> partial_products_dev{num_products, memr::get_device_resource()};
auto scalars_slice =
scalars.subspan(num_output_bytes * rng.a(), rng.size() * num_output_bytes);
co_await async_partition_product<T>(partial_products_dev, accessor, scalars_slice, rng.a());
basdv::stream stream;
basdv::async_copy_device_to_host(
basct::subspan(partial_products, num_products * chunk_index, num_products),
partial_products_dev, stream);
++chunk_index;
co_await xendv::await_stream(stream);
});

// combine the partial products
basl::info("combining {} partial product chunks", num_chunks);
memr::async_device_resource resource{reduction_stream};
memmg::managed_array<T> partial_products_dev{partial_products.size(), &resource};
basdv::async_copy_host_to_device(partial_products_dev, partial_products, reduction_stream);
combine<T>(products, reduction_stream, partial_products_dev);
co_await xendv::await_stream(reduction_stream);
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate_impl_single_chunk
//--------------------------------------------------------------------------------------------------
Expand All @@ -116,6 +56,18 @@ multiexponentiate_impl_single_chunk(basct::span<T> res, const partition_table_ac
co_await combine_reduce<T>(res, element_num_bytes, partial_products_dev);
}

template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<> multiexponentiate_impl_single_chunk(basct::span<T> res,
const partition_table_accessor<U>& accessor,
basct::cspan<unsigned> output_bit_table,
basct::cspan<uint8_t> scalars, unsigned n,
unsigned num_products) noexcept {
memmg::managed_array<T> partial_products_dev{num_products, memr::get_device_resource()};
co_await async_partition_product<T>(partial_products_dev, accessor, scalars, 0);
co_await combine_reduce<T>(res, output_bit_table, partial_products_dev);
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate_impl
//--------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -178,6 +130,7 @@ xena::future<> multiexponentiate_impl(basct::span<T> res,
// combine the partial products
basl::info("combining {} partial product chunks", num_chunks);
co_await combine_reduce<T>(res, element_num_bytes, partial_products);
basl::info("complete multiexponentiation");
}

template <bascrv::element T, class U>
Expand All @@ -194,22 +147,52 @@ multiexponentiate_impl(basct::span<T> res, const partition_table_accessor<U>& ac
scalars.size() % num_output_bytes == 0
// clang-format on
);
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> products{num_products, &resource};
co_await multiexponentiate_product_step<T>(products, stream, accessor, num_output_bytes, scalars,
split_options);
if (num_outputs == 0) {
co_return;
}
auto n = scalars.size() / num_output_bytes;
auto window_width = accessor.window_width();

// reduce products
basl::info("reducing {} products to {} outputs", num_products, num_products);
memmg::managed_array<T> res_dev{num_outputs, &resource};
reduce_products<T>(res_dev, stream, output_bit_table, products);
products.reset();
basl::info("completed {} reductions", num_outputs);
// compute bitwise products
//
// We split the work by groups of generators so that a single chunk will process
// all the outputs for those generators. This minimizes the amount of host->device
// copying we need to do for the table of precomputed sums.
auto [chunk_first, chunk_last] =
basit::split(basit::index_range{0, n}.chunk_multiple(window_width), split_options);
auto num_chunks = static_cast<size_t>(std::distance(chunk_first, chunk_last));
basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks",
num_products, n, num_chunks);

// copy result
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
// handle special case of a single chunk
if (num_chunks == 1) {
co_return co_await multiexponentiate_impl_single_chunk(res, accessor, output_bit_table, scalars,
n, num_products);
}

// handle multiple chunks
memmg::managed_array<T> partial_products(num_products * num_chunks);
size_t chunk_index = 0;
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
co_return;
basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products,
rng.a(), rng.b(), basdv::get_device());
memmg::managed_array<T> partial_products_dev{num_products, memr::get_device_resource()};
auto scalars_slice =
scalars.subspan(num_output_bytes * rng.a(), rng.size() * num_output_bytes);
co_await async_partition_product<T>(partial_products_dev, accessor, scalars_slice, rng.a());
basdv::stream stream;
basdv::async_copy_device_to_host(
basct::subspan(partial_products, num_products * chunk_index, num_products),
partial_products_dev, stream);
++chunk_index;
co_await xendv::await_stream(stream);
});

// combine the partial products
basl::info("combining {} partial product chunks", num_chunks);
co_await combine_reduce<T>(res, output_bit_table, partial_products);
basl::info("complete multiexponentiation");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/pippenger2/combination.h"
#include "sxt/multiexp/pippenger2/combine_reduce.h"
#include "sxt/multiexp/pippenger2/partition_product.h"
#include "sxt/multiexp/pippenger2/partition_table_accessor.h"
Expand Down Expand Up @@ -158,6 +157,7 @@ xena::future<> multiexponentiate_impl(basct::span<T> res, const basit::split_opt
// combine the partial products
basl::info("combining {} partial product chunks", num_chunks);
co_await combine_reduce<T>(res, output_bit_table, partial_products);
basl::info("complete multiexponentiation");
}

//--------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 736d4b4

Please sign in to comment.