Skip to content

Commit

Permalink
more C++ code clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Nov 14, 2023
1 parent 90f27bc commit fa19be7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 49 deletions.
1 change: 0 additions & 1 deletion cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3713,7 +3713,6 @@ def sum(
where=where,
)

@add_boilerplate()
def _nansum(
self,
axis: Any = None,
Expand Down
2 changes: 1 addition & 1 deletion src/cunumeric/unary/scalar_unary_red_template.inl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct ScalarUnaryRed {
}
};

template <VariantKind KIND, UnaryRedCode OP_CODE, int HAS_WHERE>
template <VariantKind KIND, UnaryRedCode OP_CODE, bool HAS_WHERE>
struct ScalarUnaryRedImpl {
template <Type::Code CODE, int DIM>
void operator()(ScalarUnaryRedArgs& args) const
Expand Down
58 changes: 11 additions & 47 deletions src/cunumeric/unary/unary_red.cu
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,10 @@ static void __device__ __forceinline__ collapse_dims(LHS& result,
#endif
}

template <typename OP, typename REDOP, typename LHS, typename RHS, int32_t DIM>
template <typename OP, typename REDOP, typename LHS, typename RHS, int32_t DIM, bool HAS_WHERE>
static __device__ __forceinline__ Point<DIM> local_reduce(LHS& result,
AccessorRO<RHS, DIM> in,
AccessorRO<bool, DIM> where,
LHS identity,
const ThreadBlocks<DIM>& blocks,
const Rect<DIM>& domain,
Expand All @@ -280,33 +281,10 @@ static __device__ __forceinline__ Point<DIM> local_reduce(LHS& result,
Point<DIM> point = blocks.point(bid, tid, domain.lo);
if (!domain.contains(point)) return point;

bool mask = true;
if constexpr (HAS_WHERE) mask = (where[point] == true);
while (point[collapsed_dim] <= domain.hi[collapsed_dim]) {
LHS value = OP::convert(point, collapsed_dim, identity, in[point]);
REDOP::template fold<true>(result, value);
blocks.next_point(point);
}

collapse_dims<REDOP, LHS>(result, point, domain, collapsed_dim, identity, tid);
return point;
}

template <typename OP, typename REDOP, typename LHS, typename RHS, int32_t DIM>
static __device__ __forceinline__ Point<DIM> local_reduce_where(LHS& result,
AccessorRO<RHS, DIM> in,
AccessorRO<bool, DIM> where,
LHS identity,
const ThreadBlocks<DIM>& blocks,
const Rect<DIM>& domain,
int32_t collapsed_dim)
{
const coord_t tid = threadIdx.x;
const coord_t bid = blockIdx.x;

Point<DIM> point = blocks.point(bid, tid, domain.lo);
if (!domain.contains(point)) return point;

while (point[collapsed_dim] <= domain.hi[collapsed_dim]) {
if (where[point] == true) {
if (mask) {
LHS value = OP::convert(point, collapsed_dim, identity, in[point]);
REDOP::template fold<true>(result, value);
}
Expand All @@ -316,33 +294,19 @@ static __device__ __forceinline__ Point<DIM> local_reduce_where(LHS& result,
collapse_dims<REDOP, LHS>(result, point, domain, collapsed_dim, identity, tid);
return point;
}
template <typename OP, typename REDOP, typename LHS, typename RHS, int32_t DIM>

template <typename OP, typename REDOP, typename LHS, typename RHS, int32_t DIM, bool HAS_WHERE>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
reduce_with_rd_acc(AccessorRD<REDOP, false, DIM> out,
AccessorRO<RHS, DIM> in,
AccessorRO<bool, DIM> where,
LHS identity,
ThreadBlocks<DIM> blocks,
Rect<DIM> domain,
int32_t collapsed_dim)
{
auto result = identity;
auto point =
local_reduce<OP, REDOP, LHS, RHS, DIM>(result, in, identity, blocks, domain, collapsed_dim);
if (result != identity) out.reduce(point, result);
}

template <typename OP, typename REDOP, typename LHS, typename RHS, int32_t DIM>
static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM)
reduce_with_rd_acc_where(AccessorRD<REDOP, false, DIM> out,
AccessorRO<RHS, DIM> in,
AccessorRO<bool, DIM> where,
LHS identity,
ThreadBlocks<DIM> blocks,
Rect<DIM> domain,
int32_t collapsed_dim)
{
auto result = identity;
auto point = local_reduce_where<OP, REDOP, LHS, RHS, DIM>(
auto point = local_reduce<OP, REDOP, LHS, RHS, DIM, HAS_WHERE>(
result, in, where, identity, blocks, domain, collapsed_dim);
if (result != identity) out.reduce(point, result);
}
Expand All @@ -362,7 +326,7 @@ struct UnaryRedImplBody<VariantKind::GPU, OP_CODE, CODE, DIM, HAS_WHERE> {
int collapsed_dim,
size_t volume) const
{
auto Kernel = reduce_with_rd_acc<OP, LG_OP, LHS, RHS, DIM>;
auto Kernel = reduce_with_rd_acc<OP, LG_OP, LHS, RHS, DIM, HAS_WHERE>;
auto stream = get_cached_stream();

ThreadBlocks<DIM> blocks;
Expand All @@ -374,7 +338,7 @@ struct UnaryRedImplBody<VariantKind::GPU, OP_CODE, CODE, DIM, HAS_WHERE> {
lhs, rhs, where, LG_OP::identity, blocks, rect, collapsed_dim);
else
Kernel<<<blocks.num_blocks(), blocks.num_threads(), 0, stream>>>(
lhs, rhs, LG_OP::identity, blocks, rect, collapsed_dim);
lhs, rhs, AccessorRO<bool, DIM>(), LG_OP::identity, blocks, rect, collapsed_dim);
CHECK_CUDA_STREAM(stream);
}
};
Expand Down

0 comments on commit fa19be7

Please sign in to comment.