Skip to content

Commit

Permalink
reimplementation for block_mat_transpose() with less registers
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Jan 2, 2025
1 parent b8d6b35 commit c01b725
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 2 deletions.
64 changes: 64 additions & 0 deletions include/rxmesh/kernels/rxmesh_queries.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,70 @@
namespace rxmesh {
namespace detail {

template <uint32_t rowOffset, uint32_t blockThreads>
__device__ __forceinline__ void block_mat_transpose(
const uint32_t num_rows,
const uint32_t num_cols,
uint16_t* mat,
uint16_t* output,
uint16_t* temp_size, // size = num_cols +1
uint16_t* temp_local, // size = num_cols
const uint32_t* row_active_mask,
int shift)
{
const uint32_t nnz = num_rows * rowOffset;

const uint32_t half_nnz = DIVIDE_UP(nnz, 2);

const uint32_t* mat_32 = reinterpret_cast<const uint32_t*>(mat);

fill_n<blockThreads>(temp_size, num_cols + 1, uint16_t(0));
fill_n<blockThreads>(temp_local, num_cols, uint16_t(0));
__syncthreads();

for (int i = threadIdx.x; i < half_nnz; i += blockThreads) {
const uint32_t c = mat_32[i];
const uint16_t c0 = detail::extract_low_bits<16>(c);
const uint16_t c1 = detail::extract_high_bits<16>(c);

assert(c0 < num_cols);

atomicAdd(temp_size + c0, 1u);

if (i * 2 + 1 < nnz) {
assert(c1 < num_cols);
atomicAdd(temp_size + c1, 1u);
}
}
__syncthreads();

cub_block_exclusive_sum<uint16_t, blockThreads>(temp_size, num_cols);


for (int i = threadIdx.x; i < nnz; i += blockThreads) {
const uint16_t col_id = mat[i];

assert(col_id < num_cols);

const uint16_t local_id = atomicAdd(temp_local + col_id, 1u);

const uint16_t prefix = temp_size[col_id];

assert(local_id < temp_size[col_id + 1] - temp_size[col_id]);

const uint16_t row_id = uint16_t(i) / rowOffset;

output[local_id + prefix] = row_id;
}

__syncthreads();

for (int i = threadIdx.x; i < num_cols + 1; i += blockThreads) {
mat[i] = temp_size[i];
}
}


template <uint32_t rowOffset,
uint32_t blockThreads,
int itemPerThread = TRANSPOSE_ITEM_PER_THREAD>
Expand Down
36 changes: 34 additions & 2 deletions tests/RXMesh_test/test_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "rxmesh/kernels/collective.cuh"
#include "rxmesh/kernels/rxmesh_queries.cuh"
#include "rxmesh/kernels/shmem_allocator.cuh"
#include "rxmesh/kernels/util.cuh"
#include "rxmesh/util/macros.h"
#include "rxmesh/util/util.h"
Expand All @@ -20,6 +21,31 @@ __global__ static void test_block_mat_transpose_kernel(uint16_t* d_src,
num_rows, num_cols, d_src, d_output, d_row_bitmask, 0);
}

template <uint32_t rowOffset, uint32_t blockThreads>
__global__ static void test_block_mat_transpose_kernel_shmem(
uint16_t* d_src,
const uint32_t num_rows,
const uint32_t num_cols,
uint16_t* d_output,
uint32_t* d_row_bitmask)
{
using namespace rxmesh;

ShmemAllocator shrd_alloc;

uint16_t* s_temp_size = shrd_alloc.alloc<uint16_t>(num_cols + 1);
uint16_t* s_temp_local = shrd_alloc.alloc<uint16_t>(num_cols);

rxmesh::detail::block_mat_transpose<rowOffset, blockThreads>(num_rows,
num_cols,
d_src,
d_output,
s_temp_size,
s_temp_local,
d_row_bitmask,
0);
}

template <typename T, uint32_t blockThreads>
__global__ static void test_block_exclusive_sum_kernel(T* d_src,
const uint32_t size)
Expand Down Expand Up @@ -250,8 +276,14 @@ TEST(Util, BlockMatrixTranspose)
cudaMemcpyHostToDevice));


test_block_mat_transpose_kernel<rowOffset, threads, item_per_thread>
<<<blocks, threads, numRows * rowOffset * sizeof(uint32_t)>>>(
// test_block_mat_transpose_kernel<rowOffset, threads, item_per_thread>
// <<<blocks, threads, 0>>>(d_src, numRows, numCols, d_offset,
// d_bitmask);

const size_t shmem = 2 * (numCols + 1) * sizeof(uint16_t) +
2 * ShmemAllocator::default_alignment;
test_block_mat_transpose_kernel_shmem<rowOffset, threads>
<<<blocks, threads, shmem>>>(
d_src, numRows, numCols, d_offset, d_bitmask);

CUDA_ERROR(cudaDeviceSynchronize());
Expand Down

0 comments on commit c01b725

Please sign in to comment.