Skip to content

Commit

Permalink
0.4.0 release (#27)
Browse files Browse the repository at this point in the history
* Run clang-format.

* Remove extraneous TODO.

* Update binding utility scripts.

* Bump version defines.

* Update version info in docs.

* Resolve compilation warnings in autotune.cc with conditional NCCL paths.

* Update integer to logical command-line argument handling in Fortran tests.

* Add CMake build recommendation to README.md.
  • Loading branch information
romerojosh authored Mar 14, 2024
1 parent 6e17527 commit b8ffecc
Show file tree
Hide file tree
Showing 16 changed files with 127 additions and 111 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ if (CRAY_CC_BIN)
message(STATUS "Found GDRCopy library: " ${GDRCOPY_LIBRARY})
endif()

# TODO: Check for MPICH to define `-DMPICH` flag

# HPC SDK
find_package(NVHPC REQUIRED COMPONENTS CUDA MATH)

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Please contact us or open a GitHub issue if you are interested in using this lib

## Build

### Method 1: Makefile with Configuration file
### Method 1: Makefile with Configuration file (deprecated)
To build the library, you must first create a configuration file to point the installed to dependent library paths and enable/disable features.
See the default [`nvhpcsdk.conf`](configs/nvhpcsdk.conf) for an example of settings to build the library using the [NVHPC SDK compilers and libraries](https://developer.nvidia.com/hpc-sdk).
The [`configs/`](configs) directory also contains several sample build configuration files for a number of GPU compute clusters, like Perlmutter, Summit, and Marconi 100.
Expand All @@ -25,9 +25,9 @@ With this configuration file created, you can build the library using the comman
$ make -j CONFIGFILE=<path to your configuration file>
```

The library will be compiled and installed in a newly created `build/` directory.
The library will be compiled and installed in a newly created `build/` directory. This build method is deprecated and will be removed in a future release.

### Method 2: CMake
### Method 2: CMake (recommended)
We also enable builds using CMake. A CMake build of the library without additional examples/tests can be completed using the following commands
```shell
$ mkdir build
Expand Down
4 changes: 1 addition & 3 deletions benchmark/benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ int main(int argc, char** argv) {
cudecompGridDescAutotuneOptions_t options;
CHECK_CUDECOMP_EXIT(cudecompGridDescAutotuneOptionsSetDefaults(&options));
options.dtype = get_cudecomp_datatype(complex_t(0));
for (int i = 0; i < 4; ++i) {
options.transpose_use_inplace_buffers[i] = !out_of_place;
}
for (int i = 0; i < 4; ++i) { options.transpose_use_inplace_buffers[i] = !out_of_place; }

if (comm_backend != 0) {
config.transpose_comm_backend = comm_backend;
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
author = 'NVIDIA Corporation'

# The full version, including alpha/beta/rc tags
release = '2022'
version = '0.4.0'
release = version


# -- General configuration ---------------------------------------------------
Expand Down
24 changes: 12 additions & 12 deletions include/cudecomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
#include <mpi.h>

#define CUDECOMP_MAJOR 0
#define CUDECOMP_MINOR 3
#define CUDECOMP_PATCH 1
#define CUDECOMP_MINOR 4
#define CUDECOMP_PATCH 0

/**
* @brief This enum lists the different available transpose backend options.
Expand Down Expand Up @@ -140,10 +140,10 @@ typedef struct {
*/
typedef struct {
// General options
int32_t n_warmup_trials; ///< number of warmup trials to run for each tested configuration during autotuning
///< (default: 3)
int32_t n_trials; ///< number of timed trials to run for each tested configuration during autotuning
///< (default: 5)
int32_t n_warmup_trials; ///< number of warmup trials to run for each tested configuration during autotuning
///< (default: 3)
int32_t n_trials; ///< number of timed trials to run for each tested configuration during autotuning
///< (default: 5)
cudecompAutotuneGridMode_t grid_mode; ///< which communication (transpose/halo) to use to autotune process grid
///< (default: CUDECOMP_AUTOTUNE_GRID_TRANSPOSE)
cudecompDataType_t dtype; ///< datatype to use during autotuning (default: CUDECOMP_DOUBLE)
Expand All @@ -152,19 +152,19 @@ typedef struct {
bool disable_nccl_backends; ///< flag to disable NCCL backend options during autotuning (default: false)
bool disable_nvshmem_backends; ///< flag to disable NVSHMEM backend options during autotuning (default: false)
double skip_threshold; ///< threshold used to skip testing slow configurations; skip configuration
///< if `skip_threshold * t > t_best`, where `t` is the duration of the first timed trial
///< for the configuration and `t_best` is the average trial time of the current best
///< configuration (default: 0.0)
///< if `skip_threshold * t > t_best`, where `t` is the duration of the first timed trial
///< for the configuration and `t_best` is the average trial time of the current best
///< configuration (default: 0.0)

// Transpose-specific options
bool autotune_transpose_backend; ///< flag to enable transpose backend autotuning (default: false)
bool transpose_use_inplace_buffers[4]; ///< flag to control whether transpose autotuning uses in-place or out-of-place
///< buffers during autotuning by transpose operation, considering
///< the following order: X-to-Y, Y-to-Z, Z-to-Y, Y-to-X
///< (default: [false, false, false, false])
double transpose_op_weights[4]; ///< multiplicative weight to apply to trial time contribution by transpose operation
///< in the following order: X-to-Y, Y-to-Z, Z-to-Y, Y-to-X
///< (default: [1.0, 1.0, 1.0, 1.0])
double transpose_op_weights[4]; ///< multiplicative weight to apply to trial time contribution by transpose operation
///< in the following order: X-to-Y, Y-to-Z, Z-to-Y, Y-to-X
///< (default: [1.0, 1.0, 1.0, 1.0])

// Halo-specific options
bool autotune_halo_backend; ///< flag to enable halo backend autotuning (default: false)
Expand Down
9 changes: 4 additions & 5 deletions include/internal/comm_routines.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,10 @@ nvshmemAlltoallV(const cudecompHandle_t& handle, const cudecompGridDesc_t& grid_
size_t send_bytes = send_counts[dst_rank] * sizeof(T);
size_t nchunks = (send_bytes + CUDECOMP_NVSHMEM_CHUNK_SZ - 1) / CUDECOMP_NVSHMEM_CHUNK_SZ;
for (size_t j = 0; j < nchunks; ++j) {
nvshmemx_putmem_nbi_on_stream(
recv_buff + recv_offsets[dst_rank] + j * (CUDECOMP_NVSHMEM_CHUNK_SZ / sizeof(T)),
send_buff + send_offsets[dst_rank] + j * (CUDECOMP_NVSHMEM_CHUNK_SZ / sizeof(T)),
std::min(CUDECOMP_NVSHMEM_CHUNK_SZ, send_bytes - j * CUDECOMP_NVSHMEM_CHUNK_SZ),
dst_rank_global, stream);
nvshmemx_putmem_nbi_on_stream(recv_buff + recv_offsets[dst_rank] + j * (CUDECOMP_NVSHMEM_CHUNK_SZ / sizeof(T)),
send_buff + send_offsets[dst_rank] + j * (CUDECOMP_NVSHMEM_CHUNK_SZ / sizeof(T)),
std::min(CUDECOMP_NVSHMEM_CHUNK_SZ, send_bytes - j * CUDECOMP_NVSHMEM_CHUNK_SZ),
dst_rank_global, stream);
}
continue;
}
Expand Down
5 changes: 3 additions & 2 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ struct cudecompHandle {
ncclComm_t nccl_comm = nullptr; // NCCL communicator (global)
ncclComm_t nccl_local_comm = nullptr; // NCCL communicator (intranode)
bool nccl_enable_ubr = false; // Flag to control NCCL user buffer registration usage
std::unordered_map<void*, std::vector<std::pair<ncclComm_t, void*>>> nccl_ubr_handles; // map of allocated buffer address to NCCL registration handle(s)
std::unordered_map<void*, std::vector<std::pair<ncclComm_t, void*>>>
nccl_ubr_handles; // map of allocated buffer address to NCCL registration handle(s)

cudaStream_t pl_stream = nullptr; // stream used for pipelined backends

cutensorHandle_t cutensor_handle; // cuTENSOR handle;
#if CUTENSOR_MAJOR >= 2
cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference;
cutensorPlanPreference_t cutensor_plan_pref; // cuTENSOR plan preference;
#endif

std::vector<std::array<char, MPI_MAX_PROCESSOR_NAME>> hostnames; // list of hostnames by rank
Expand Down
26 changes: 12 additions & 14 deletions include/internal/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,15 @@ template <typename T> static inline cutensorDataType_t getCutensorDataType() { r

static inline cutensorComputeDescriptor_t getCutensorComputeType(cutensorDataType_t cutensor_dtype) {
switch (cutensor_dtype) {
case CUTENSOR_R_32F:
case CUTENSOR_C_32F:
return CUTENSOR_COMPUTE_DESC_32F;
case CUTENSOR_R_64F:
case CUTENSOR_C_64F:
default:
return CUTENSOR_COMPUTE_DESC_64F;
case CUTENSOR_R_32F:
case CUTENSOR_C_32F: return CUTENSOR_COMPUTE_DESC_32F;
case CUTENSOR_R_64F:
case CUTENSOR_C_64F:
default: return CUTENSOR_COMPUTE_DESC_64F;
}
}

template <typename T>
static inline uint32_t getAlignment(const T* ptr) {
template <typename T> static inline uint32_t getAlignment(const T* ptr) {
auto i_ptr = reinterpret_cast<std::uintptr_t>(ptr);
for (uint32_t d = 16; d > 0; d >>= 1) {
if (i_ptr % d == 0) return d;
Expand Down Expand Up @@ -116,14 +113,15 @@ static void localPermute(const cudecompHandle_t handle, const std::array<int64_t
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, &desc_in, 3, extent_in.data(), strides_in_ptr,
cutensor_type, getAlignment(input)));
cutensorTensorDescriptor_t desc_out;
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, &desc_out, 3, extent_out.data(), strides_out_ptr,
cutensor_type, getAlignment(output)));
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, &desc_out, 3, extent_out.data(),
strides_out_ptr, cutensor_type, getAlignment(output)));

cutensorOperationDescriptor_t desc_op;
CHECK_CUTENSOR(cutensorCreatePermutation(handle->cutensor_handle, &desc_op, desc_in, order_in.data(), CUTENSOR_OP_IDENTITY,
desc_out, order_out.data(), getCutensorComputeType(cutensor_type)));
CHECK_CUTENSOR(cutensorCreatePermutation(handle->cutensor_handle, &desc_op, desc_in, order_in.data(),
CUTENSOR_OP_IDENTITY, desc_out, order_out.data(),
getCutensorComputeType(cutensor_type)));

cutensorPlan_t plan;
cutensorPlan_t plan;
CHECK_CUTENSOR(cutensorCreatePlan(handle->cutensor_handle, &plan, desc_op, handle->cutensor_plan_pref, 0));

T one(1);
Expand Down
Loading

0 comments on commit b8ffecc

Please sign in to comment.