From f1538d08a43f96223b180aa8ad9af33c233653aa Mon Sep 17 00:00:00 2001 From: WangXi Date: Sat, 13 Jul 2019 23:00:30 +0800 Subject: [PATCH] Add bit redop test --- README.md | 2 +- src/common.cu | 54 +++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 7a4bbbc..faca9ac 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ All tests support the same set of arguments : * `-i,--stepbytes ` fixed increment between sizes. Default : (max-min)/10. * `-f,--stepfactor ` multiplication factor between sizes. Default : disabled. * NCCL operations arguments - * `-o,--op ` Specify which reduction operation to perform. Only relevant for reduction operations like Allreduce, Reduce or ReduceScatter. Default : Sum. + * `-o,--op ` Specify which reduction operation to perform. Only relevant for reduction operations like Allreduce, Reduce or ReduceScatter. Default : Sum. * `-d,--datatype ` Specify which datatype to use. Default : Float. * `-r,--root ` Specify which root to use. Only for operations with a root like broadcast or reduce. Default : 0. * Performance diff --git a/src/common.cu b/src/common.cu index 5a3ae52..144d064 100644 --- a/src/common.cu +++ b/src/common.cu @@ -18,8 +18,8 @@ const char *test_typenames[ncclNumTypes] = {"int8", "uint8", "int32", "uint32", ncclDataType_t test_types[ncclNumTypes] = {ncclChar, ncclInt, ncclHalf, ncclFloat, ncclDouble, ncclInt64, ncclUint64}; const char *test_typenames[ncclNumTypes] = {"char", "int", "half", "float", "double", "int64", "uint64"}; #endif -ncclRedOp_t test_ops[ncclNumOps] = {ncclSum, ncclProd, ncclMax, ncclMin}; -const char *test_opnames[ncclNumOps] = {"sum", "prod", "max", "min"}; +ncclRedOp_t test_ops[ncclNumOps] = {ncclSum, ncclProd, ncclMax, ncclMin, ncclBitAnd, ncclBitOr, ncclBitXor}; +const char *test_opnames[ncclNumOps] = {"sum", "prod", "max", "min", "band", "bor", "bxor"}; thread_local int is_main_thread = 0; @@ -184,6 +184,12 @@ template __device__ T ncclOpMax(T a, T b) { return a>b ? a : b; } template __device__ T ncclOpMin(T a, T b) { return a +__device__ T ncclOpBitAnd(T a, T b) { return a&b; } +template +__device__ T ncclOpBitOr(T a, T b) { return a|b; } +template +__device__ T ncclOpBitXor(T a, T b) { return a^b; } // Definitions for half template<> @@ -195,6 +201,45 @@ __device__ half ncclOpMax(half a, half b) { return __half2float(a)>__half2float( template<> __device__ half ncclOpMin(half a, half b) { return __half2float(a)<__half2float(b) ? a : b; } +// Definitions for bit op with floating number +template +union bitConverter; + +template<> +union bitConverter { + half storage; + int16_t a; +}; +template<> +union bitConverter { + float storage; + int a; +}; +template<> +union bitConverter { + double storage; + int64_t a; +}; + +#define BIT_OPS(dtype, name, op) \ +template<> \ +__device__ dtype ncclOpBit##name(dtype a, dtype b) { \ + union bitConverter ca, cb, cr; \ + ca.storage = a; \ + cb.storage = b; \ + cr.a = ca.a op cb.a; \ + return cr.storage; \ +} + +#define BIT_OP_TYPE(name, op) \ + BIT_OPS(half, name, op) \ + BIT_OPS(float, name, op) \ + BIT_OPS(double, name, op) + +BIT_OP_TYPE(And, &) +BIT_OP_TYPE(Or, |) +BIT_OP_TYPE(Xor, ^) + template __global__ void InitDataReduceKernel(T* data, const size_t N, const size_t offset, const int rep, const int nranks) { for (size_t o=blockIdx.x*blockDim.x+threadIdx.x; o> -#define OPS(type) KERN(type, ncclOpSum), KERN(type, ncclOpProd), KERN(type, ncclOpMax), KERN(type, ncclOpMin) +#define OPS(type) KERN(type, ncclOpSum), KERN(type, ncclOpProd), KERN(type, ncclOpMax), KERN(type, ncclOpMin), \ + KERN(type, ncclOpBitAnd), KERN(type, ncclOpBitOr), KERN(type, ncclOpBitXor) static void* const redInitDataKerns[ncclNumOps*ncclNumTypes] = { OPS(int8_t), OPS(uint8_t), OPS(int32_t), OPS(uint32_t), OPS(int64_t), OPS(uint64_t), OPS(half), OPS(float), OPS(double) @@ -658,7 +704,7 @@ int main(int argc, char* argv[]) { "[-w,--warmup_iters ] \n\t" "[-p,--parallel_init <0/1>] \n\t" "[-c,--check <0/1>] \n\t" - "[-o,--op ] \n\t" + "[-o,--op ] \n\t" "[-d,--datatype ] \n\t" "[-r,--root ] \n\t" "[-z,--blocking <0/1>] \n\t"