Skip to content

Commit

Permalink
feat: chunk combination-reduction step of variable length multi-expon…
Browse files Browse the repository at this point in the history
…entiations(PROOF-923) (#208)

* rework mx

* rework mx

* rework mx

* rework mx

* drop cruft

* reformat

* drop print statement

* rename
  • Loading branch information
rnburn authored Jan 6, 2025
1 parent 3bc8138 commit 1e54f6b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 66 deletions.
1 change: 1 addition & 0 deletions sxt/multiexp/pippenger2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ sxt_cc_component(
],
deps = [
":combination",
":combine_reduce",
":partition_table_accessor",
":reduce",
":variable_length_computation",
Expand Down
103 changes: 39 additions & 64 deletions sxt/multiexp/pippenger2/variable_length_multiexponentiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
#include "sxt/execution/device/for_each.h"
#include "sxt/execution/device/synchronization.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/async_device_resource.h"
#include "sxt/memory/resource/pinned_resource.h"
#include "sxt/multiexp/pippenger2/combination.h"
#include "sxt/multiexp/pippenger2/combine_reduce.h"
#include "sxt/multiexp/pippenger2/partition_product.h"
#include "sxt/multiexp/pippenger2/partition_table_accessor.h"
#include "sxt/multiexp/pippenger2/reduce.h"
Expand Down Expand Up @@ -80,19 +82,40 @@ async_partition_product_chunk(basct::span<T> products, const partition_table_acc
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate_product_step
// multiexponentiate_impl_single_chunk
//--------------------------------------------------------------------------------------------------
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<>
multiexponentiate_product_step(basct::span<T> products, basdv::stream& reduction_stream,
const partition_table_accessor<U>& accessor,
unsigned num_output_bytes, basct::cspan<unsigned> output_bit_table,
basct::cspan<unsigned> output_lengths, basct::cspan<uint8_t> scalars,
const basit::split_options& split_options) noexcept {
auto num_products = products.size();
auto n = scalars.size() / num_output_bytes;
xena::future<> multiexponentiate_impl_single_chunk(basct::span<T> res,
const partition_table_accessor<U>& accessor,
basct::cspan<unsigned> output_bit_table,
basct::cspan<unsigned> output_lengths,
basct::cspan<uint8_t> scalars, unsigned n,
unsigned num_products) noexcept {
memmg::managed_array<T> partial_products{num_products, memr::get_device_resource()};
co_await async_partition_product_chunk<T>(partial_products, accessor, output_bit_table,
output_lengths, scalars, 0, n);
co_await combine_reduce<T>(res, output_bit_table, partial_products);
}

//--------------------------------------------------------------------------------------------------
// multiexponentiate_impl
//--------------------------------------------------------------------------------------------------
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<> multiexponentiate_impl(basct::span<T> res, const basit::split_options& split_options,
const partition_table_accessor<U>& accessor,
basct::cspan<unsigned> output_bit_table,
basct::cspan<unsigned> output_lengths,
basct::cspan<uint8_t> scalars) noexcept {
auto num_outputs = res.size();
if (num_outputs == 0) {
co_return;
}
auto num_products = count_products(output_bit_table);
auto window_width = accessor.window_width();
auto num_output_bytes = basn::divide_up<size_t>(num_products, 8);
auto n = scalars.size() / num_output_bytes;

// compute bitwise products
//
Expand All @@ -105,15 +128,14 @@ multiexponentiate_product_step(basct::span<T> products, basdv::stream& reduction
basl::info("computing {} bitwise multiexponentiation products of length {} using {} chunks",
num_products, n, num_chunks);

// handle no chunk case
// handle special case of a single chunk
if (num_chunks == 1) {
co_await async_partition_product_chunk(products, accessor, output_bit_table, output_lengths,
scalars, 0, n);
co_return;
co_return co_await multiexponentiate_impl_single_chunk(
res, accessor, output_bit_table, output_lengths, scalars, n, num_products);
}

// handle multiple chunks
memmg::managed_array<T> partial_products{num_products * num_chunks, memr::get_pinned_resource()};
memmg::managed_array<T> partial_products(num_products * num_chunks);
size_t chunk_index = 0;
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](const basit::index_range& rng) noexcept -> xena::future<> {
Expand All @@ -135,54 +157,7 @@ multiexponentiate_product_step(basct::span<T> products, basdv::stream& reduction

// combine the partial products
basl::info("combining {} partial product chunks", num_chunks);
memr::async_device_resource resource{reduction_stream};
memmg::managed_array<T> partial_products_dev{partial_products.size(), &resource};
basdv::async_copy_host_to_device(partial_products_dev, partial_products, reduction_stream);
combine<T>(products, reduction_stream, partial_products_dev);
co_await xendv::await_stream(reduction_stream);
}

//--------------------------------------------------------------------------------------------------
// multiexponentiation_impl
//--------------------------------------------------------------------------------------------------
template <bascrv::element T, class U>
requires std::constructible_from<T, U>
xena::future<>
multiexponentiate_impl(basct::span<T> res, const partition_table_accessor<U>& accessor,
basct::cspan<unsigned> output_bit_table,
basct::cspan<unsigned> output_lengths, basct::cspan<uint8_t> scalars,
const basit::split_options& split_options) noexcept {
auto num_outputs = res.size();
auto num_products = count_products(output_bit_table);
auto num_output_bytes = basn::divide_up<size_t>(num_products, 8);
if (num_outputs == 0) {
co_return;
}
SXT_DEBUG_ASSERT(
// clang-format off
scalars.size() % num_output_bytes == 0
// clang-format on
);

// compute products
basdv::stream stream;
memr::async_device_resource resource{stream};
memmg::managed_array<T> products{num_products, &resource};
co_await multiexponentiate_product_step<T>(products, stream, accessor, num_output_bytes,
output_bit_table, output_lengths, scalars,
split_options);

// reduce products
basl::info("reducing {} products to {} outputs", num_products, num_products);
memmg::managed_array<T> res_dev{num_outputs, &resource};
reduce_products<T>(res_dev, stream, output_bit_table, products);
products.reset();
basl::info("completed {} reductions", num_outputs);

// copy result
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
basl::info("complete multiexponentiation");
co_await combine_reduce<T>(res, output_bit_table, partial_products);
}

//--------------------------------------------------------------------------------------------------
Expand All @@ -207,8 +182,8 @@ xena::future<> async_multiexponentiate(basct::span<T> res,
.max_chunk_size = 1024,
.split_factor = basdv::get_num_devices(),
};
return multiexponentiate_impl(res, accessor, output_bit_table, output_lengths, scalars,
split_options);
return multiexponentiate_impl(res, split_options, accessor, output_bit_table, output_lengths,
scalars);
}

//--------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ TEST_CASE("we can compute multiexponentiations with varying lengths") {
scalars.resize(32);
scalars[0] = 1;
scalars[16] = 1;
auto fut = multiexponentiate_impl<E>(res, *accessor, output_bit_table, output_lengths, scalars,
options);
auto fut = multiexponentiate_impl<E>(res, options, *accessor, output_bit_table, output_lengths,
scalars);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == generators[0].value + generators[16].value);
Expand Down

0 comments on commit 1e54f6b

Please sign in to comment.