Skip to content

Commit

Permalink
use the new block_mat_transpose() implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahdhn committed Jan 3, 2025
1 parent c01b725 commit 4b58217
Show file tree
Hide file tree
Showing 11 changed files with 379 additions and 255 deletions.
4 changes: 2 additions & 2 deletions apps/SECHistogram/sec_kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once
#include "rxmesh/cavity_manager.cuh"
#include "rxmesh/cavity_manager2.cuh"

#include "link_condition.cuh"

Expand All @@ -12,7 +12,7 @@ __global__ static void sec(rxmesh::Context context,
using namespace rxmesh;
auto block = cooperative_groups::this_thread_block();
ShmemAllocator shrd_alloc;
CavityManager<blockThreads, CavityOp::EV> cavity(
CavityManager2<blockThreads, CavityOp::EV> cavity(
block, context, shrd_alloc, true);

const uint32_t pid = cavity.patch_id();
Expand Down
2 changes: 1 addition & 1 deletion apps/SECPriority/secp_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ __global__ static void secp(rxmesh::Context context,
auto block = cooperative_groups::this_thread_block();
ShmemAllocator shrd_alloc;
CavityManager2<blockThreads, CavityOp::EV> cavity(
block, context, shrd_alloc, 0, true);
block, context, shrd_alloc, true);

const uint32_t pid = cavity.patch_id();

Expand Down
1 change: 0 additions & 1 deletion include/rxmesh/cavity_manager2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ struct CavityManager2
CavityManager2(cooperative_groups::thread_block& block,
Context& context,
ShmemAllocator& shrd_alloc,
int iteration,
bool preserve_cavity,
bool allow_touching_cavities = true,
uint32_t current_p = 0);
Expand Down
1 change: 0 additions & 1 deletion include/rxmesh/cavity_manager_impl2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ __device__ __forceinline__ CavityManager2<blockThreads, cop>::CavityManager2(
cooperative_groups::thread_block& block,
Context& context,
ShmemAllocator& shrd_alloc,
int iteration,
bool preserve_cavity,
bool allow_touching_cavities,
uint32_t current_p)
Expand Down
54 changes: 31 additions & 23 deletions include/rxmesh/iterator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct Iterator
{
using LocalT = typename HandleT::LocalT;

__device__ Iterator()
__device__ __inline__ Iterator()
: m_context(Context()),
m_local_id(INVALID16),
m_patch_output(nullptr),
Expand All @@ -28,9 +28,9 @@ struct Iterator
{
}

__device__ Iterator(const Context& context,
const uint16_t local_id,
const uint32_t patch_id)
__device__ __inline__ Iterator(const Context& context,
const uint16_t local_id,
const uint32_t patch_id)
: m_context(context),
m_local_id(local_id),
m_patch_output(nullptr),
Expand All @@ -46,17 +46,17 @@ struct Iterator
{
}

__device__ Iterator(const Context& context,
const uint16_t local_id,
const LocalT* patch_output,
const uint16_t* patch_offset,
const uint32_t offset_size,
const uint32_t patch_id,
const uint32_t* output_owned_bitmask,
const LPHashTable& output_lp_hashtable,
const LPPair* s_table,
const PatchStash patch_stash,
int shift = 0)
__device__ __inline__ Iterator(const Context& context,
const uint16_t local_id,
const LocalT* patch_output,
const uint16_t* patch_offset,
const uint32_t offset_size,
const uint32_t patch_id,
const uint32_t* output_owned_bitmask,
const LPHashTable& output_lp_hashtable,
const LPPair* s_table,
const PatchStash patch_stash,
int shift = 0)
: m_context(context),
m_local_id(local_id),
m_patch_output(patch_output),
Expand All @@ -73,12 +73,12 @@ struct Iterator
Iterator(const Iterator& orig) = default;


__device__ uint16_t size() const
__device__ __inline__ uint16_t size() const
{
return m_end - m_begin;
}

__device__ HandleT operator[](const uint16_t i) const
__device__ __inline__ HandleT operator[](const uint16_t i) const
{
if (i + m_begin >= m_end) {
return HandleT();
Expand All @@ -89,16 +89,24 @@ struct Iterator
if (lid == INVALID16) {
return HandleT();
}
HandleT ret(m_patch_id, lid);

if (detail::is_owned(lid, m_output_owned_bitmask)) {
HandleT ret(m_patch_id, lid);
return ret;
} else {
return m_context.get_owner_handle(ret, nullptr, m_s_table);
assert(m_s_table);
LPPair lp = m_output_lp_hashtable.find(lid, m_s_table);
if (lp.is_sentinel()) {
return HandleT();
}
return HandleT(m_patch_stash.get_patch(lp),
{lp.local_id_in_owner_patch()});

// return m_context.get_owner_handle(ret, nullptr, m_s_table);
}
}

__device__ uint16_t local(const uint16_t i) const
__device__ __inline__ uint16_t local(const uint16_t i) const
{
if (i + m_begin >= m_end) {
return INVALID16;
Expand All @@ -109,15 +117,15 @@ struct Iterator
return lid;
}

__device__ HandleT back() const
__device__ __inline__ HandleT back() const
{
return ((*this)[size() - 1]);
}

__device__ HandleT front() const
__device__ __inline__ HandleT front() const
{
return ((*this)[0]);
}
}


private:
Expand Down
73 changes: 38 additions & 35 deletions include/rxmesh/kernels/query_dispatcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -105,44 +105,43 @@ __device__ __inline__ void query_block_dispatcher(
output_lp_hashtable = patch_info.lp_f;
}


// load table async
auto alloc_then_load_table = [&](bool with_wait) {
s_table = shrd_alloc.template alloc<LPPair>(
output_lp_hashtable.get_capacity());
output_lp_hashtable.load_in_shared_memory(s_table, with_wait);
};

if constexpr (op != Op::FV && op != Op::VV && op != Op::FF &&
op != Op::EVDiamond) {
if (op != Op::EV && oriented) {
alloc_then_load_table(false);
}
}
//if constexpr (op != Op::FV && op != Op::VV && op != Op::FF &&
// op != Op::EVDiamond) {
// if (op != Op::EV && oriented) {
// alloc_then_load_table(false);
// }
//}


// we cache the result of (is_active && is_owned && is_compute_set) in
// shared memory to check on it later
bool is_participant = false;
block_loop<uint16_t,
blockThreads,
true>(num_src_in_patch, [&](const uint16_t local_id) {
bool is_par = false;
if (local_id < num_src_in_patch) {
bool is_del = is_deleted(local_id, input_active_mask);
bool is_own =
allow_not_owned || is_owned(local_id, input_owned_mask);
bool is_act = compute_active_set({patch_info.patch_id, local_id});
is_par = !is_del && is_own && is_act;
}
is_participant = is_participant || is_par;
uint32_t warp_mask = __ballot_sync(0xFFFFFFFF, is_par);
uint32_t lane_id = threadIdx.x % 32;
if (lane_id == 0) {
uint32_t mask_id = local_id / 32;
s_participant_bitmask[mask_id] = warp_mask;
}
});
block_loop<uint16_t, blockThreads, true>(
num_src_in_patch, [&](const uint16_t local_id) {
bool is_par = false;
if (local_id < num_src_in_patch) {
bool is_del = is_deleted(local_id, input_active_mask);
bool is_own =
allow_not_owned || is_owned(local_id, input_owned_mask);
bool is_act =
compute_active_set({patch_info.patch_id, local_id});
is_par = !is_del && is_own && is_act;
}
is_participant = is_participant || is_par;
uint32_t warp_mask = __ballot_sync(0xFFFFFFFF, is_par);
uint32_t lane_id = threadIdx.x % 32;
if (lane_id == 0) {
uint32_t mask_id = local_id / 32;
s_participant_bitmask[mask_id] = warp_mask;
}
});


if (__syncthreads_or(is_participant) == 0) {
Expand All @@ -161,15 +160,19 @@ __device__ __inline__ void query_block_dispatcher(
s_output_value,
oriented);

if constexpr (op == Op::FV || op == Op::VV || op == Op::FF ||
op == Op::EVDiamond) {
block.sync();
alloc_then_load_table(true);
}
if (op == Op::EV && oriented) {
block.sync();
alloc_then_load_table(true);
}
block.sync();
alloc_then_load_table(true);

//if constexpr (op == Op::FV || op == Op::VV || op == Op::FF ||
// op == Op::EVDiamond || op == Op::VE || op == Op::VF) {
// block.sync();
// alloc_then_load_table(true);
//}
//if (op == Op::EV && oriented) {
// printf("\n YESS \n");
// block.sync();
// alloc_then_load_table(true);
//}
block.sync();
}

Expand Down
Loading

0 comments on commit 4b58217

Please sign in to comment.