diff --git a/sxt/multiexp/pippenger2/BUILD b/sxt/multiexp/pippenger2/BUILD index ea6cb418..1747d57a 100644 --- a/sxt/multiexp/pippenger2/BUILD +++ b/sxt/multiexp/pippenger2/BUILD @@ -14,32 +14,6 @@ sxt_cc_component( ], ) -sxt_cc_component( - name = "combination", - test_deps = [ - "//sxt/base/curve:example_element", - "//sxt/base/device:stream", - "//sxt/base/device:synchronization", - "//sxt/base/test:unit_test", - "//sxt/execution/schedule:scheduler", - "//sxt/memory/resource:managed_device_resource", - ], - deps = [ - "//sxt/algorithm/iteration:for_each", - "//sxt/base/container:span", - "//sxt/base/container:span_utility", - "//sxt/base/curve:element", - "//sxt/base/device:memory_utility", - "//sxt/base/error:assert", - "//sxt/base/macro:cuda_callable", - "//sxt/base/type:raw_stream", - "//sxt/execution/async:coroutine", - "//sxt/execution/device:synchronization", - "//sxt/memory/management:managed_array", - "//sxt/memory/resource:async_device_resource", - ], -) - sxt_cc_component( name = "combine_reduce", test_deps = [ @@ -225,7 +199,6 @@ sxt_cc_component( "//sxt/ristretto/random:element", ], deps = [ - ":combination", ":combine_reduce", ":partition_product", ":partition_table_accessor", @@ -260,7 +233,6 @@ sxt_cc_component( "//sxt/ristretto/random:element", ], deps = [ - ":combination", ":combine_reduce", ":partition_table_accessor", ":reduce", diff --git a/sxt/multiexp/pippenger2/combination.cc b/sxt/multiexp/pippenger2/combination.cc deleted file mode 100644 index a4e8e240..00000000 --- a/sxt/multiexp/pippenger2/combination.cc +++ /dev/null @@ -1,17 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/multiexp/pippenger2/combination.h" diff --git a/sxt/multiexp/pippenger2/combination.h b/sxt/multiexp/pippenger2/combination.h deleted file mode 100644 index d5b0d0da..00000000 --- a/sxt/multiexp/pippenger2/combination.h +++ /dev/null @@ -1,74 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "sxt/algorithm/iteration/for_each.h" -#include "sxt/base/container/span.h" -#include "sxt/base/container/span_utility.h" -#include "sxt/base/curve/element.h" -#include "sxt/base/device/memory_utility.h" -#include "sxt/base/error/assert.h" -#include "sxt/base/macro/cuda_callable.h" -#include "sxt/base/type/raw_stream.h" -#include "sxt/execution/async/coroutine.h" -#include "sxt/execution/device/synchronization.h" -#include "sxt/memory/management/managed_array.h" -#include "sxt/memory/resource/async_device_resource.h" - -namespace sxt::mtxpp2 { -//-------------------------------------------------------------------------------------------------- -// combine_impl -//-------------------------------------------------------------------------------------------------- -template -CUDA_CALLABLE void combine_impl(T* __restrict__ reduction, const T* __restrict__ elements, - unsigned step, unsigned reduction_size) noexcept { - T res = elements[0]; - for (unsigned i = 1; i < reduction_size; ++i) { - auto e = elements[step * i]; - add_inplace(res, e); - } - *reduction = res; -} - -//-------------------------------------------------------------------------------------------------- -// combine -//-------------------------------------------------------------------------------------------------- -template -void combine(basct::span res, bast::raw_stream_t stream, basct::cspan elements) noexcept { - auto n = static_cast(res.size()); - SXT_DEBUG_ASSERT( - // clang-format off - elements.size() >= n && - elements.size() % n == 0 && - basdv::is_active_device_pointer(res.data()) && - basdv::is_active_device_pointer(elements.data()) - // clang-format on - ); - auto reduction_size = static_cast(elements.size() / n); - auto f = [ - // clang-format off - reductions = res.data(), - elements = elements.data(), - reduction_size = reduction_size - // clang-format on - ] __device__ - __host__(unsigned n, unsigned index) noexcept { - combine_impl(reductions + index, elements + index, n, reduction_size); - }; - algi::launch_for_each_kernel(stream, f, n); -} -} // namespace sxt::mtxpp2 diff --git a/sxt/multiexp/pippenger2/combination.t.cc b/sxt/multiexp/pippenger2/combination.t.cc deleted file mode 100644 index 18a52a4e..00000000 --- a/sxt/multiexp/pippenger2/combination.t.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU. - * - * Copyright 2024-present Space and Time Labs, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "sxt/multiexp/pippenger2/combination.h" - -#include - -#include "sxt/base/curve/example_element.h" -#include "sxt/base/device/stream.h" -#include "sxt/base/device/synchronization.h" -#include "sxt/base/test/unit_test.h" -#include "sxt/execution/schedule/scheduler.h" -#include "sxt/memory/resource/managed_device_resource.h" - -using namespace sxt; -using namespace sxt::mtxpp2; - -TEST_CASE("we can combine elements") { - using E = bascrv::element97; - - std::pmr::vector reduction{1, memr::get_managed_device_resource()}; - std::pmr::vector elements{memr::get_managed_device_resource()}; - - std::pmr::vector expected; - - basdv::stream stream; - - SECTION("we can reduce a single element") { - elements = {123u}; - combine(reduction, stream, elements); - basdv::synchronize_stream(stream); - - expected = {123u}; - REQUIRE(reduction == expected); - } - - SECTION("we can reduce two elements") { - elements = {3u, 4u}; - combine(reduction, stream, elements); - basdv::synchronize_stream(stream); - - expected = {7u}; - REQUIRE(reduction == expected); - } - - SECTION("we can reduce multiple elements") { - reduction.resize(3); - elements = {3u, 4u, 1u, 2u, 6u, 5u}; - combine(reduction, stream, elements); - basdv::synchronize_stream(stream); - - expected = {3 + 2, 4 + 6, 1 + 5}; - REQUIRE(reduction == expected); - } -} diff --git a/sxt/multiexp/pippenger2/multiexponentiation.h b/sxt/multiexp/pippenger2/multiexponentiation.h index e2242c92..edd525de 100644 --- a/sxt/multiexp/pippenger2/multiexponentiation.h +++ b/sxt/multiexp/pippenger2/multiexponentiation.h @@ -36,72 +36,12 @@ #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/device_resource.h" #include "sxt/memory/resource/pinned_resource.h" -#include "sxt/multiexp/pippenger2/combination.h" #include "sxt/multiexp/pippenger2/combine_reduce.h" #include "sxt/multiexp/pippenger2/partition_product.h" #include "sxt/multiexp/pippenger2/partition_table_accessor.h" #include "sxt/multiexp/pippenger2/reduce.h" namespace sxt::mtxpp2 { -//-------------------------------------------------------------------------------------------------- -// multiexponentiate_product_step -//-------------------------------------------------------------------------------------------------- -template - requires std::constructible_from -xena::future<> -multiexponentiate_product_step(basct::span products, basdv::stream& reduction_stream, - const partition_table_accessor& accessor, - unsigned num_output_bytes, basct::cspan scalars, - const basit::split_options& split_options) noexcept { - auto num_products = products.size(); - auto n = scalars.size() / num_output_bytes; - auto window_width = accessor.window_width(); - - // compute bitwise products - // - // We split the work by groups of generators so that a single chunk will process - // all the outputs for those generators. This minimizes the amount of host->device - // copying we need to do for the table of precomputed sums. - auto [chunk_first, chunk_last] = - basit::split(basit::index_range{0, n}.chunk_multiple(window_width), split_options); - auto num_chunks = static_cast(std::distance(chunk_first, chunk_last)); - basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks", - num_products, n, num_chunks); - - // handle no chunk case - if (num_chunks == 1) { - co_await async_partition_product(products, accessor, scalars, 0); - co_return; - } - - // handle multiple chunks - memmg::managed_array partial_products{num_products * num_chunks, memr::get_pinned_resource()}; - size_t chunk_index = 0; - co_await xendv::concurrent_for_each( - chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> { - basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products, - rng.a(), rng.b(), basdv::get_device()); - memmg::managed_array partial_products_dev{num_products, memr::get_device_resource()}; - auto scalars_slice = - scalars.subspan(num_output_bytes * rng.a(), rng.size() * num_output_bytes); - co_await async_partition_product(partial_products_dev, accessor, scalars_slice, rng.a()); - basdv::stream stream; - basdv::async_copy_device_to_host( - basct::subspan(partial_products, num_products * chunk_index, num_products), - partial_products_dev, stream); - ++chunk_index; - co_await xendv::await_stream(stream); - }); - - // combine the partial products - basl::info("combining {} partial product chunks", num_chunks); - memr::async_device_resource resource{reduction_stream}; - memmg::managed_array partial_products_dev{partial_products.size(), &resource}; - basdv::async_copy_host_to_device(partial_products_dev, partial_products, reduction_stream); - combine(products, reduction_stream, partial_products_dev); - co_await xendv::await_stream(reduction_stream); -} - //-------------------------------------------------------------------------------------------------- // multiexponentiate_impl_single_chunk //-------------------------------------------------------------------------------------------------- @@ -116,6 +56,18 @@ multiexponentiate_impl_single_chunk(basct::span res, const partition_table_ac co_await combine_reduce(res, element_num_bytes, partial_products_dev); } +template + requires std::constructible_from +xena::future<> multiexponentiate_impl_single_chunk(basct::span res, + const partition_table_accessor& accessor, + basct::cspan output_bit_table, + basct::cspan scalars, unsigned n, + unsigned num_products) noexcept { + memmg::managed_array partial_products_dev{num_products, memr::get_device_resource()}; + co_await async_partition_product(partial_products_dev, accessor, scalars, 0); + co_await combine_reduce(res, output_bit_table, partial_products_dev); +} + //-------------------------------------------------------------------------------------------------- // multiexponentiate_impl //-------------------------------------------------------------------------------------------------- @@ -178,6 +130,7 @@ xena::future<> multiexponentiate_impl(basct::span res, // combine the partial products basl::info("combining {} partial product chunks", num_chunks); co_await combine_reduce(res, element_num_bytes, partial_products); + basl::info("complete multiexponentiation"); } template @@ -194,22 +147,52 @@ multiexponentiate_impl(basct::span res, const partition_table_accessor& ac scalars.size() % num_output_bytes == 0 // clang-format on ); - basdv::stream stream; - memr::async_device_resource resource{stream}; - memmg::managed_array products{num_products, &resource}; - co_await multiexponentiate_product_step(products, stream, accessor, num_output_bytes, scalars, - split_options); + if (num_outputs == 0) { + co_return; + } + auto n = scalars.size() / num_output_bytes; + auto window_width = accessor.window_width(); - // reduce products - basl::info("reducing {} products to {} outputs", num_products, num_products); - memmg::managed_array res_dev{num_outputs, &resource}; - reduce_products(res_dev, stream, output_bit_table, products); - products.reset(); - basl::info("completed {} reductions", num_outputs); + // compute bitwise products + // + // We split the work by groups of generators so that a single chunk will process + // all the outputs for those generators. This minimizes the amount of host->device + // copying we need to do for the table of precomputed sums. + auto [chunk_first, chunk_last] = + basit::split(basit::index_range{0, n}.chunk_multiple(window_width), split_options); + auto num_chunks = static_cast(std::distance(chunk_first, chunk_last)); + basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks", + num_products, n, num_chunks); - // copy result - basdv::async_copy_device_to_host(res, res_dev, stream); - co_await xendv::await_stream(stream); + // handle special case of a single chunk + if (num_chunks == 1) { + co_return co_await multiexponentiate_impl_single_chunk(res, accessor, output_bit_table, scalars, + n, num_products); + } + + // handle multiple chunks + memmg::managed_array partial_products(num_products * num_chunks); + size_t chunk_index = 0; + co_await xendv::concurrent_for_each( + chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> { + co_return; + basl::info("computing {} multiproducts for generators [{}, {}] on device {}", num_products, + rng.a(), rng.b(), basdv::get_device()); + memmg::managed_array partial_products_dev{num_products, memr::get_device_resource()}; + auto scalars_slice = + scalars.subspan(num_output_bytes * rng.a(), rng.size() * num_output_bytes); + co_await async_partition_product(partial_products_dev, accessor, scalars_slice, rng.a()); + basdv::stream stream; + basdv::async_copy_device_to_host( + basct::subspan(partial_products, num_products * chunk_index, num_products), + partial_products_dev, stream); + ++chunk_index; + co_await xendv::await_stream(stream); + }); + + // combine the partial products + basl::info("combining {} partial product chunks", num_chunks); + co_await combine_reduce(res, output_bit_table, partial_products); basl::info("complete multiexponentiation"); } diff --git a/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h b/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h index 9b556a3d..2d279792 100644 --- a/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h +++ b/sxt/multiexp/pippenger2/variable_length_multiexponentiation.h @@ -30,7 +30,6 @@ #include "sxt/memory/management/managed_array.h" #include "sxt/memory/resource/async_device_resource.h" #include "sxt/memory/resource/pinned_resource.h" -#include "sxt/multiexp/pippenger2/combination.h" #include "sxt/multiexp/pippenger2/combine_reduce.h" #include "sxt/multiexp/pippenger2/partition_product.h" #include "sxt/multiexp/pippenger2/partition_table_accessor.h" @@ -158,6 +157,7 @@ xena::future<> multiexponentiate_impl(basct::span res, const basit::split_opt // combine the partial products basl::info("combining {} partial product chunks", num_chunks); co_await combine_reduce(res, output_bit_table, partial_products); + basl::info("complete multiexponentiation"); } //--------------------------------------------------------------------------------------------------