Skip to content

Commit

Permalink
feat: support fixed size reductions (PROOF-923) (#211)
Browse files Browse the repository at this point in the history
* support fixed size reductions

* fix offset
  • Loading branch information
rnburn authored Jan 8, 2025
1 parent 1e54f6b commit 69940af
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 17 deletions.
152 changes: 136 additions & 16 deletions sxt/multiexp/pippenger2/combine_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,31 @@
#include "sxt/memory/resource/async_device_resource.h"

namespace sxt::mtxpp2 {
//--------------------------------------------------------------------------------------------------
// combine_reduce_output
//--------------------------------------------------------------------------------------------------
template <bascrv::element T>
__device__ void combine_reduce_output(T* __restrict__ res, const T* __restrict__ partials,
unsigned num_partials, unsigned reduction_size,
unsigned bit_width) noexcept {
unsigned bit_index = bit_width - 1u;
--partials;
T e = *partials;
for (unsigned reduction_index = 1; reduction_index < reduction_size; ++reduction_index) {
auto ep = partials[reduction_index * num_partials];
add_inplace(e, ep);
}
for (; bit_index-- > 0u;) {
--partials;
double_element(e, e);
for (unsigned reduction_index = 0; reduction_index < reduction_size; ++reduction_index) {
auto ep = partials[reduction_index * num_partials];
add_inplace(e, ep);
}
}
*res = e;
}

//--------------------------------------------------------------------------------------------------
// combine_reduce_chunk_kernel
//--------------------------------------------------------------------------------------------------
Expand All @@ -56,22 +81,20 @@ __device__ void combine_reduce_chunk_kernel(T* __restrict__ res, const T* __rest
partials += bit_table_partial_sums[output_index] - partials_offset;

// combine reduce
unsigned bit_index = bit_width - 1u;
--partials;
T e = *partials;
for (unsigned reduction_index = 1; reduction_index < reduction_size; ++reduction_index) {
auto ep = partials[reduction_index * num_partials];
add_inplace(e, ep);
}
for (; bit_index-- > 0u;) {
--partials;
double_element(e, e);
for (unsigned reduction_index = 0; reduction_index < reduction_size; ++reduction_index) {
auto ep = partials[reduction_index * num_partials];
add_inplace(e, ep);
}
}
*res = e;
combine_reduce_output(res, partials, num_partials, reduction_size, bit_width);
}

template <bascrv::element T>
__device__ void combine_reduce_chunk_kernel(T* __restrict__ res, const T* __restrict__ partials,
unsigned bit_width, unsigned num_partials,
unsigned reduction_size,
unsigned output_index) noexcept {
// adjust pointers
res += output_index;
partials += bit_width;

// combine reduce
combine_reduce_output(res, partials, num_partials, reduction_size, bit_width);
}

//--------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -134,6 +157,58 @@ xena::future<> combine_reduce_chunk(basct::span<T> res,
co_await xendv::await_stream(stream);
}

template <bascrv::element T>
xena::future<> combine_reduce_chunk(basct::span<T> res, unsigned element_num_bytes,
basct::cspan<T> partial_products, unsigned reduction_size,
unsigned partials_offset) noexcept {
auto num_partials = partial_products.size() / reduction_size;
auto num_outputs = res.size();
auto bit_width = 8u * element_num_bytes;
auto slice_num_partials = num_outputs * bit_width;
SXT_RELEASE_ASSERT(
// clang-format off
num_outputs > 0 &&
res.size() == num_outputs &&
partial_products.size() == num_partials * reduction_size &&
partials_offset < num_partials
// clang-format on
);
basdv::stream stream;

// copy data
memr::async_device_resource resource{stream};
memmg::managed_array<T> partials_dev_data{&resource};
basct::cspan<T> partials_dev = partial_products;
if (!basdv::is_active_device_pointer(partials_dev.data())) {
partials_dev_data.resize(slice_num_partials * reduction_size);
co_await xendv::strided_copy_host_to_device<T>(partials_dev_data, stream, partial_products,
num_partials, slice_num_partials,
partials_offset);
partials_dev = partials_dev_data;
} else {
SXT_RELEASE_ASSERT(partial_products.size() == slice_num_partials * reduction_size);
}

// combine reduce chunk
memmg::managed_array<T> res_dev{num_outputs, &resource};
auto f = [
// clang-format off
num_partials = slice_num_partials,
bit_width = bit_width,
reduction_size = reduction_size,
res = res_dev.data(),
partials = partials_dev.data()
// clang-format on
] __device__
__host__(unsigned /*num_outputs*/, unsigned output_index) noexcept {
combine_reduce_output(res + output_index, partials + bit_width * (output_index + 1u),
num_partials, reduction_size, bit_width);
};
algi::launch_for_each_kernel(stream, f, num_outputs);
basdv::async_copy_device_to_host(res, res_dev, stream);
co_await xendv::await_stream(stream);
}

//--------------------------------------------------------------------------------------------------
// combine_reduce
//--------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -183,6 +258,41 @@ xena::future<> combine_reduce(basct::span<T> res, const basit::split_options& sp
});
}

template <bascrv::element T>
xena::future<> combine_reduce(basct::span<T> res, const basit::split_options& split_options,
unsigned element_num_bytes,
basct::cspan<T> partial_products) noexcept {
auto num_outputs = res.size();
auto bit_width = element_num_bytes * 8u;

if (res.empty()) {
co_return;
}

auto reduction_size = partial_products.size() / (num_outputs * bit_width);

// don't split if partials are already in device memory
if (basdv::is_active_device_pointer(partial_products.data())) {
co_return co_await combine_reduce_chunk(res, element_num_bytes, partial_products,
reduction_size, 0);
}

// split
auto [chunk_first, chunk_last] = basit::split(basit::index_range{0, num_outputs}, split_options);

// combine reduce
co_await xendv::concurrent_for_each(
chunk_first, chunk_last, [&](basit::index_range rng) noexcept -> xena::future<> {
auto output_first = rng.a();

auto res_chunk = res.subspan(output_first, rng.size());
auto partials_offset = output_first * bit_width;

co_await combine_reduce_chunk(res_chunk, element_num_bytes, partial_products,
reduction_size, partials_offset);
});
}

template <bascrv::element T>
xena::future<> combine_reduce(basct::span<T> res, basct::cspan<unsigned> output_bit_table,
basct::cspan<T> partial_products) noexcept {
Expand All @@ -192,4 +302,14 @@ xena::future<> combine_reduce(basct::span<T> res, basct::cspan<unsigned> output_
};
co_await combine_reduce(res, split_options, output_bit_table, partial_products);
}

template <bascrv::element T>
xena::future<> combine_reduce(basct::span<T> res, unsigned element_num_bytes,
basct::cspan<T> partial_products) noexcept {
basit::split_options split_options{
.max_chunk_size = 1024,
.split_factor = basdv::get_num_devices(),
};
co_await combine_reduce(res, split_options, element_num_bytes, partial_products);
}
} // namespace sxt::mtxpp2
95 changes: 94 additions & 1 deletion sxt/multiexp/pippenger2/combine_reduce.t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,100 @@
using namespace sxt;
using namespace sxt::mtxpp2;

TEST_CASE("we can combine and reduce partial products") {
TEST_CASE("we can combine and reduce partial products with outputs of fixed size") {
using E = bascrv::element97;

unsigned element_num_bytes = 1;
std::vector<E> partial_products;
std::vector<E> res(1);

SECTION("we handle no outputs") {
res.clear();
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products);
REQUIRE(fut.ready());
}

SECTION("we can combine and reduce a single element") {
partial_products.resize(8);
partial_products[0] = 3u;
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
}

SECTION("we can combine and reduce a two byte output") {
partial_products.resize(16);
partial_products[0] = 3u;
partial_products[8] = 4u;
element_num_bytes = 2;
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u + (1u << 8u) * 4u);
}

SECTION("we can combine and reduce elements already on device") {
partial_products.resize(8);
partial_products[0] = 3u;
std::pmr::vector<E> partial_products_dev{partial_products.begin(), partial_products.end(),
memr::get_managed_device_resource()};
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products_dev);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
}

SECTION("we can combine and reduce a single output with a reduction size of two") {
partial_products.resize(16);
partial_products[0] = 3u;
partial_products[8] = 4u;
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 7u);
}

SECTION("we can combine and reduce an output with a bit width of 8") {
partial_products.resize(8);
partial_products[0] = 3u;
partial_products[1] = 4u;
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 11u);
}

SECTION("we can combine and reduce multiple outputs") {
partial_products.resize(16);
partial_products[0] = 3u;
partial_products[8] = 4u;
res.resize(2);
auto fut = combine_reduce<E>(res, element_num_bytes, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
REQUIRE(res[1] == 4u);
}

SECTION("we can combine and reduce in chunks") {
partial_products.resize(16);
partial_products[0] = 3u;
partial_products[8] = 4u;
res.resize(2);
basit::split_options split_options{
.max_chunk_size = 1u,
.split_factor = 2u,
};
auto fut = combine_reduce<E>(res, split_options, element_num_bytes, partial_products);
xens::get_scheduler().run();
REQUIRE(fut.ready());
REQUIRE(res[0] == 3u);
REQUIRE(res[1] == 4u);
}
}

TEST_CASE("we can combine and reduce partial products with outputs of varying size") {
using E = bascrv::element97;

std::vector<unsigned> output_bit_table;
Expand Down

0 comments on commit 69940af

Please sign in to comment.