From d9ff93b919188cc21734cae3e2b4cf0abdb354b7 Mon Sep 17 00:00:00 2001 From: shenan Date: Tue, 20 Oct 2020 16:18:55 +0000 Subject: [PATCH 01/20] FEAT: HMPI INITIAL COMMIT HyperMPI initial commit. --- LICENSE | 2 + README | 2 + config/ompi_check_ucx.m4 | 35 +- contrib/platform/mellanox/optimized | 4 +- ompi/mca/coll/ucx/Makefile.am | 49 +++ ompi/mca/coll/ucx/coll_ucx.h | 178 ++++++++ ompi/mca/coll/ucx/coll_ucx_component.c | 445 +++++++++++++++++++ ompi/mca/coll/ucx/coll_ucx_freelist.h | 31 ++ ompi/mca/coll/ucx/coll_ucx_module.c | 509 ++++++++++++++++++++++ ompi/mca/coll/ucx/coll_ucx_op.c | 450 +++++++++++++++++++ ompi/mca/coll/ucx/coll_ucx_request.c | 177 ++++++++ ompi/mca/coll/ucx/coll_ucx_request.h | 68 +++ ompi/mca/coll/ucx/configure.m4 | 37 ++ ompi/mca/pml/ucx/pml_ucx.c | 8 +- opal/mca/common/ucx/common_ucx.c | 13 +- opal/mca/common/ucx/common_ucx.h | 45 +- orte/mca/rmaps/base/rmaps_base_ranking.c | 30 +- oshmem/mca/atomic/ucx/atomic_ucx_cswap.c | 2 + oshmem/mca/atomic/ucx/atomic_ucx_module.c | 2 + oshmem/mca/spml/ucx/spml_ucx.c | 5 +- oshmem/mca/sshmem/ucx/sshmem_ucx_module.c | 3 +- 21 files changed, 2059 insertions(+), 36 deletions(-) create mode 100644 ompi/mca/coll/ucx/Makefile.am create mode 100644 ompi/mca/coll/ucx/coll_ucx.h create mode 100644 ompi/mca/coll/ucx/coll_ucx_component.c create mode 100644 ompi/mca/coll/ucx/coll_ucx_freelist.h create mode 100644 ompi/mca/coll/ucx/coll_ucx_module.c create mode 100644 ompi/mca/coll/ucx/coll_ucx_op.c create mode 100644 ompi/mca/coll/ucx/coll_ucx_request.c create mode 100644 ompi/mca/coll/ucx/coll_ucx_request.h create mode 100644 ompi/mca/coll/ucx/configure.m4 diff --git a/LICENSE b/LICENSE index 29b02918cee..634cae748e7 100644 --- a/LICENSE +++ b/LICENSE @@ -57,6 +57,8 @@ Copyright (c) 2017-2018 Amazon.com, Inc. or its affiliates. All Rights reserved. Copyright (c) 2019 Triad National Security, LLC. All rights reserved. +Copyright (c) 2020 Huawei Technologies Co.,Ltd. All rights + reserved. $COPYRIGHT$ diff --git a/README b/README index 150fbda2ae2..0d39eb05af8 100644 --- a/README +++ b/README @@ -23,6 +23,8 @@ Copyright (c) 2017 Research Organization for Information Science and Technology (RIST). All rights reserved. Copyright (c) 2019 Triad National Security, LLC. All rights reserved. +Copyright (c) 2020 Huawei Technologies Co.,Ltd. All rights + reserved. $COPYRIGHT$ diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index 7f04ba3a52c..52f80f0c083 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -13,6 +13,10 @@ # # $HEADER$ # +#2020.06.09-Changed process for coll_ucx +# Huawei Technologies Co., Ltd. 2020. +# + # OMPI_CHECK_UCX(prefix, [action-if-found], [action-if-not-found]) # -------------------------------------------------------- @@ -41,6 +45,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [ompi_check_ucx_dir=])], [true])]) ompi_check_ucx_happy="no" + ompi_check_ucg_happy="no" + ucx_libs="-luct -lucm -lucs" AS_IF([test -z "$ompi_check_ucx_dir"], [OPAL_CHECK_PACKAGE([ompi_check_ucx], [ucp/api/ucp.h], @@ -51,6 +57,15 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [], [ompi_check_ucx_happy="yes"], [ompi_check_ucx_happy="no"]) + OPAL_CHECK_PACKAGE([ompi_check_ucg], + [ucg/api/ucg.h], + [ucg], + [ucg_request_check_status], + [-lucg -lucp $ucx_libs], + [], + [], + [ompi_check_ucg_happy="yes"], + [ompi_check_ucg_happy="no"]) AS_IF([test "$ompi_check_ucx_happy" = yes], [AC_MSG_CHECKING(for UCX version compatibility) AC_REQUIRE_CPP @@ -83,6 +98,15 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [$ompi_check_ucx_libdir], [ompi_check_ucx_happy="yes"], [ompi_check_ucx_happy="no"]) + OPAL_CHECK_PACKAGE([ompi_check_ucg], + [ucg/api/ucg.h], + [ucg], + [ucg_request_check_status], + [-lucg -lucp $ucx_libs], + [$ompi_check_ucx_dir], + [$ompi_check_ucx_libdir], + [ompi_check_ucg_happy="yes"], + [ompi_check_ucg_happy="no"]) CPPFLAGS="$ompi_check_ucx_$1_save_CPPFLAGS" LDFLAGS="$ompi_check_ucx_$1_save_LDFLAGS" @@ -133,10 +157,13 @@ AC_DEFUN([OMPI_CHECK_UCX],[ OPAL_SUMMARY_ADD([[Transports]],[[Open UCX]],[$1],[$ompi_check_ucx_happy])])]) AS_IF([test "$ompi_check_ucx_happy" = "yes"], - [$1_CPPFLAGS="[$]$1_CPPFLAGS $ompi_check_ucx_CPPFLAGS" - $1_LDFLAGS="[$]$1_LDFLAGS $ompi_check_ucx_LDFLAGS" - $1_LIBS="[$]$1_LIBS $ompi_check_ucx_LIBS" - AC_DEFINE([HAVE_UCX], [1], [have ucx]) + [AS_IF([test "$ompi_check_ucg_happy" = "yes"], + [$1_CPPFLAGS="[$]$1_CPPFLAGS $ompi_check_ucg_CPPFLAGS" + $1_LDFLAGS="[$]$1_LDFLAGS $ompi_check_ucg_LDFLAGS" + $1_LIBS="[$]$1_LIBS $ompi_check_ucg_LIBS"], + [$1_CPPFLAGS="[$]$1_CPPFLAGS $ompi_check_ucx_CPPFLAGS" + $1_LDFLAGS="[$]$1_LDFLAGS $ompi_check_ucx_LDFLAGS" + $1_LIBS="[$]$1_LIBS $ompi_check_ucx_LIBS"]) $2], [AS_IF([test ! -z "$with_ucx" && test "$with_ucx" != "no"], [AC_MSG_ERROR([UCX support requested but not found. Aborting])]) diff --git a/contrib/platform/mellanox/optimized b/contrib/platform/mellanox/optimized index f49a0576c64..86beb4a57bb 100644 --- a/contrib/platform/mellanox/optimized +++ b/contrib/platform/mellanox/optimized @@ -52,9 +52,9 @@ else enable_picky=no enable_heterogeneous=no enable_ft_thread=no - with_mpi_param_check=no + with_mpi_param_check=yes CXXFLAGS="-O3 -g" CCASFLAGS="-O3 -g" FCFLAGS="-O3 -g" - CFLAGS="-O3 -g" + CFLAGS="-O3 -g $CFLAGS" fi diff --git a/ompi/mca/coll/ucx/Makefile.am b/ompi/mca/coll/ucx/Makefile.am new file mode 100644 index 00000000000..e82d63f82b8 --- /dev/null +++ b/ompi/mca/coll/ucx/Makefile.am @@ -0,0 +1,49 @@ +# -*- shell-script -*- +# +# +# Copyright (c) 2011 Mellanox Technologies. All rights reserved. +# Copyright (c) 2013 Cisco Systems, Inc. All rights reserved. +# Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# +# + +AM_CPPFLAGS = $(coll_ucx_CPPFLAGS) -DCOLL_UCX_HOME=\"$(coll_ucx_HOME)\" $(coll_ucx_extra_CPPFLAGS) + +#dist_ompidata_DATA = help-coll-ucx.txt +coll_ucx_sources = \ + coll_ucx.h \ + coll_ucx_request.h \ + coll_ucx_freelist.h \ + coll_ucx_op.c \ + coll_ucx_module.c \ + coll_ucx_request.c \ + coll_ucx_component.c + +# Make the output library in this directory, and name it either +# mca__.la (for DSO builds) or libmca__.la +# (for static builds). + +if MCA_BUILD_ompi_coll_ucx_DSO +component_noinst = +component_install = mca_coll_ucx.la +else +component_noinst = libmca_coll_ucx.la +component_install = +endif + +mcacomponentdir = $(ompilibdir) +mcacomponent_LTLIBRARIES = $(component_install) +mca_coll_ucx_la_SOURCES = $(coll_ucx_sources) +mca_coll_ucx_la_LIBADD = $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la $(coll_ucx_LIBS) \ + $(OPAL_TOP_BUILDDIR)/opal/mca/common/ucx/lib@OPAL_LIB_PREFIX@mca_common_ucx.la +mca_coll_ucx_la_LDFLAGS = -module -avoid-version $(coll_ucx_LDFLAGS) + +noinst_LTLIBRARIES = $(component_noinst) +libmca_coll_ucx_la_SOURCES =$(coll_ucx_sources) +libmca_coll_ucx_la_LIBADD = $(coll_ucx_LIBS) +libmca_coll_ucx_la_LDFLAGS = -module -avoid-version $(coll_ucx_LDFLAGS) diff --git a/ompi/mca/coll/ucx/coll_ucx.h b/ompi/mca/coll/ucx/coll_ucx.h new file mode 100644 index 00000000000..5bef95720cd --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx.h @@ -0,0 +1,178 @@ +/** + Copyright (c) 2011 Mellanox Technologies. All rights reserved. + Copyright (c) 2015 Research Organization for Information Science + and Technology (RIST). All rights reserved. + Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + $COPYRIGHT$ + + Additional copyrights may follow + + $HEADER$ + */ + +#ifndef MCA_COLL_UCX_H +#define MCA_COLL_UCX_H + +#include "ompi_config.h" + +#include "mpi.h" +#include "ompi/mca/mca.h" +#include "opal/memoryhooks/memory.h" +#include "opal/mca/memory/base/base.h" +#include "ompi/mca/coll/coll.h" +#include "ompi/request/request.h" +#include "ompi/mca/pml/pml.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/communicator/communicator.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/attribute/attribute.h" +#include "ompi/op/op.h" + +#include "orte/runtime/orte_globals.h" +#include "opal/mca/common/ucx/common_ucx.h" + +#include "ucg/api/ucg_mpi.h" +#include "ucs/datastruct/list.h" +#include "coll_ucx_freelist.h" + +#ifndef UCX_VERSION +#define UCX_VERSION(major, minor) (((major)< + +#include +#include + +#include "opal/mca/common/ucx/common_ucx.h" +#include "opal/mca/installdirs/installdirs.h" + +#include "coll_ucx.h" +#include "coll_ucx_request.h" + + +/* + * Public string showing the coll ompi_hcol component version number + */ +const char *mca_coll_ucx_component_version_string = + "Open MPI UCX collective MCA component version " OMPI_VERSION; + + +static int ucx_open(void); +static int ucx_close(void); +static int ucx_register(void); +int mca_coll_ucx_init_query(bool enable_progress_threads, + bool enable_mpi_threads); +mca_coll_base_module_t *mca_coll_ucx_comm_query(struct ompi_communicator_t *comm, int *priority); + +int mca_coll_ucx_output = -1; +mca_coll_ucx_component_t mca_coll_ucx_component = { + + /* First, the mca_component_t struct containing meta information + about the component itfca */ + { + .collm_version = { + MCA_COLL_BASE_VERSION_2_0_0, + + /* Component name and version */ + .mca_component_name = "ucx", + MCA_BASE_MAKE_VERSION(component, OMPI_MAJOR_VERSION, OMPI_MINOR_VERSION, + OMPI_RELEASE_VERSION), + + /* Component open and close functions */ + .mca_open_component = ucx_open, + .mca_close_component = ucx_close, + .mca_register_component_params = ucx_register, + }, + .collm_data = { + /* The component is not checkpoint ready */ + MCA_BASE_METADATA_PARAM_NONE + }, + + /* Initialization querying functions */ + .collm_init_query = mca_coll_ucx_init_query, + .collm_comm_query = mca_coll_ucx_comm_query, + }, + .priority = 91, /* priority */ + .verbose = 0, /* verbose level */ + .num_disconnect = 0, /* ucx_enable */ + .enable_topo_map = 1, /* enable topology map */ + .topo_map = NULL +}; + +int mca_coll_ucx_init_query(bool enable_progress_threads, + bool enable_mpi_threads) +{ + return OMPI_SUCCESS; +} + +mca_coll_base_module_t *mca_coll_ucx_comm_query(struct ompi_communicator_t *comm, int *priority) +{ + /* basic checks */ + if ((OMPI_COMM_IS_INTER(comm)) || (ompi_comm_size(comm) < 2)) { + return NULL; + } + + /* create a new module for this communicator */ + COLL_UCX_VERBOSE(10, "Creating ucx_context for comm %p, comm_id %d, comm_size %d", + (void*)comm, comm->c_contextid, ompi_comm_size(comm)); + mca_coll_ucx_module_t *ucx_module = OBJ_NEW(mca_coll_ucx_module_t); + if (!ucx_module) { + return NULL; + } + + *priority = mca_coll_ucx_component.priority; + return &(ucx_module->super); +} + +static int ucx_register(void) +{ + int status; + mca_coll_ucx_component.verbose = 0; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "verbosity", + "Verbosity of the UCX component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.verbose); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + mca_coll_ucx_component.priority = 91; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "priority", + "Priority of the UCX component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.priority); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + mca_coll_ucx_component.num_disconnect = 1; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "num_disconnect", + "How may disconnects go in parallel", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.num_disconnect); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + mca_coll_ucx_component.enable_topo_map = 1; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "enable_topo_map", + "Enable global topology map for ucg", + MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.enable_topo_map); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + opal_common_ucx_mca_var_register(&mca_coll_ucx_component.super.collm_version); + return OMPI_SUCCESS; +} + +static int ucx_open(void) +{ + mca_coll_ucx_component.output = opal_output_open(NULL); + opal_output_set_verbosity(mca_coll_ucx_component.output, mca_coll_ucx_component.verbose); + + opal_common_ucx_mca_register(); + + return mca_coll_ucx_open(); +} + +static int ucx_close(void) +{ + if (mca_coll_ucx_component.ucg_worker == NULL) { + return OMPI_ERROR; + } + mca_coll_ucx_cleanup(); + + opal_common_ucx_mca_deregister(); + + return mca_coll_ucx_close(); +} + +static int mca_coll_ucx_send_worker_address(void) +{ + ucg_address_t *address = NULL; + ucs_status_t status; + size_t addrlen; + int rc; + + status = ucg_worker_get_address(mca_coll_ucx_component.ucg_worker, &address, &addrlen); + if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to get worker address"); + return OMPI_ERROR; + } + + OPAL_MODEX_SEND(rc, OPAL_PMIX_GLOBAL, &mca_coll_ucx_component.super.collm_version, + (void*)address, addrlen); + if (OPAL_SUCCESS != rc) { + COLL_UCX_ERROR("Open MPI couldn't distribute EP connection details"); + return OMPI_ERROR; + } + + ucg_worker_release_address(mca_coll_ucx_component.ucg_worker, address); + + return OMPI_SUCCESS; +} + +static int mca_coll_ucx_recv_worker_address(ompi_proc_t *proc, + ucg_address_t **address_p, + size_t *addrlen_p) +{ + int ret; + + *address_p = NULL; + OPAL_MODEX_RECV(ret, &mca_coll_ucx_component.super.collm_version, + &proc->super.proc_name, (void**)address_p, addrlen_p); + if (ret != OPAL_SUCCESS) { + COLL_UCX_ERROR("Failed to receive UCX worker address: %s (%d)", opal_strerror(ret), ret); + } + return ret; +} + +int mca_coll_ucx_open(void) +{ + ucg_context_attr_t attr; + ucg_params_t params; + ucg_config_t *config = NULL; + ucs_status_t status; + + COLL_UCX_VERBOSE(1, "mca_coll_ucx_open"); + + /* Read options */ + status = ucg_config_read("MPI", NULL, &config); + if (UCS_OK != status) { + return OMPI_ERROR; + } + + /* Initialize UCX context */ + params.field_mask = UCP_PARAM_FIELD_FEATURES | + UCP_PARAM_FIELD_REQUEST_SIZE | + UCP_PARAM_FIELD_REQUEST_INIT | + UCP_PARAM_FIELD_REQUEST_CLEANUP | + // UCP_PARAM_FIELD_TAG_SENDER_MASK | + UCP_PARAM_FIELD_MT_WORKERS_SHARED | + UCP_PARAM_FIELD_ESTIMATED_NUM_EPS; + params.features = UCP_FEATURE_TAG | + UCP_FEATURE_RMA | + UCP_FEATURE_AMO32 | + UCP_FEATURE_AMO64 | + UCP_FEATURE_GROUPS; + params.request_size = sizeof(ompi_request_t); + params.request_init = mca_coll_ucx_request_init; + params.request_cleanup = mca_coll_ucx_request_cleanup; + params.mt_workers_shared = 0; /* we do not need mt support for context + since it will be protected by worker */ + params.estimated_num_eps = ompi_proc_world_size(); + + status = ucg_init(¶ms, config, &mca_coll_ucx_component.ucg_context); + ucg_config_release(config); + config = NULL; + if (UCS_OK != status) { + return OMPI_ERROR; + } + + /* Query UCX attributes */ + attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE; + status = ucg_context_query(mca_coll_ucx_component.ucg_context, &attr); + if (UCS_OK != status) { + goto out; + } + + mca_coll_ucx_component.request_size = attr.request_size; + + /* Initialize UCX worker */ + if (OMPI_SUCCESS != mca_coll_ucx_init()) { + goto out; + } + + ucs_list_head_init(&mca_coll_ucx_component.group_head); + return OMPI_SUCCESS; + +out: + ucg_cleanup(mca_coll_ucx_component.ucg_context); + mca_coll_ucx_component.ucg_context = NULL; + return OMPI_ERROR; +} + +int mca_coll_ucx_close(void) +{ + COLL_UCX_VERBOSE(1, "mca_coll_ucx_close"); + + if (mca_coll_ucx_component.ucg_worker != NULL) { + mca_coll_ucx_cleanup(); + mca_coll_ucx_component.ucg_worker = NULL; + } + + if (mca_coll_ucx_component.ucg_context != NULL) { + ucg_cleanup(mca_coll_ucx_component.ucg_context); + mca_coll_ucx_component.ucg_context = NULL; + } + return OMPI_SUCCESS; +} + +int mca_coll_ucx_progress(void) +{ + mca_coll_ucx_module_t *module = NULL; + ucs_list_for_each(module, &mca_coll_ucx_component.group_head, ucs_list) { + ucg_group_progress(module->ucg_group); + } + return OMPI_SUCCESS; +} + +int mca_coll_ucx_init_worker(void) +{ + int rc; + ucg_worker_attr_t attr; + + attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; + rc = ucg_worker_query(mca_coll_ucx_component.ucg_worker, &attr); + if (UCS_OK != rc) { + COLL_UCX_ERROR("Failed to query UCP worker thread level"); + rc = OMPI_ERROR; + return rc; + } + + /* UCX does not support multithreading, disqualify current PML for now */ + if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) { + COLL_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE"); + rc = OMPI_ERR_NOT_SUPPORTED; + return rc; + } + + /* Share my UCP address, so it could be later obtained via @ref mca_coll_ucx_resolve_address */ + rc = mca_coll_ucx_send_worker_address(); + return rc; +} + +int mca_coll_ucx_init(void) +{ + ucg_worker_params_t params; + ucs_status_t status; + int rc; + + COLL_UCX_VERBOSE(1, "mca_coll_ucx_init"); + params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + params.thread_mode = UCS_THREAD_MODE_SINGLE; + if (ompi_mpi_thread_multiple) { + params.thread_mode = UCS_THREAD_MODE_MULTI; + } else { + params.thread_mode = UCS_THREAD_MODE_SINGLE; + } + + status = ucg_worker_create(mca_coll_ucx_component.ucg_context, ¶ms, &mca_coll_ucx_component.ucg_worker); + if (UCS_OK != status) { + COLL_UCX_WARN("Failed to create UCG worker"); + rc = OMPI_ERROR; + goto err; + } + + status = mca_coll_ucx_init_worker(); + if (UCS_OK != status) { + COLL_UCX_WARN("Failed to init UCG worker"); + rc = OMPI_ERROR; + goto err_destroy_worker; + } + + /* Initialize the free lists */ + OBJ_CONSTRUCT(&mca_coll_ucx_component.persistent_ops, mca_coll_ucx_freelist_t); + + /* Create a completed request to be returned from isend */ + OBJ_CONSTRUCT(&mca_coll_ucx_component.completed_send_req, ompi_request_t); + mca_coll_ucx_completed_request_init(&mca_coll_ucx_component.completed_send_req); + + rc = opal_progress_register(mca_coll_ucx_progress); + if (OPAL_SUCCESS != rc) { + COLL_UCX_ERROR("Failed to progress register"); + goto err_destroy_worker; + } + + COLL_UCX_VERBOSE(2, "created ucp context %p, worker %p", (void *)mca_coll_ucx_component.ucg_context, + (void *)mca_coll_ucx_component.ucg_worker); + return rc; + +err_destroy_worker: + ucg_worker_destroy(mca_coll_ucx_component.ucg_worker); + mca_coll_ucx_component.ucg_worker = NULL; +err: + return rc; +} + +void mca_coll_ucx_cleanup(void) +{ + COLL_UCX_VERBOSE(1, "mca_coll_ucx_cleanup"); + + opal_progress_unregister(mca_coll_ucx_progress); + + mca_coll_ucx_component.completed_send_req.req_state = OMPI_REQUEST_INVALID; + OMPI_REQUEST_FINI(&mca_coll_ucx_component.completed_send_req); + OBJ_DESTRUCT(&mca_coll_ucx_component.completed_send_req); + OBJ_DESTRUCT(&mca_coll_ucx_component.persistent_ops); + + if (mca_coll_ucx_component.ucg_worker) { + ucg_worker_destroy(mca_coll_ucx_component.ucg_worker); + mca_coll_ucx_component.ucg_worker = NULL; + } + if (mca_coll_ucx_component.topo_map) { + for (unsigned i = 0; i < mca_coll_ucx_component.world_member_count; i++) { + free(mca_coll_ucx_component.topo_map[i]); + mca_coll_ucx_component.topo_map[i] = NULL; + } + free(mca_coll_ucx_component.topo_map); + mca_coll_ucx_component.topo_map = NULL; + } +} + +ucs_status_t mca_coll_ucx_resolve_address(void *cb_group_obj, ucg_group_member_index_t rank, ucg_address_t **addr, + size_t *addr_len) +{ + /* Sanity checks */ + ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; + if (rank == (ucg_group_member_index_t)comm->c_my_rank) { + return UCS_ERR_UNSUPPORTED; + } + + /* Check the cache for a previously established connection to that rank */ + ompi_proc_t *proc_peer = + (struct ompi_proc_t*)ompi_comm_peer_lookup((ompi_communicator_t*)cb_group_obj, rank); + *addr = proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL]; + *addr_len = 1; + if (*addr) { + return UCS_OK; + } + + /* Obtain the UCP address of the remote */ + int ret = mca_coll_ucx_recv_worker_address(proc_peer, addr, addr_len); + if (ret < 0) { + COLL_UCX_ERROR("mca_coll_ucx_recv_worker_address(proc=%d rank=%lu) failed", + proc_peer->super.proc_name.vpid, rank); + return UCS_ERR_INVALID_ADDR; + } + + /* Cache the connection for future invocations with this rank */ + proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL] = *addr; + return UCS_OK; +} + +void mca_coll_ucx_release_address(ucg_address_t *addr) +{ + /* no need to free - the address is stored in proc_peer->proc_endpoints */ +} + +ucg_worker_h mca_coll_ucx_get_component_worker() +{ + return mca_coll_ucx_component.ucg_worker; +} diff --git a/ompi/mca/coll/ucx/coll_ucx_freelist.h b/ompi/mca/coll/ucx/coll_ucx_freelist.h new file mode 100644 index 00000000000..9aca9e89c05 --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_freelist.h @@ -0,0 +1,31 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef COLL_UCX_FREELIST_H_ +#define COLL_UCX_FREELIST_H_ + +#include "ompi_config.h" +#include "opal/class/opal_free_list.h" + + +#define mca_coll_ucx_freelist_t opal_free_list_t + +#define COLL_UCX_FREELIST_GET(_freelist) \ + opal_free_list_get (_freelist) + +#define COLL_UCX_FREELIST_RETURN(_freelist, _item) \ + opal_free_list_return(_freelist, _item) + +#define COLL_UCX_FREELIST_INIT(_fl, _type, _initial, _max, _batch) \ + opal_free_list_init(_fl, sizeof(_type), 8, OBJ_CLASS(_type), \ + 0, 0, _initial, _max, _batch, NULL, 0, NULL, NULL, NULL) + + +#endif /* COLL_UCX_FREELIST_H_ */ diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c new file mode 100644 index 00000000000..3cc84c474e6 --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2011 Mellanox Technologies. All rights reserved. + * Copyright (c) 2014 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "ompi_config.h" +#include "ompi/constants.h" +#include "ompi/datatype/ompi_datatype.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/op/op.h" + +#include "coll_ucx.h" +#include "coll_ucx_request.h" + +#include +#include +#include +#include +#include +#include +#include + +static int mca_coll_ucg_obtain_addr_from_hostname(const char *hostname, struct in_addr *ip_addr) +{ + struct addrinfo hints; + struct addrinfo *res = NULL, *cur = NULL; + struct sockaddr_in *addr = NULL; + int ret; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_INET; + hints.ai_flags = AI_PASSIVE; + hints.ai_protocol = 0; + hints.ai_socktype = SOCK_DGRAM; + ret = getaddrinfo(hostname, NULL, &hints, &res); + if (ret < 0) { + COLL_UCX_ERROR("%s", gai_strerror(ret)); + return OMPI_ERROR; + } + + for (cur = res; cur != NULL; cur = cur->ai_next) { + addr = (struct sockaddr_in *)cur->ai_addr; + } + + *ip_addr = addr->sin_addr; + freeaddrinfo(res); + return OMPI_SUCCESS; +} + +static int mca_coll_ucg_obtain_node_index(unsigned member_count, struct ompi_communicator_t *comm, uint16_t *node_index) +{ + ucg_group_member_index_t rank_idx, rank2_idx; + uint16_t same_node_flag; + uint16_t node_idx = 0; + uint16_t init_node_idx = (uint16_t) - 1; + int status, status2; + struct in_addr ip_address, ip_address2; + + /* initialize: -1: unnumbering flag */ + for (rank_idx = 0; rank_idx < member_count; rank_idx++) { + node_index[rank_idx] = init_node_idx; + } + + for (rank_idx = 0; rank_idx < member_count; rank_idx++) { + if (node_index[rank_idx] == init_node_idx) { + struct ompi_proc_t *rank_iter = + (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, rank_idx); + /* super.proc_hostname give IP address or real hostname */ + /* transform hostname to IP address for uniform format */ + status = mca_coll_ucg_obtain_addr_from_hostname(rank_iter->super.proc_hostname, &ip_address); + for (rank2_idx = rank_idx; rank2_idx < member_count; rank2_idx++) { + struct ompi_proc_t *rank2_iter = + (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, rank2_idx); + + status2 = mca_coll_ucg_obtain_addr_from_hostname(rank2_iter->super.proc_hostname, &ip_address2); + if (status != OMPI_SUCCESS || status2 != OMPI_SUCCESS) { + return OMPI_ERROR; + } + + /* if rank_idx and rank2_idx in same node, same_flag = 1 */ + same_node_flag = (memcmp(&ip_address, &ip_address2, sizeof(ip_address))) ? 0 : 1; + if (same_node_flag == 1 && node_index[rank2_idx] == init_node_idx) { + node_index[rank2_idx] = node_idx; + } + } + node_idx++; + } + } + + /* make sure every rank has its node_index */ + for (rank_idx = 0; rank_idx < member_count; rank_idx++) { + /* some rank do NOT have node_index */ + if (node_index[rank_idx] == init_node_idx) { + return OMPI_ERROR; + } + } + return OMPI_SUCCESS; +} + +static int mca_coll_ucx_create_topo_map(const uint16_t *node_index, const char *topo_info, unsigned loc_size, unsigned rank_cnt) +{ + mca_coll_ucx_component.topo_map = (char**)malloc(sizeof(char*) * rank_cnt); + if (mca_coll_ucx_component.topo_map == NULL) { + return OMPI_ERROR; + } + + unsigned i, j; + for (i = 0; i < rank_cnt; i++) { + mca_coll_ucx_component.topo_map[i] = (char*)malloc(sizeof(char) * rank_cnt); + if (mca_coll_ucx_component.topo_map[i] == NULL) { + for (j = 0; j < i; j++) { + free(mca_coll_ucx_component.topo_map[j]); + mca_coll_ucx_component.topo_map[j] = NULL; + } + free(mca_coll_ucx_component.topo_map); + mca_coll_ucx_component.topo_map = NULL; + return OMPI_ERROR; + } + for (j = 0; j <= i; j++) { + if (i == j) { + mca_coll_ucx_component.topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_SELF; + continue; + } + + if (node_index[i] != node_index[j]) { + mca_coll_ucx_component.topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; + mca_coll_ucx_component.topo_map[j][i] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; + continue; + } + + opal_hwloc_locality_t rel_loc = opal_hwloc_compute_relative_locality(topo_info + i * loc_size, topo_info + j * loc_size); + enum ucg_group_member_distance distance; + if (OPAL_PROC_ON_LOCAL_L3CACHE(rel_loc)) { + distance = UCG_GROUP_MEMBER_DISTANCE_L3CACHE; + } else if (OPAL_PROC_ON_LOCAL_SOCKET(rel_loc)) { + distance = UCG_GROUP_MEMBER_DISTANCE_SOCKET; + } else if (OPAL_PROC_ON_LOCAL_HOST(rel_loc)) { + distance = UCG_GROUP_MEMBER_DISTANCE_HOST; + } else { + distance = UCG_GROUP_MEMBER_DISTANCE_NET; + } + mca_coll_ucx_component.topo_map[i][j] = (char)distance; + mca_coll_ucx_component.topo_map[j][i] = (char)distance; + } + } + return OMPI_SUCCESS; +} + +static int mca_coll_ucx_print_topo_map(unsigned rank_cnt, char **topo_map) +{ + int status = OMPI_SUCCESS; + + /* Print topo map for rank 0. */ + if (ompi_comm_rank(MPI_COMM_WORLD) == 0) { + unsigned i; + for (i = 0; i < rank_cnt; i++) { + char *topo_print = (char*)malloc(rank_cnt + 1); + if (topo_print == NULL) { + status = OMPI_ERROR; + return status; + } + for (unsigned j = 0; j < rank_cnt; j++) { + topo_print[j] = '0' + (int)topo_map[i][j]; + } + topo_print[rank_cnt] = '\0'; + COLL_UCX_VERBOSE(8, "%s\n", topo_print); + free(topo_print); + topo_print = NULL; + } + } + return status; +} + +static int mca_coll_ucx_init_global_topo(mca_coll_ucx_module_t *module) +{ + if (mca_coll_ucx_component.topo_map != NULL) { + return OMPI_SUCCESS; + } + + /* Derive the 'loc' string from pmix and gather all 'loc' string from all the ranks in the world. */ + int status = OMPI_SUCCESS; + uint16_t *node_index = NULL; + unsigned LOC_SIZE = 64; + unsigned rank_cnt = mca_coll_ucx_component.world_member_count = ompi_comm_size(MPI_COMM_WORLD); + char *topo_info = (char*)malloc(sizeof(char) * LOC_SIZE * rank_cnt); + if (topo_info == NULL) { + status = OMPI_ERROR; + goto end; + } + memset(topo_info, 0, sizeof(char) * LOC_SIZE * rank_cnt); + int ret; + char *val = NULL; + OPAL_MODEX_RECV_VALUE_OPTIONAL(ret, OPAL_PMIX_LOCALITY_STRING, + &opal_proc_local_get()->proc_name, &val, OPAL_STRING); + if (val == NULL || ret != OMPI_SUCCESS) { + status = OMPI_ERROR; + goto end; + } + + ret = ompi_coll_base_allgather_intra_bruck(val, LOC_SIZE, MPI_CHAR, topo_info, LOC_SIZE, MPI_CHAR, MPI_COMM_WORLD, &module->super); + if (ret != OMPI_SUCCESS) { + status = OMPI_ERROR; + goto end; + } + + /* Obtain node index to indicate each 'loc' belongs to which node, + as 'loc' only has info of local machine and contains no network info. */ + node_index = (uint16_t*)malloc(rank_cnt * sizeof(uint16_t)); + if (node_index == NULL) { + status = OMPI_ERROR; + goto end; + } + + ret = mca_coll_ucg_obtain_node_index(rank_cnt, MPI_COMM_WORLD, node_index); + if (ret != OMPI_SUCCESS) { + status = OMPI_ERROR; + goto end; + } + + /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ + ret = mca_coll_ucx_create_topo_map(node_index, topo_info, LOC_SIZE, rank_cnt); + if (ret != OMPI_SUCCESS) { + status = OMPI_ERROR; + goto end; + } + + ret = mca_coll_ucx_print_topo_map(rank_cnt, mca_coll_ucx_component.topo_map); + if (ret != OMPI_SUCCESS) { + status = OMPI_ERROR; + } + +end: + if (val) { + free(val); + val = NULL; + } + + if (node_index) { + free(node_index); + node_index = NULL; + } + if (topo_info) { + free(topo_info); + topo_info = NULL; + } + return status; +} + +static int mca_coll_ucx_find_rank_in_comm_world(struct ompi_communicator_t *comm, int comm_rank) +{ + struct ompi_proc_t *proc = (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, comm_rank); + if (proc == NULL) { + return -1; + } + + unsigned i; + for (i = 0; i < ompi_comm_size(MPI_COMM_WORLD); i++) { + struct ompi_proc_t *rank_iter = (struct ompi_proc_t*)ompi_comm_peer_lookup(MPI_COMM_WORLD, i); + if (rank_iter == proc) { + return i; + } + } + + return -1; +} + +static int mca_coll_ucx_create_comm_topo(ucg_group_params_t *args, struct ompi_communicator_t *comm) +{ + int status; + if (comm == MPI_COMM_WORLD) { + if (args->topo_map != NULL) { + free(args->topo_map); + } + args->topo_map = mca_coll_ucx_component.topo_map; + return OMPI_SUCCESS; + } + + /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ + unsigned i; + for (i = 0; i < args->member_count; i++) { + /* Find the rank in the MPI_COMM_WORLD for rank i in the comm. */ + int world_rank_i = mca_coll_ucx_find_rank_in_comm_world(comm, i); + if (world_rank_i == -1) { + return OMPI_ERROR; + } + for (unsigned j = 0; j <= i; j++) { + int world_rank_j = mca_coll_ucx_find_rank_in_comm_world(comm, j); + if (world_rank_j == -1) { + return OMPI_ERROR; + } + args->topo_map[i][j] = mca_coll_ucx_component.topo_map[world_rank_i][world_rank_j]; + args->topo_map[j][i] = mca_coll_ucx_component.topo_map[world_rank_j][world_rank_i]; + } + } + + status = mca_coll_ucx_print_topo_map(args->member_count, args->topo_map); + return status; +} + +static void mca_coll_ucg_create_distance_array(struct ompi_communicator_t *comm, ucg_group_member_index_t my_idx, ucg_group_params_t *args) +{ + ucg_group_member_index_t rank_idx; + for (rank_idx = 0; rank_idx < args->member_count; rank_idx++) { + struct ompi_proc_t *rank_iter = (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, rank_idx); + rank_iter->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL] = NULL; + if (rank_idx == my_idx) { + args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_SELF; + } else if (OPAL_PROC_ON_LOCAL_L3CACHE(rank_iter->super.proc_flags)) { + args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_L3CACHE; + } else if (OPAL_PROC_ON_LOCAL_SOCKET(rank_iter->super.proc_flags)) { + args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_SOCKET; + } else if (OPAL_PROC_ON_LOCAL_HOST(rank_iter->super.proc_flags)) { + args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_HOST; + } else { + args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_NET; + } + } +} + +static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) +{ + args->member_count = ompi_comm_size(comm); + args->cid = ompi_comm_get_cid(comm); + args->mpi_reduce_f = ompi_op_reduce; + args->resolve_address_f = mca_coll_ucx_resolve_address; + args->release_address_f = mca_coll_ucx_release_address; + args->cb_group_obj = comm; + args->op_is_commute_f = ompi_op_is_commute; +} + +static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args) +{ + unsigned i; + + if (args->distance != NULL) { + free(args->distance); + args->distance = NULL; + } + + if (args->node_index != NULL) { + free(args->node_index); + args->node_index = NULL; + } + + if (comm != MPI_COMM_WORLD && args->topo_map != NULL) { + for (i = 0; i < args->member_count; i++) { + if (args->topo_map[i] != NULL) { + free(args->topo_map[i]); + args->topo_map[i] = NULL; + } + } + free(args->topo_map); + args->topo_map = NULL; + } +} + +static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_communicator_t *comm) +{ + ucs_status_t error; + ucg_group_params_t args; + ucg_group_member_index_t my_idx; + int status = OMPI_SUCCESS; + unsigned i; + +#if OMPI_GROUP_SPARSE + COLL_UCX_ERROR("Sparse process groups are not supported"); + return UCS_ERR_UNSUPPORTED; +#endif + + /* Fill in group initialization parameters */ + my_idx = ompi_comm_rank(comm); + mca_coll_ucg_init_group_param(comm, &args); + args.distance = malloc(args.member_count * sizeof(*args.distance)); + args.node_index = malloc(args.member_count * sizeof(*args.node_index)); + args.is_bind_to_none = (OPAL_BIND_TO_NONE == OPAL_GET_BINDING_POLICY(opal_hwloc_binding_policy)); + args.topo_map = NULL; + + if (args.distance == NULL || args.node_index == NULL) { + MCA_COMMON_UCX_WARN("Failed to allocate memory for %lu local ranks", args.member_count); + status = OMPI_ERROR; + goto out; + } + + if (mca_coll_ucx_component.enable_topo_map && (OPAL_BIND_TO_CORE == OPAL_GET_BINDING_POLICY(opal_hwloc_binding_policy))) { + /* Initialize global topology map. */ + args.topo_map = (char**)malloc(sizeof(char*) * args.member_count); + if (args.topo_map == NULL) { + MCA_COMMON_UCX_WARN("Failed to allocate memory for %lu local ranks", args.member_count); + status = OMPI_ERROR; + goto out; + } + + for (i = 0; i < args.member_count; i++) { + args.topo_map[i] = (char*)malloc(sizeof(char) * args.member_count); + if (args.topo_map[i] == NULL) { + MCA_COMMON_UCX_WARN("Failed to allocate memory for %lu local ranks", args.member_count); + status = OMPI_ERROR; + goto out; + } + } + + status = mca_coll_ucx_init_global_topo(module); + if (status != OMPI_SUCCESS) { + MCA_COMMON_UCX_WARN("Failed to create global topology."); + status = OMPI_ERROR; + goto out; + } + + if (status == OMPI_SUCCESS) { + status = mca_coll_ucx_create_comm_topo(&args, comm); + if (status != OMPI_SUCCESS) { + MCA_COMMON_UCX_WARN("Failed to create communicator topology."); + status = OMPI_ERROR; + goto out; + } + } + } + + /* Generate (temporary) rank-distance array */ + mca_coll_ucg_create_distance_array(comm, my_idx, &args); + + /* Generate node_index for each process */ + status = mca_coll_ucg_obtain_node_index(args.member_count, comm, args.node_index); + + if (status != OMPI_SUCCESS) { + status = OMPI_ERROR; + goto out; + } + + error = ucg_group_create(mca_coll_ucx_component.ucg_worker, &args, &module->ucg_group); + + /* Examine comm_new return value */ + if (error != UCS_OK) { + MCA_COMMON_UCX_WARN("ucg_new failed: %s", ucs_status_string(error)); + status = OMPI_ERROR; + goto out; + } + + ucs_list_add_tail(&mca_coll_ucx_component.group_head, &module->ucs_list); + status = OMPI_SUCCESS; + +out: + mca_coll_ucg_arg_free(comm, &args); + return status; +} + +/* + * Initialize module on the communicator + */ +static int mca_coll_ucx_module_enable(mca_coll_base_module_t *module, + struct ompi_communicator_t *comm) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*) module; + int rc; + + /* prepare the placeholder for the array of request* */ + module->base_data = OBJ_NEW(mca_coll_base_comm_t); + if (NULL == module->base_data) { + return OMPI_ERROR; + } + + rc = mca_coll_ucg_create(ucx_module, comm); + if (rc != OMPI_SUCCESS) { + return rc; + } + + COLL_UCX_FREELIST_INIT(&mca_coll_ucx_component.persistent_ops, mca_coll_ucx_persistent_op_t, + 128, -1, 128); + + COLL_UCX_VERBOSE(1, "UCX Collectives Module initialized"); + return OMPI_SUCCESS; +} + +static int mca_coll_ucx_ft_event(int state) +{ + return OMPI_SUCCESS; +} + +static void mca_coll_ucx_module_construct(mca_coll_ucx_module_t *module) +{ + size_t nonzero = sizeof(module->super.super); + memset((void*)module + nonzero, 0, sizeof(*module) - nonzero); + + module->super.coll_module_enable = mca_coll_ucx_module_enable; + module->super.ft_event = mca_coll_ucx_ft_event; + module->super.coll_allreduce = mca_coll_ucx_allreduce; + module->super.coll_barrier = mca_coll_ucx_barrier; + module->super.coll_bcast = mca_coll_ucx_bcast; + ucs_list_head_init(&module->ucs_list); +} + +static void mca_coll_ucx_module_destruct(mca_coll_ucx_module_t *module) +{ + if (module->ucg_group) { + ucg_group_destroy(module->ucg_group); + } + ucs_list_del(&module->ucs_list); +} + +OBJ_CLASS_INSTANCE(mca_coll_ucx_module_t, + mca_coll_base_module_t, + mca_coll_ucx_module_construct, + mca_coll_ucx_module_destruct); diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c new file mode 100644 index 00000000000..0dfebe42464 --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -0,0 +1,450 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED. + * Copyright (c) 2016 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "coll_ucx.h" +#include "coll_ucx_request.h" + +#include "ompi/message/message.h" +#include + +static inline int mca_coll_ucx_is_datatype_supported(struct ompi_datatype_t *dtype, int count) +{ + return ompi_datatype_is_contiguous_memory_layout(dtype, count); +} + +int mca_coll_ucx_start(size_t count, ompi_request_t** requests) +{ + mca_coll_ucx_persistent_op_t *preq = NULL; + ompi_request_t *tmp_req = NULL; + size_t i; + + for (i = 0; i < count; ++i) { + preq = (mca_coll_ucx_persistent_op_t *)requests[i]; + if ((preq == NULL) || (OMPI_REQUEST_COLL != preq->ompi.req_type)) { + /* Skip irrelevant requests */ + continue; + } + + COLL_UCX_ASSERT(preq->ompi.req_state != OMPI_REQUEST_INVALID); + preq->ompi.req_state = OMPI_REQUEST_ACTIVE; + mca_coll_ucx_request_reset(&preq->ompi); + + tmp_req = ucg_collective_start_nb(preq->coll_desc); + if (tmp_req == NULL) { + COLL_UCX_VERBOSE(8, "collective completed immediately, completing persistent request %p", (void*)preq); + mca_coll_ucx_set_coll_status(&preq->ompi.req_status, UCS_OK); + ompi_request_complete(&preq->ompi, true); + } else if (!UCS_PTR_IS_ERR(tmp_req)) { + if (REQUEST_COMPLETE(tmp_req)) { + /* tmp_req is already completed */ + COLL_UCX_VERBOSE(8, "completing persistent request %p", (void*)preq); + mca_coll_ucx_persistent_op_complete(preq, tmp_req); + } else { + /* tmp_req would be completed by callback and trigger completion + * of preq */ + COLL_UCX_VERBOSE(8, "temporary request %p will complete persistent request %p", + (void*)tmp_req, (void*)preq); + tmp_req->req_complete_cb_data = preq; + preq->tmp_req = tmp_req; + } + } else { + COLL_UCX_ERROR("ucx collective failed: %s", ucs_status_string(UCS_PTR_STATUS(tmp_req))); + return OMPI_ERROR; + } + } + + return OMPI_SUCCESS; +} + +/** + * For each type of collectives there are 3 varieties of function calls: + * blocking, non-blocking and persistent initialization. For example, for + * the allreduce collective operations, those would be called: + * - mca_coll_ucx_allreduce + * - mca_coll_ucx_iallreduce + * - mca_coll_ucx_iallreduce_init + * + * In the blocking version, request is placed on the stack, awaiting completion. + * For non-blocking, request is allocated by UCX, awaiting completion. + * For persistent requests, the collective starts later - only then the + * (internal) request is created (by UCX) and placed as "tmp_req" inside + * the persistent (external) request structure. + */ +#define COLL_UCX_TRACE(_msg, _sbuf, _rbuf, _count, _datatype, _comm, ...) \ + COLL_UCX_VERBOSE(8, _msg " sbuf %p rbuf %p count %i type '%s' comm %d '%s'", \ + __VA_ARGS__, (_sbuf), (_rbuf), (_count), (_datatype)->name, \ + (_comm)->c_contextid, (_comm)->c_name); + +#define COLL_UCX_REQ_ALLOCA(ucx_module) \ + ((char *)alloca(mca_coll_ucx_component.request_size) + \ + mca_coll_ucx_component.request_size); + +int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + if (ucs_unlikely(!mca_coll_ucx_is_datatype_supported(dtype, count))) { + COLL_UCX_ERROR("UCX component does not support discontinuous datatype. Please use other coll component."); + return OMPI_ERR_NOT_SUPPORTED; + } + COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + ptrdiff_t dtype_size; + ucg_coll_h coll = NULL; + ompi_datatype_type_extent(dtype, &dtype_size); + ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, + op, 0, 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx allreduce start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx allreduce", (void)0); + COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce END"); + + return OMPI_SUCCESS; +} + +int mca_coll_ucx_iallreduce(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + struct ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + COLL_UCX_TRACE("iallreduce request *%p", sbuf, rbuf, count, dtype, comm, + (void*)request); + + ptrdiff_t dtype_size; + ucg_coll_h coll = NULL; + ompi_datatype_type_extent(dtype, &dtype_size); + ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, + mca_coll_ucx_coll_completion, op, 0, 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ompi_request_t *req = (ompi_request_t*)ucg_collective_start_nb(coll); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx allreduce start failed: %s", + ucs_status_string(UCS_PTR_STATUS(req))); + return OMPI_ERROR; + } + + if (req == NULL) { + COLL_UCX_VERBOSE(8, "returning completed request"); + *request = &mca_coll_ucx_component.completed_send_req; + return OMPI_SUCCESS; + } + + COLL_UCX_VERBOSE(8, "got request %p", (void*)req); + *request = req; + return OMPI_SUCCESS; +} + +int mca_coll_ucx_allreduce_init(const void *sbuf, void *rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + struct ompi_info_t *info, + struct ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + mca_coll_ucx_persistent_op_t *req = + (mca_coll_ucx_persistent_op_t *) + COLL_UCX_FREELIST_GET(&mca_coll_ucx_component.persistent_ops); + if (OPAL_UNLIKELY(req == NULL)) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + + COLL_UCX_TRACE("iallreduce_init request *%p=%p", + sbuf, rbuf, count, dtype, comm, (void*)request, (void*)req); + + ptrdiff_t dtype_size; + ompi_datatype_type_extent(dtype, &dtype_size); + ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, + mca_coll_ucx_pcoll_completion, op, 0, + UCG_GROUP_COLLECTIVE_MODIFIER_PERSISTENT, &req->coll_desc); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx allreduce failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + req->ompi.req_state = OMPI_REQUEST_INACTIVE; + *request = &req->ompi; + return OMPI_SUCCESS; +} + +int mca_coll_ucx_reduce(const void *sbuf, void* rbuf, int count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + + ptrdiff_t dtype_size; + ucg_coll_h coll = NULL; + ompi_datatype_type_extent(dtype, &dtype_size); + ucs_status_t ret = ucg_coll_reduce_init(sbuf, rbuf, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, + op, root, 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx reduce init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx reduce start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx reduce", (void)0); +} + +int mca_coll_ucx_scatter(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, + int root, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + COLL_UCX_TRACE("%s", sbuf, rbuf, scount, sdtype, comm, "scatter"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + + ucg_coll_h coll = NULL; + ptrdiff_t sdtype_size, rdtype_size; + ompi_datatype_type_extent(sdtype, &sdtype_size); + ompi_datatype_type_extent(rdtype, &rdtype_size); + ucs_status_t ret = ucg_coll_scatter_init(sbuf, scount, (size_t)sdtype_size, sdtype, + rbuf, rcount, (size_t)rdtype_size, rdtype, + ucx_module->ucg_group, 0, 0, root, + 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx reduce init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx reduce start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx scatter", (void)0); +} + +int mca_coll_ucx_gather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, int root, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + COLL_UCX_TRACE("%s", sbuf, rbuf, scount, sdtype, comm, "gather"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + + ucg_coll_h coll = NULL; + ptrdiff_t sdtype_size, rdtype_size; + ompi_datatype_type_extent(sdtype, &sdtype_size); + ompi_datatype_type_extent(rdtype, &rdtype_size); + ucs_status_t ret = ucg_coll_gather_init(sbuf, scount, (size_t)sdtype_size, sdtype, + rbuf, rcount, (size_t)rdtype_size, rdtype, + ucx_module->ucg_group, 0, 0, root, + 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx reduce init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx reduce start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx gather", (void)0); +} + +int mca_coll_ucx_allgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + COLL_UCX_TRACE("%s", sbuf, rbuf, scount, sdtype, comm, "allgather"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + + ucg_coll_h coll = NULL; + ptrdiff_t sdtype_size, rdtype_size; + ompi_datatype_type_extent(sdtype, &sdtype_size); + ompi_datatype_type_extent(rdtype, &rdtype_size); + ucs_status_t ret = ucg_coll_allgather_init(sbuf, scount, (size_t)sdtype_size, sdtype, + rbuf, rcount, (size_t)rdtype_size, rdtype, + ucx_module->ucg_group, 0, 0, 0, + 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx allgather init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx allgather start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx allgather", (void)0); +} + +int mca_coll_ucx_alltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + COLL_UCX_TRACE("%s", sbuf, rbuf, scount, sdtype, comm, "alltoall"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + + ucg_coll_h coll = NULL; + ptrdiff_t sdtype_size, rdtype_size; + ompi_datatype_type_extent(sdtype, &sdtype_size); + ompi_datatype_type_extent(rdtype, &rdtype_size); + ucs_status_t ret = ucg_coll_alltoall_init(sbuf, scount, (size_t)sdtype_size, sdtype, + rbuf, rcount, (size_t)rdtype_size, rdtype, + ucx_module->ucg_group, 0, 0, 0, + 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx alltoall init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx alltoall start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx alltoall", (void)0); +} + +int mca_coll_ucx_barrier(struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + + ucg_coll_h coll = NULL; + ucs_status_t ret = ucg_coll_barrier_init(0, ucx_module->ucg_group, 0, 0, 0, 0, &coll); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx barrier init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx barrier start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx barrier", (void)0); +} + +int mca_coll_ucx_bcast(void *buff, int count, struct ompi_datatype_t *dtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + + if (ucs_unlikely(!mca_coll_ucx_is_datatype_supported(dtype, count))) { + COLL_UCX_ERROR("UCX component does not support discontinuous datatype. Please use other coll component."); + return OMPI_ERR_NOT_SUPPORTED; + } + COLL_UCX_TRACE("%s", buff, buff, count, dtype, comm, "bcast"); + + ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); + ptrdiff_t dtype_size; + ucg_coll_h coll = NULL; + ompi_datatype_type_extent(dtype, &dtype_size); + ucs_status_t ret = ucg_coll_bcast_init(buff, buff, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, + 0, root, 0, &coll); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx bcast init failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + ret = ucg_collective_start_nbr(coll, req); + if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { + COLL_UCX_ERROR("ucx bcast start failed: %s", ucs_status_string(ret)); + return OMPI_ERROR; + } + + if (ucs_unlikely(ret == UCS_OK)) { + return OMPI_SUCCESS; + } + + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx bcast", (void)0); +} diff --git a/ompi/mca/coll/ucx/coll_ucx_request.c b/ompi/mca/coll/ucx/coll_ucx_request.c new file mode 100644 index 00000000000..1c5161eee7e --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_request.c @@ -0,0 +1,177 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED. + * Copyright (c) 2016 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include +#include "coll_ucx_request.h" +#include "ompi/message/message.h" + + +static int mca_coll_ucx_request_free(ompi_request_t **rptr) +{ + ompi_request_t *req = *rptr; + + COLL_UCX_VERBOSE(9, "free request *%p=%p", (void*)rptr, (void*)req); + + *rptr = MPI_REQUEST_NULL; + mca_coll_ucx_request_reset(req); + ucg_request_free(req); + return OMPI_SUCCESS; +} + +static int mca_coll_ucx_request_cancel(ompi_request_t *req, int flag) +{ + ucg_request_cancel(mca_coll_ucx_component.ucg_worker, req); + return OMPI_SUCCESS; +} + +void mca_coll_ucx_coll_completion(void *request, ucs_status_t status) +{ + ompi_request_t *req = request; + + COLL_UCX_VERBOSE(8, "send request %p completed with status %s", (void*)req, ucs_status_string(status)); + + mca_coll_ucx_set_coll_status(&req->req_status, status); + COLL_UCX_ASSERT(!(REQUEST_COMPLETE(req))); + ompi_request_complete(req, true); +} + +static void mca_coll_ucx_persistent_op_detach(mca_coll_ucx_persistent_op_t *preq, ompi_request_t *tmp_req) +{ + tmp_req->req_complete_cb_data = NULL; + preq->tmp_req = NULL; +} + +inline void mca_coll_ucx_persistent_op_complete(mca_coll_ucx_persistent_op_t *preq, ompi_request_t *tmp_req) +{ + preq->ompi.req_status = tmp_req->req_status; + mca_coll_ucx_request_reset(tmp_req); + mca_coll_ucx_persistent_op_detach(preq, tmp_req); + ucg_request_free(tmp_req); + ompi_request_complete(&preq->ompi, true); +} + +static inline void mca_coll_ucx_preq_completion(ompi_request_t *tmp_req) +{ + mca_coll_ucx_persistent_op_t *preq = NULL; + + ompi_request_complete(tmp_req, false); + preq = (mca_coll_ucx_persistent_op_t*)tmp_req->req_complete_cb_data; + if (preq != NULL) { + COLL_UCX_ASSERT(preq->tmp_req != NULL); + mca_coll_ucx_persistent_op_complete(preq, tmp_req); + } +} + +void mca_coll_ucx_pcoll_completion(void *request, ucs_status_t status) +{ + ompi_request_t *tmp_req = request; + + COLL_UCX_VERBOSE(8, "persistent collective request %p completed with status %s", + (void*)tmp_req, ucs_status_string(status)); + + mca_coll_ucx_set_coll_status(&tmp_req->req_status, status); + mca_coll_ucx_preq_completion(tmp_req); +} + +static void mca_coll_ucx_request_init_common(ompi_request_t* ompi_req, bool op_persistent, ompi_request_state_t state, + ompi_request_free_fn_t req_free, ompi_request_cancel_fn_t req_cancel) +{ + OMPI_REQUEST_INIT(ompi_req, op_persistent); + ompi_req->req_type = OMPI_REQUEST_COLL; + ompi_req->req_state = state; + ompi_req->req_start = mca_coll_ucx_start; + ompi_req->req_free = req_free; + ompi_req->req_cancel = req_cancel; + /* This field is used to attach persistant request to a temporary req. + * Receive (ucg_tag_recv_nb) may call completion callback + * before the field is set. If the field is not NULL then mca_coll_ucx_preq_completion() + * will try to complete bogus persistant request. + */ + ompi_req->req_complete_cb_data = NULL; +} + +void mca_coll_ucx_request_init(void *request) +{ + ompi_request_t* ompi_req = request; + OBJ_CONSTRUCT(ompi_req, ompi_request_t); + mca_coll_ucx_request_init_common(ompi_req, false, OMPI_REQUEST_ACTIVE, mca_coll_ucx_request_free, + mca_coll_ucx_request_cancel); +} + +void mca_coll_ucx_request_cleanup(void *request) +{ + ompi_request_t* ompi_req = request; + ompi_req->req_state = OMPI_REQUEST_INVALID; + OMPI_REQUEST_FINI(ompi_req); + OBJ_DESTRUCT(ompi_req); +} + +static int mca_coll_ucx_persistent_op_free(ompi_request_t **rptr) +{ + mca_coll_ucx_persistent_op_t* preq = (mca_coll_ucx_persistent_op_t*)*rptr; + ompi_request_t *tmp_req = preq->tmp_req; + + preq->ompi.req_state = OMPI_REQUEST_INVALID; + if (tmp_req != NULL) { + mca_coll_ucx_persistent_op_detach(preq, tmp_req); + ucg_request_free(tmp_req); + } + + COLL_UCX_FREELIST_RETURN(&mca_coll_ucx_component.persistent_ops, &preq->ompi.super); + *rptr = MPI_REQUEST_NULL; + return OMPI_SUCCESS; +} + +static int mca_coll_ucx_persistent_op_cancel(ompi_request_t *req, int flag) +{ + mca_coll_ucx_persistent_op_t* preq = (mca_coll_ucx_persistent_op_t*)req; + + if (preq->tmp_req != NULL) { + ucg_request_cancel(preq->ucg_worker, preq->tmp_req); + } + return OMPI_SUCCESS; +} + +static void mca_coll_ucx_persisternt_op_construct(mca_coll_ucx_persistent_op_t* req) +{ + mca_coll_ucx_request_init_common(&req->ompi, true, OMPI_REQUEST_INACTIVE, mca_coll_ucx_persistent_op_free, + mca_coll_ucx_persistent_op_cancel); + req->tmp_req = NULL; +} + +static void mca_coll_ucx_persisternt_op_destruct(mca_coll_ucx_persistent_op_t* req) +{ + req->ompi.req_state = OMPI_REQUEST_INVALID; + OMPI_REQUEST_FINI(&req->ompi); +} + +OBJ_CLASS_INSTANCE(mca_coll_ucx_persistent_op_t, ompi_request_t, mca_coll_ucx_persisternt_op_construct, + mca_coll_ucx_persisternt_op_destruct); + +static int mca_coll_completed_request_free(struct ompi_request_t** rptr) +{ + *rptr = MPI_REQUEST_NULL; + return OMPI_SUCCESS; +} + +static int mca_coll_completed_request_cancel(struct ompi_request_t* ompi_req, int flag) +{ + return OMPI_SUCCESS; +} + +void mca_coll_ucx_completed_request_init(ompi_request_t *ompi_req) +{ + mca_coll_ucx_request_init_common(ompi_req, false, OMPI_REQUEST_ACTIVE, mca_coll_completed_request_free, + mca_coll_completed_request_cancel); + ompi_request_complete(ompi_req, false); +} diff --git a/ompi/mca/coll/ucx/coll_ucx_request.h b/ompi/mca/coll/ucx/coll_ucx_request.h new file mode 100644 index 00000000000..d988419746e --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_request.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2015. ALL RIGHTS RESERVED. + * Copyright (c) 2016 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef COLL_UCX_REQUEST_H_ +#define COLL_UCX_REQUEST_H_ + +#include "coll_ucx.h" + +enum { + MCA_PML_UCX_REQUEST_FLAG_SEND = (1 << 0), /* Persistent send */ + MCA_PML_UCX_REQUEST_FLAG_FREE_CALLED = (1 << 1), + MCA_PML_UCX_REQUEST_FLAG_COMPLETED = (1 << 2) +}; + +struct coll_ucx_persistent_op { + ompi_request_t ompi; + ompi_request_t *tmp_req; + ucg_coll_h coll_desc; + ucg_worker_h ucg_worker; + unsigned flags; +}; + + +void mca_coll_ucx_coll_completion(void *request, ucs_status_t status); + +void mca_coll_ucx_pcoll_completion(void *request, ucs_status_t status); + +void mca_coll_ucx_persistent_op_complete(mca_coll_ucx_persistent_op_t *preq, ompi_request_t *tmp_req); + +void mca_coll_ucx_completed_request_init(ompi_request_t *ompi_req); + +void mca_coll_ucx_request_init(void *request); + +void mca_coll_ucx_request_cleanup(void *request); + + +static inline void mca_coll_ucx_request_reset(ompi_request_t *req) +{ + req->req_complete = REQUEST_PENDING; +} + +static inline void mca_coll_ucx_set_coll_status(ompi_status_public_t* mpi_status, + ucs_status_t status) +{ + if (OPAL_LIKELY(status == UCS_OK)) { + mpi_status->MPI_ERROR = MPI_SUCCESS; + mpi_status->_cancelled = false; + } else if (status == UCS_ERR_CANCELED) { + mpi_status->_cancelled = true; + } else { + mpi_status->MPI_ERROR = MPI_ERR_INTERN; + } +} + +OBJ_CLASS_DECLARATION(mca_coll_ucx_persistent_op_t); + + +#endif /* COLL_UCX_REQUEST_H_ */ diff --git a/ompi/mca/coll/ucx/configure.m4 b/ompi/mca/coll/ucx/configure.m4 new file mode 100644 index 00000000000..6e71a579bd9 --- /dev/null +++ b/ompi/mca/coll/ucx/configure.m4 @@ -0,0 +1,37 @@ +# -*- shell-script -*- +# +# +# Copyright (c) 2011 Mellanox Technologies. All rights reserved. +# Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. +# $COPYRIGHT$ +# +# Additional copyrights may follow +# +# $HEADER$ +# + +AC_DEFUN([MCA_ompi_coll_ucx_POST_CONFIG], [ + AS_IF([test "$1" = "1"], [OMPI_REQUIRE_ENDPOINT_TAG([COLL])]) +]) + +# MCA_coll_ucx_CONFIG([action-if-can-compile], +# [action-if-cant-compile]) +# ------------------------------------------------ +AC_DEFUN([MCA_ompi_coll_ucx_CONFIG],[ + AC_CONFIG_FILES([ompi/mca/coll/ucx/Makefile]) + + OMPI_CHECK_UCX([coll_ucx], + [coll_ucx_happy="yes"], + [coll_ucx_happy="no"]) + + AS_IF([test "$coll_ucx_happy" = "yes"], + [$1], + [$2]) + + # substitute in the things needed to build ucx + AC_SUBST([coll_ucx_CFLAGS]) + AC_SUBST([coll_ucx_CPPFLAGS]) + AC_SUBST([coll_ucx_LDFLAGS]) + AC_SUBST([coll_ucx_LIBS]) +])dnl + diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index fb7b7f84615..5cf18197438 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -5,6 +5,7 @@ * reserved. * Copyright (c) 2018 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -668,7 +669,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat } static ucs_status_ptr_t -mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count, +mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count, ompi_datatype_t *datatype, uint64_t pml_tag) { ompi_request_t *req; @@ -797,7 +798,8 @@ mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count, return OMPI_SUCCESS; } else if (!UCS_PTR_IS_ERR(req)) { PML_UCX_VERBOSE(8, "got request %p", (void*)req); - MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", ompi_request_free(&req)); + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCP, ompi_pml_ucx.ucp_worker, "ucx send", + ompi_request_free(&req)); } else { PML_UCX_ERROR("ucx send failed: %s", ucs_status_string(UCS_PTR_STATUS(req))); return OMPI_ERROR; @@ -820,7 +822,7 @@ mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count, return OMPI_SUCCESS; } - MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", (void)0); + MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCP, ompi_pml_ucx.ucp_worker, "ucx send", (void)0); } #endif diff --git a/opal/mca/common/ucx/common_ucx.c b/opal/mca/common/ucx/common_ucx.c index bf5d6c04943..e3ffcce30a2 100644 --- a/opal/mca/common/ucx/common_ucx.c +++ b/opal/mca/common/ucx/common_ucx.c @@ -1,5 +1,6 @@ /* * Copyright (C) Mellanox Technologies Ltd. 2018. ALL RIGHTS RESERVED. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -163,8 +164,7 @@ OPAL_DECLSPEC int opal_common_ucx_mca_pmix_fence(ucp_worker_h worker) volatile int fenced = 0; int ret = OPAL_SUCCESS; - if (OPAL_SUCCESS != (ret = opal_pmix.fence_nb(NULL, 0, - opal_common_ucx_mca_fence_complete_cb, (void*)&fenced))){ + if (OPAL_SUCCESS != (ret = opal_pmix.fence_nb(NULL, 0, opal_common_ucx_mca_fence_complete_cb, (void*)&fenced))) { return ret; } @@ -175,13 +175,14 @@ OPAL_DECLSPEC int opal_common_ucx_mca_pmix_fence(ucp_worker_h worker) return ret; } -static void opal_common_ucx_wait_all_requests(void **reqs, int count, ucp_worker_h worker) +static void opal_common_ucx_wait_all_requests(void **reqs, int count, ucp_worker_h worker, + enum opal_common_ucx_req_type type) { int i; MCA_COMMON_UCX_VERBOSE(2, "waiting for %d disconnect requests", count); for (i = 0; i < count; ++i) { - opal_common_ucx_wait_request(reqs[i], worker, "ucp_disconnect_nb"); + opal_common_ucx_wait_request(reqs[i], worker, type, "ucp_disconnect_nb"); reqs[i] = NULL; } } @@ -225,7 +226,7 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t * } else { dreqs[num_reqs++] = dreq; if (num_reqs >= max_disconnect) { - opal_common_ucx_wait_all_requests(dreqs, num_reqs, worker); + opal_common_ucx_wait_all_requests(dreqs, num_reqs, worker, OPAL_COMMON_UCX_REQUEST_TYPE_UCP); num_reqs = 0; } } @@ -234,7 +235,7 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t * /* num_reqs == 0 is processed by opal_common_ucx_wait_all_requests routine, * so suppress coverity warning */ /* coverity[uninit_use_in_call] */ - opal_common_ucx_wait_all_requests(dreqs, num_reqs, worker); + opal_common_ucx_wait_all_requests(dreqs, num_reqs, worker, OPAL_COMMON_UCX_REQUEST_TYPE_UCP); free(dreqs); return OPAL_SUCCESS; diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index 202131ac890..6bd53c973ac 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -1,6 +1,7 @@ /* * Copyright (c) 2018 Mellanox Technologies. All rights reserved. * All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -16,6 +17,9 @@ #include #include +#if HAVE_UCG_API_UCG_H +#include +#endif #include "opal/mca/mca.h" #include "opal/util/output.h" @@ -56,17 +60,25 @@ BEGIN_C_DECLS } /* progress loop to allow call UCX/opal progress */ +enum opal_common_ucx_req_type { + OPAL_COMMON_UCX_REQUEST_TYPE_UCP = 0, +#if HAVE_UCG_API_UCG_H + OPAL_COMMON_UCX_REQUEST_TYPE_UCG = 1 +#endif +}; + +/* progress loop to allow call UCX/opal progress, while testing requests by type */ /* used C99 for-statement variable initialization */ #define MCA_COMMON_UCX_PROGRESS_LOOP(_worker) \ for (unsigned iter = 0;; (++iter % opal_common_ucx.progress_iterations) ? \ (void)ucp_worker_progress(_worker) : opal_progress()) -#define MCA_COMMON_UCX_WAIT_LOOP(_request, _worker, _msg, _completed) \ +#define MCA_COMMON_UCX_WAIT_LOOP(_request, _req_type, _worker, _msg, _completed) \ do { \ ucs_status_t status; \ /* call UCX progress */ \ MCA_COMMON_UCX_PROGRESS_LOOP(_worker) { \ - status = opal_common_ucx_request_status(_request); \ + status = opal_common_ucx_request_status(_request, _req_type); \ if (UCS_INPROGRESS != status) { \ _completed; \ if (OPAL_LIKELY(UCS_OK == status)) { \ @@ -110,20 +122,31 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t * OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component); static inline -ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request) +ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request, + enum opal_common_ucx_req_type type) { + switch (type) { + case OPAL_COMMON_UCX_REQUEST_TYPE_UCP: #if !HAVE_DECL_UCP_REQUEST_CHECK_STATUS - ucp_tag_recv_info_t info; + ucp_tag_recv_info_t info; - return ucp_request_test(request, &info); + return ucp_request_test(request, &info); #else - return ucp_request_check_status(request); + return ucp_request_check_status(request); #endif +#if HAVE_UCG_API_UCG_H + case OPAL_COMMON_UCX_REQUEST_TYPE_UCG: + return ucg_request_check_status(request); +#endif + default: + break; + } + return OPAL_ERROR; } static inline int opal_common_ucx_wait_request(ucs_status_ptr_t request, ucp_worker_h worker, - const char *msg) + enum opal_common_ucx_req_type type, const char *msg) { /* check for request completed or failed */ if (OPAL_LIKELY(UCS_OK == request)) { @@ -135,7 +158,7 @@ int opal_common_ucx_wait_request(ucs_status_ptr_t request, ucp_worker_h worker, return OPAL_ERROR; } - MCA_COMMON_UCX_WAIT_LOOP(request, worker, msg, ucp_request_free(request)); + MCA_COMMON_UCX_WAIT_LOOP(request, type, worker, msg, ucp_request_free(request)); } static inline @@ -145,7 +168,7 @@ int opal_common_ucx_ep_flush(ucp_ep_h ep, ucp_worker_h worker) ucs_status_ptr_t request; request = ucp_ep_flush_nb(ep, 0, opal_common_ucx_empty_complete_cb); - return opal_common_ucx_wait_request(request, worker, "ucp_ep_flush_nb"); + return opal_common_ucx_wait_request(request, worker, OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_ep_flush_nb"); #else ucs_status_t status; @@ -161,7 +184,7 @@ int opal_common_ucx_worker_flush(ucp_worker_h worker) ucs_status_ptr_t request; request = ucp_worker_flush_nb(worker, 0, opal_common_ucx_empty_complete_cb); - return opal_common_ucx_wait_request(request, worker, "ucp_worker_flush_nb"); + return opal_common_ucx_wait_request(request, worker, OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_worker_flush_nb"); #else ucs_status_t status; @@ -180,7 +203,7 @@ int opal_common_ucx_atomic_fetch(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, request = ucp_atomic_fetch_nb(ep, opcode, value, result, op_size, remote_addr, rkey, opal_common_ucx_empty_complete_cb); - return opal_common_ucx_wait_request(request, worker, "ucp_atomic_fetch_nb"); + return opal_common_ucx_wait_request(request, worker, OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_atomic_fetch_nb"); } static inline diff --git a/orte/mca/rmaps/base/rmaps_base_ranking.c b/orte/mca/rmaps/base/rmaps_base_ranking.c index e4f67d9f4d5..a185b075d78 100644 --- a/orte/mca/rmaps/base/rmaps_base_ranking.c +++ b/orte/mca/rmaps/base/rmaps_base_ranking.c @@ -13,6 +13,7 @@ * Copyright (c) 2014-2018 Intel, Inc. All rights reserved. * Copyright (c) 2017 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -84,7 +85,6 @@ static int rank_span(orte_job_t *jdata, * just loop across the nodes and objects until all procs * are mapped */ - vpid = 0; for (n=0; n < jdata->apps->size; n++) { if (NULL == (app = (orte_app_context_t*)opal_pointer_array_get_item(jdata->apps, n))) { @@ -142,7 +142,8 @@ static int rank_span(orte_job_t *jdata, return ORTE_ERROR; } /* ignore procs not on this object */ - if (!hwloc_bitmap_intersects(obj->cpuset, locale->cpuset)) { + if (NULL == locale || + !hwloc_bitmap_intersects(obj->cpuset, locale->cpuset)) { opal_output_verbose(5, orte_rmaps_base_framework.framework_output, "mca:rmaps:rank_span: proc at position %d is not on object %d", j, i); @@ -175,6 +176,11 @@ static int rank_span(orte_job_t *jdata, } } } + + /* Are all the procs ranked? we don't want to crash on INVALID ranks */ + if (cnt < app->num_procs) { + return ORTE_ERR_NOT_SUPPORTED; + } } return ORTE_SUCCESS; @@ -206,7 +212,6 @@ static int rank_fill(orte_job_t *jdata, * 0 1 4 5 8 9 12 13 * 2 3 6 7 10 11 14 15 */ - vpid = 0; for (n=0; n < jdata->apps->size; n++) { if (NULL == (app = (orte_app_context_t*)opal_pointer_array_get_item(jdata->apps, n))) { @@ -263,7 +268,8 @@ static int rank_fill(orte_job_t *jdata, return ORTE_ERROR; } /* ignore procs not on this object */ - if (!hwloc_bitmap_intersects(obj->cpuset, locale->cpuset)) { + if (NULL == locale || + !hwloc_bitmap_intersects(obj->cpuset, locale->cpuset)) { opal_output_verbose(5, orte_rmaps_base_framework.framework_output, "mca:rmaps:rank_fill: proc at position %d is not on object %d", j, i); @@ -293,6 +299,11 @@ static int rank_fill(orte_job_t *jdata, } } } + + /* Are all the procs ranked? we don't want to crash on INVALID ranks */ + if (cnt < app->num_procs) { + return ORTE_ERR_NOT_SUPPORTED; + } } return ORTE_SUCCESS; @@ -331,7 +342,6 @@ static int rank_by(orte_job_t *jdata, * 0 2 1 3 8 10 9 11 * 4 6 5 7 12 14 13 15 */ - vpid = 0; for (n=0, napp=0; napp < jdata->num_apps && n < jdata->apps->size; n++) { if (NULL == (app = (orte_app_context_t*)opal_pointer_array_get_item(jdata->apps, n))) { @@ -378,7 +388,8 @@ static int rank_by(orte_job_t *jdata, * algorithm, but this works for now. */ i = 0; - while (cnt < app->num_procs && i < (int)node->num_procs) { + while (cnt < app->num_procs && + ((i < (int)node->num_procs) || (i < num_objs))) { /* get the next object */ obj = (hwloc_obj_t)opal_pointer_array_get_item(&objs, i % num_objs); if (NULL == obj) { @@ -423,7 +434,7 @@ static int rank_by(orte_job_t *jdata, !hwloc_bitmap_intersects(obj->cpuset, locale->cpuset)) { opal_output_verbose(5, orte_rmaps_base_framework.framework_output, "mca:rmaps:rank_by: proc at position %d is not on object %d", - j, i); + j, i % num_objs); continue; } /* assign the vpid */ @@ -458,6 +469,11 @@ static int rank_by(orte_job_t *jdata, } /* cleanup */ OBJ_DESTRUCT(&objs); + + /* Are all the procs ranked? we don't want to crash on INVALID ranks */ + if (cnt < app->num_procs) { + return ORTE_ERR_NOT_SUPPORTED; + } } return ORTE_SUCCESS; } diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c index b3cddcd6d2b..b81c0d59d9c 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c @@ -1,6 +1,7 @@ /* * Copyright (c) 2013 Mellanox Technologies, Inc. * All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -51,5 +52,6 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx, } return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], + OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_atomic_fetch_nb"); } diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_module.c b/oshmem/mca/atomic/ucx/atomic_ucx_module.c index 34ed0b551b9..d6470180c6e 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_module.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_module.c @@ -1,6 +1,7 @@ /* * Copyright (c) 2015 Mellanox Technologies, Inc. * All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -81,6 +82,7 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx, rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], + OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_atomic_fetch_nb"); } diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 60453e92438..367d11f7f77 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -4,6 +4,7 @@ * Copyright (c) 2014-2018 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2016 ARM, Inc. All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -825,7 +826,7 @@ int mca_spml_ucx_get(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_add #if HAVE_DECL_UCP_GET_NB request = ucp_get_nb(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); - return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_get_nb"); + return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_get_nb"); #else status = ucp_get(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -887,7 +888,7 @@ int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_add #if HAVE_DECL_UCP_PUT_NB request = ucp_put_nb(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); - res = opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_put_nb"); + res = opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], OPAL_COMMON_UCX_REQUEST_TYPE_UCP, "ucp_put_nb"); #else status = ucp_put(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c index a069bf5cd2e..c7812a89357 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c @@ -1,6 +1,7 @@ /* * Copyright (c) 2017 Mellanox Technologies, Inc. * All rights reserved. + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -226,7 +227,7 @@ segment_create(map_segment_t *ds_buf, int ret; #if HAVE_UCX_DEVICE_MEM - int ret = OSHMEM_ERROR; + ret = OSHMEM_ERROR; if (hint & SHMEM_HINT_DEVICE_NIC_MEM) { if (size > UINT_MAX) { return OSHMEM_ERR_BAD_PARAM; From affeec0ec20b361096d62c395d57c2cb66649c4e Mon Sep 17 00:00:00 2001 From: Alex Margolin Date: Sun, 1 Nov 2020 16:19:30 +0200 Subject: [PATCH 02/20] UCG: fix configure to include UCG Signed-off-by: Alex Margolin --- config/ompi_check_ucx.m4 | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index 52f80f0c083..1081949d26a 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -46,7 +46,6 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [true])]) ompi_check_ucx_happy="no" ompi_check_ucg_happy="no" - ucx_libs="-luct -lucm -lucs" AS_IF([test -z "$ompi_check_ucx_dir"], [OPAL_CHECK_PACKAGE([ompi_check_ucx], [ucp/api/ucp.h], @@ -60,8 +59,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ OPAL_CHECK_PACKAGE([ompi_check_ucg], [ucg/api/ucg.h], [ucg], - [ucg_request_check_status], - [-lucg -lucp $ucx_libs], + [ucg_collective_destroy], + [-lucp -luct -lucm -lucs], [], [], [ompi_check_ucg_happy="yes"], @@ -101,8 +100,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ OPAL_CHECK_PACKAGE([ompi_check_ucg], [ucg/api/ucg.h], [ucg], - [ucg_request_check_status], - [-lucg -lucp $ucx_libs], + [ucg_collective_destroy], + [-lucp -luct -lucm -lucs], [$ompi_check_ucx_dir], [$ompi_check_ucx_libdir], [ompi_check_ucg_happy="yes"], From 4980a36522520de156c4215059dd6366280aede8 Mon Sep 17 00:00:00 2001 From: Alex Margolin Date: Sun, 1 Nov 2020 23:49:57 +0200 Subject: [PATCH 03/20] UCG: support for non-contiguous datatypes Signed-off-by: Alex Margolin --- ompi/mca/coll/ucx/Makefile.am | 2 + ompi/mca/coll/ucx/coll_ucx.h | 9 +- ompi/mca/coll/ucx/coll_ucx_component.c | 29 ++- ompi/mca/coll/ucx/coll_ucx_datatype.c | 271 +++++++++++++++++++++++++ ompi/mca/coll/ucx/coll_ucx_datatype.h | 84 ++++++++ ompi/mca/coll/ucx/coll_ucx_module.c | 36 +++- ompi/mca/coll/ucx/coll_ucx_op.c | 13 -- 7 files changed, 416 insertions(+), 28 deletions(-) create mode 100644 ompi/mca/coll/ucx/coll_ucx_datatype.c create mode 100644 ompi/mca/coll/ucx/coll_ucx_datatype.h diff --git a/ompi/mca/coll/ucx/Makefile.am b/ompi/mca/coll/ucx/Makefile.am index e82d63f82b8..60bee91e955 100644 --- a/ompi/mca/coll/ucx/Makefile.am +++ b/ompi/mca/coll/ucx/Makefile.am @@ -18,10 +18,12 @@ AM_CPPFLAGS = $(coll_ucx_CPPFLAGS) -DCOLL_UCX_HOME=\"$(coll_ucx_HOME)\" $(coll_u coll_ucx_sources = \ coll_ucx.h \ coll_ucx_request.h \ + coll_ucx_datatype.h \ coll_ucx_freelist.h \ coll_ucx_op.c \ coll_ucx_module.c \ coll_ucx_request.c \ + coll_ucx_datatype.c \ coll_ucx_component.c # Make the output library in this directory, and name it either diff --git a/ompi/mca/coll/ucx/coll_ucx.h b/ompi/mca/coll/ucx/coll_ucx.h index 5bef95720cd..9e3bf616651 100644 --- a/ompi/mca/coll/ucx/coll_ucx.h +++ b/ompi/mca/coll/ucx/coll_ucx.h @@ -26,9 +26,9 @@ #include "ompi/communicator/communicator.h" #include "ompi/datatype/ompi_datatype.h" #include "ompi/attribute/attribute.h" -#include "ompi/op/op.h" #include "orte/runtime/orte_globals.h" +#include "ompi/datatype/ompi_datatype_internal.h" #include "opal/mca/common/ucx/common_ucx.h" #include "ucg/api/ucg_mpi.h" @@ -71,6 +71,13 @@ typedef struct mca_coll_ucx_component { mca_coll_ucx_freelist_t persistent_ops; ompi_request_t completed_send_req; size_t request_size; + + /* Datatypes */ + int datatype_attr_keyval; + ucp_datatype_t predefined_types[OMPI_DATATYPE_MPI_MAX_PREDEFINED]; + + /* Converters pool */ + mca_coll_ucx_freelist_t convs; } mca_coll_ucx_component_t; OMPI_MODULE_DECLSPEC extern mca_coll_ucx_component_t mca_coll_ucx_component; diff --git a/ompi/mca/coll/ucx/coll_ucx_component.c b/ompi/mca/coll/ucx/coll_ucx_component.c index 4be598320b5..ee73c8fee5f 100644 --- a/ompi/mca/coll/ucx/coll_ucx_component.c +++ b/ompi/mca/coll/ucx/coll_ucx_component.c @@ -21,6 +21,7 @@ #include "coll_ucx.h" #include "coll_ucx_request.h" +#include "coll_ucx_datatype.h" /* @@ -266,6 +267,12 @@ int mca_coll_ucx_open(void) goto out; } + int i; + mca_coll_ucx_component.datatype_attr_keyval = MPI_KEYVAL_INVALID; + for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { + mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID; + } + ucs_list_head_init(&mca_coll_ucx_component.group_head); return OMPI_SUCCESS; @@ -279,6 +286,14 @@ int mca_coll_ucx_close(void) { COLL_UCX_VERBOSE(1, "mca_coll_ucx_close"); + int i; + for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { + if (mca_coll_ucx_component.predefined_types[i] != COLL_UCX_DATATYPE_INVALID) { + ucp_dt_destroy(mca_coll_ucx_component.predefined_types[i]); + mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID; + } + } + if (mca_coll_ucx_component.ucg_worker != NULL) { mca_coll_ucx_cleanup(); mca_coll_ucx_component.ucg_worker = NULL; @@ -355,11 +370,10 @@ int mca_coll_ucx_init(void) } /* Initialize the free lists */ - OBJ_CONSTRUCT(&mca_coll_ucx_component.persistent_ops, mca_coll_ucx_freelist_t); - - /* Create a completed request to be returned from isend */ - OBJ_CONSTRUCT(&mca_coll_ucx_component.completed_send_req, ompi_request_t); - mca_coll_ucx_completed_request_init(&mca_coll_ucx_component.completed_send_req); + OBJ_CONSTRUCT(&mca_coll_ucx_component.convs, mca_coll_ucx_freelist_t); + COLL_UCX_FREELIST_INIT(&mca_coll_ucx_component.convs, + mca_coll_ucx_convertor_t, + 128, -1, 128); rc = opal_progress_register(mca_coll_ucx_progress); if (OPAL_SUCCESS != rc) { @@ -384,10 +398,7 @@ void mca_coll_ucx_cleanup(void) opal_progress_unregister(mca_coll_ucx_progress); - mca_coll_ucx_component.completed_send_req.req_state = OMPI_REQUEST_INVALID; - OMPI_REQUEST_FINI(&mca_coll_ucx_component.completed_send_req); - OBJ_DESTRUCT(&mca_coll_ucx_component.completed_send_req); - OBJ_DESTRUCT(&mca_coll_ucx_component.persistent_ops); + OBJ_DESTRUCT(&mca_coll_ucx_component.convs); if (mca_coll_ucx_component.ucg_worker) { ucg_worker_destroy(mca_coll_ucx_component.ucg_worker); diff --git a/ompi/mca/coll/ucx/coll_ucx_datatype.c b/ompi/mca/coll/ucx/coll_ucx_datatype.c new file mode 100644 index 00000000000..05eb985cf93 --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_datatype.c @@ -0,0 +1,271 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED. + * Copyright (c) 2019 Research Organization for Information Science + * and Technology (RIST). All rights reserved. + * Copyright (c) 2020 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * + * Copyright (c) 2020 Huawei Technologies Co., Ltd. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "coll_ucx_datatype.h" +#include "coll_ucx_request.h" + +#include "ompi/runtime/mpiruntime.h" +#include "ompi/attribute/attribute.h" + +#include +#include + +static void* coll_ucx_generic_datatype_start_pack(void *context, const void *buffer, + size_t count) +{ + ompi_datatype_t *datatype = context; + mca_coll_ucx_convertor_t *convertor; + + convertor = (mca_coll_ucx_convertor_t *)COLL_UCX_FREELIST_GET(&mca_coll_ucx_component.convs); + + OMPI_DATATYPE_RETAIN(datatype); + convertor->datatype = datatype; + opal_convertor_copy_and_prepare_for_send(ompi_proc_local_proc->super.proc_convertor, + &datatype->super, count, buffer, 0, + &convertor->opal_conv); + return convertor; +} + +static void* coll_ucx_generic_datatype_start_unpack(void *context, void *buffer, + size_t count) +{ + ompi_datatype_t *datatype = context; + mca_coll_ucx_convertor_t *convertor; + + convertor = (mca_coll_ucx_convertor_t *)COLL_UCX_FREELIST_GET(&mca_coll_ucx_component.convs); + + OMPI_DATATYPE_RETAIN(datatype); + convertor->datatype = datatype; + convertor->offset = 0; + opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor, + &datatype->super, count, buffer, 0, + &convertor->opal_conv); + return convertor; +} + +static size_t coll_ucx_generic_datatype_packed_size(void *state) +{ + mca_coll_ucx_convertor_t *convertor = state; + size_t size; + + opal_convertor_get_packed_size(&convertor->opal_conv, &size); + return size; +} + +static size_t coll_ucx_generic_datatype_pack(void *state, size_t offset, + void *dest, size_t max_length) +{ + mca_coll_ucx_convertor_t *convertor = state; + uint32_t iov_count; + struct iovec iov; + size_t length; + + iov_count = 1; + iov.iov_base = dest; + iov.iov_len = max_length; + + opal_convertor_set_position(&convertor->opal_conv, &offset); + length = max_length; + opal_convertor_pack(&convertor->opal_conv, &iov, &iov_count, &length); + return length; +} + +static ucs_status_t coll_ucx_generic_datatype_unpack(void *state, size_t offset, + const void *src, size_t length) +{ + mca_coll_ucx_convertor_t *convertor = state; + + uint32_t iov_count; + struct iovec iov; + opal_convertor_t conv; + + iov_count = 1; + iov.iov_base = (void*)src; + iov.iov_len = length; + + /* in case if unordered message arrived - create separate convertor to + * unpack data. */ + if (offset != convertor->offset) { + OBJ_CONSTRUCT(&conv, opal_convertor_t); + opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor, + &convertor->datatype->super, + convertor->opal_conv.count, + convertor->opal_conv.pBaseBuf, 0, + &conv); + opal_convertor_set_position(&conv, &offset); + opal_convertor_unpack(&conv, &iov, &iov_count, &length); + opal_convertor_cleanup(&conv); + OBJ_DESTRUCT(&conv); + /* permanently switch to un-ordered mode */ + convertor->offset = 0; + } else { + opal_convertor_unpack(&convertor->opal_conv, &iov, &iov_count, &length); + convertor->offset += length; + } + return UCS_OK; +} + +static void coll_ucx_generic_datatype_finish(void *state) +{ + mca_coll_ucx_convertor_t *convertor = state; + + opal_convertor_cleanup(&convertor->opal_conv); + OMPI_DATATYPE_RELEASE(convertor->datatype); + COLL_UCX_FREELIST_RETURN(&mca_coll_ucx_component.convs, &convertor->super); +} + +static ucp_generic_dt_ops_t coll_ucx_generic_datatype_ops = { + .start_pack = coll_ucx_generic_datatype_start_pack, + .start_unpack = coll_ucx_generic_datatype_start_unpack, + .packed_size = coll_ucx_generic_datatype_packed_size, + .pack = coll_ucx_generic_datatype_pack, + .unpack = coll_ucx_generic_datatype_unpack, + .finish = coll_ucx_generic_datatype_finish +}; + +int mca_coll_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval, + void *attr_val, void *extra) +{ + ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val; + +#ifdef HAVE_UCP_REQUEST_PARAM_T + free((void*)datatype->pml_data); +#else + COLL_UCX_ASSERT((uint64_t)ucp_datatype == datatype->pml_data); +#endif + ucp_dt_destroy(ucp_datatype); + datatype->pml_data = COLL_UCX_DATATYPE_INVALID; + return OMPI_SUCCESS; +} + +__opal_attribute_always_inline__ +static inline int mca_coll_ucx_datatype_is_contig(ompi_datatype_t *datatype) +{ + ptrdiff_t lb; + + ompi_datatype_type_lb(datatype, &lb); + + return (datatype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) && + (datatype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) && + (lb == 0); +} + +#ifdef HAVE_UCP_REQUEST_PARAM_T +__opal_attribute_always_inline__ static inline +coll_ucx_datatype_t *mca_coll_ucx_init_nbx_datatype(ompi_datatype_t *datatype, + ucp_datatype_t ucp_datatype, + size_t size) +{ + coll_ucx_datatype_t *pml_datatype; + int is_contig_pow2; + + pml_datatype = malloc(sizeof(*pml_datatype)); + if (pml_datatype == NULL) { + int err = MPI_ERR_INTERN; + COLL_UCX_ERROR("Failed to allocate datatype structure"); + /* TODO: this error should return to the caller and invoke an error + * handler from the MPI API call. + * For now, it is fatal. */ + ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure"); + } + + pml_datatype->datatype = ucp_datatype; + + is_contig_pow2 = mca_coll_ucx_datatype_is_contig(datatype) && + (size && !(size & (size - 1))); /* is_pow2(size) */ + if (is_contig_pow2) { + pml_datatype->size_shift = (int)(log(size) / log(2.0)); /* log2(size) */ + } else { + pml_datatype->size_shift = 0; + } + + return pml_datatype; +} +#endif + +ucp_datatype_t mca_coll_ucx_init_datatype(ompi_datatype_t *datatype) +{ + size_t size = 0; /* init to suppress compiler warning */ + ucp_datatype_t ucp_datatype; + ucs_status_t status; + int ret; + + if (mca_coll_ucx_datatype_is_contig(datatype)) { + ompi_datatype_type_size(datatype, &size); + ucp_datatype = ucp_dt_make_contig(size); + goto out; + } + + status = ucp_dt_create_generic(&coll_ucx_generic_datatype_ops, + datatype, &ucp_datatype); + if (status != UCS_OK) { + int err = MPI_ERR_INTERN; + COLL_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name); + /* TODO: this error should return to the caller and invoke an error + * handler from the MPI API call. + * For now, it is fatal. */ + ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure"); + } + + /* Add custom attribute, to clean up UCX resources when OMPI datatype is + * released. + */ + if (ompi_datatype_is_predefined(datatype)) { + COLL_UCX_ASSERT(datatype->id < OMPI_DATATYPE_MAX_PREDEFINED); + mca_coll_ucx_component.predefined_types[datatype->id] = ucp_datatype; + } else { + ret = ompi_attr_set_c(TYPE_ATTR, datatype, &datatype->d_keyhash, + mca_coll_ucx_component.datatype_attr_keyval, + (void*)ucp_datatype, false); + if (ret != OMPI_SUCCESS) { + int err = MPI_ERR_INTERN; + COLL_UCX_ERROR("Failed to add UCX datatype attribute for %s (%p): %d", + datatype->name, (void*)datatype, ret); + /* TODO: this error should return to the caller and invoke an error + * handler from the MPI API call. + * For now, it is fatal. */ + ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure"); + } + } +out: + COLL_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype) + +#ifdef HAVE_UCP_REQUEST_PARAM_T + UCS_STATIC_ASSERT(sizeof(datatype->pml_data) >= sizeof(coll_ucx_datatype_t*)); + datatype->pml_data = (uint64_t)mca_coll_ucx_init_nbx_datatype(datatype, + ucp_datatype, + size); +#else + datatype->pml_data = ucp_datatype; +#endif + + return ucp_datatype; +} + +static void mca_coll_ucx_convertor_construct(mca_coll_ucx_convertor_t *convertor) +{ + OBJ_CONSTRUCT(&convertor->opal_conv, opal_convertor_t); +} + +static void mca_coll_ucx_convertor_destruct(mca_coll_ucx_convertor_t *convertor) +{ + OBJ_DESTRUCT(&convertor->opal_conv); +} + +OBJ_CLASS_INSTANCE(mca_coll_ucx_convertor_t, + opal_free_list_item_t, + mca_coll_ucx_convertor_construct, + mca_coll_ucx_convertor_destruct); diff --git a/ompi/mca/coll/ucx/coll_ucx_datatype.h b/ompi/mca/coll/ucx/coll_ucx_datatype.h new file mode 100644 index 00000000000..1966cafea25 --- /dev/null +++ b/ompi/mca/coll/ucx/coll_ucx_datatype.h @@ -0,0 +1,84 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2001-2011. ALL RIGHTS RESERVED. + * Copyright (C) Huawei Technologies Co., Ltd. 2020. All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#ifndef COLL_UCX_DATATYPE_H_ +#define COLL_UCX_DATATYPE_H_ + +#include "coll_ucx.h" + + +#define COLL_UCX_DATATYPE_INVALID 0 + +#ifdef HAVE_UCP_REQUEST_PARAM_T +typedef struct { + ucp_datatype_t datatype; + int size_shift; +} coll_ucx_datatype_t; +#endif + +struct coll_ucx_convertor { + opal_free_list_item_t super; + ompi_datatype_t *datatype; + opal_convertor_t opal_conv; + size_t offset; +}; + +ucp_datatype_t mca_coll_ucx_init_datatype(ompi_datatype_t *datatype); + +int mca_coll_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval, + void *attr_val, void *extra); + +OBJ_CLASS_DECLARATION(mca_coll_ucx_convertor_t); + + +__opal_attribute_always_inline__ +static inline ucp_datatype_t mca_coll_ucx_get_datatype(ompi_datatype_t *datatype) +{ +#ifdef HAVE_UCP_REQUEST_PARAM_T + coll_ucx_datatype_t *ucp_type = (coll_ucx_datatype_t*)datatype->pml_data; + + if (OPAL_LIKELY(ucp_type != COLL_UCX_DATATYPE_INVALID)) { + return ucp_type->datatype; + } +#else + ucp_datatype_t ucp_type = datatype->pml_data; + + if (OPAL_LIKELY(ucp_type != COLL_UCX_DATATYPE_INVALID)) { + return ucp_type; + } +#endif + + return mca_coll_ucx_init_datatype(datatype); +} + +#ifdef HAVE_UCP_REQUEST_PARAM_T +__opal_attribute_always_inline__ +static inline coll_ucx_datatype_t* +mca_coll_ucx_get_op_data(ompi_datatype_t *datatype) +{ + coll_ucx_datatype_t *ucp_type = (coll_ucx_datatype_t*)datatype->pml_data; + + if (OPAL_LIKELY(ucp_type != COLL_UCX_DATATYPE_INVALID)) { + return ucp_type; + } + + mca_coll_ucx_init_datatype(datatype); + return (coll_ucx_datatype_t*)datatype->pml_data; +} + +__opal_attribute_always_inline__ +static inline size_t mca_coll_ucx_get_data_size(coll_ucx_datatype_t *op_data, + size_t count) +{ + return count << op_data->size_shift; +} +#endif + +#endif /* COLL_UCX_DATATYPE_H_ */ diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 3cc84c474e6..0b1baaad8bd 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -18,6 +18,7 @@ #include "coll_ucx.h" #include "coll_ucx_request.h" +#include "coll_ucx_datatype.h" #include #include @@ -223,7 +224,7 @@ static int mca_coll_ucx_init_global_topo(mca_coll_ucx_module_t *module) goto end; } - /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ + /* Create a topo matrix. As it is Diagonal symmetryן¼� only half of the matrix will be computed. */ ret = mca_coll_ucx_create_topo_map(node_index, topo_info, LOC_SIZE, rank_cnt); if (ret != OMPI_SUCCESS) { status = OMPI_ERROR; @@ -281,7 +282,7 @@ static int mca_coll_ucx_create_comm_topo(ucg_group_params_t *args, struct ompi_c return OMPI_SUCCESS; } - /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ + /* Create a topo matrix. As it is Diagonal symmetryן¼� only half of the matrix will be computed. */ unsigned i; for (i = 0; i < args->member_count; i++) { /* Find the rank in the MPI_COMM_WORLD for rank i in the comm. */ @@ -323,6 +324,13 @@ static void mca_coll_ucg_create_distance_array(struct ompi_communicator_t *comm, } } +static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt, + ucp_datatype_t *ucp_dt) +{ + *ucp_dt = mca_coll_ucx_get_datatype(mpi_dt); + return 0; +} + static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) { args->member_count = ompi_comm_size(comm); @@ -332,6 +340,7 @@ static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_ args->release_address_f = mca_coll_ucx_release_address; args->cb_group_obj = comm; args->op_is_commute_f = ompi_op_is_commute; + args->mpi_dt_convert = mca_coll_ucg_datatype_convert; } static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args) @@ -459,6 +468,26 @@ static int mca_coll_ucx_module_enable(mca_coll_base_module_t *module, mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*) module; int rc; + if (mca_coll_ucx_component.datatype_attr_keyval == MPI_KEYVAL_INVALID) { + /* Create a key for adding custom attributes to datatypes */ + ompi_attribute_fn_ptr_union_t copy_fn; + ompi_attribute_fn_ptr_union_t del_fn; + copy_fn.attr_datatype_copy_fn = + (MPI_Type_internal_copy_attr_function*)MPI_TYPE_NULL_COPY_FN; + del_fn.attr_datatype_delete_fn = mca_coll_ucx_datatype_attr_del_fn; + rc = ompi_attr_create_keyval(TYPE_ATTR, copy_fn, del_fn, + &mca_coll_ucx_component.datatype_attr_keyval, + NULL, 0, NULL); + if (rc != OMPI_SUCCESS) { + COLL_UCX_ERROR("Failed to create keyval for UCX datatypes: %d", rc); + return rc; + } + + COLL_UCX_FREELIST_INIT(&mca_coll_ucx_component.convs, + mca_coll_ucx_convertor_t, + 128, -1, 128); + } + /* prepare the placeholder for the array of request* */ module->base_data = OBJ_NEW(mca_coll_base_comm_t); if (NULL == module->base_data) { @@ -470,9 +499,6 @@ static int mca_coll_ucx_module_enable(mca_coll_base_module_t *module, return rc; } - COLL_UCX_FREELIST_INIT(&mca_coll_ucx_component.persistent_ops, mca_coll_ucx_persistent_op_t, - 128, -1, 128); - COLL_UCX_VERBOSE(1, "UCX Collectives Module initialized"); return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index 0dfebe42464..d5f2f5e000d 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -17,11 +17,6 @@ #include "ompi/message/message.h" #include -static inline int mca_coll_ucx_is_datatype_supported(struct ompi_datatype_t *dtype, int count) -{ - return ompi_datatype_is_contiguous_memory_layout(dtype, count); -} - int mca_coll_ucx_start(size_t count, ompi_request_t** requests) { mca_coll_ucx_persistent_op_t *preq = NULL; @@ -96,10 +91,6 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, { mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; - if (ucs_unlikely(!mca_coll_ucx_is_datatype_supported(dtype, count))) { - COLL_UCX_ERROR("UCX component does not support discontinuous datatype. Please use other coll component."); - return OMPI_ERR_NOT_SUPPORTED; - } COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); @@ -418,10 +409,6 @@ int mca_coll_ucx_bcast(void *buff, int count, struct ompi_datatype_t *dtype, int { mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; - if (ucs_unlikely(!mca_coll_ucx_is_datatype_supported(dtype, count))) { - COLL_UCX_ERROR("UCX component does not support discontinuous datatype. Please use other coll component."); - return OMPI_ERR_NOT_SUPPORTED; - } COLL_UCX_TRACE("%s", buff, buff, count, dtype, comm, "bcast"); ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); From 4b6ac18298679c141dc0c2854260dd0e1c561910 Mon Sep 17 00:00:00 2001 From: zheng871026 <40054765+zheng871026@users.noreply.github.com> Date: Mon, 16 Nov 2020 15:30:39 +0800 Subject: [PATCH 04/20] fix format error --- ompi/mca/coll/ucx/coll_ucx_module.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 0b1baaad8bd..ca11f7ae585 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -224,7 +224,7 @@ static int mca_coll_ucx_init_global_topo(mca_coll_ucx_module_t *module) goto end; } - /* Create a topo matrix. As it is Diagonal symmetryן¼� only half of the matrix will be computed. */ + /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ ret = mca_coll_ucx_create_topo_map(node_index, topo_info, LOC_SIZE, rank_cnt); if (ret != OMPI_SUCCESS) { status = OMPI_ERROR; @@ -282,7 +282,7 @@ static int mca_coll_ucx_create_comm_topo(ucg_group_params_t *args, struct ompi_c return OMPI_SUCCESS; } - /* Create a topo matrix. As it is Diagonal symmetryן¼� only half of the matrix will be computed. */ + /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ unsigned i; for (i = 0; i < args->member_count; i++) { /* Find the rank in the MPI_COMM_WORLD for rank i in the comm. */ From 16f926be7f7d43b1847b07946f45b30726aa1b07 Mon Sep 17 00:00:00 2001 From: zheng871026 <40054765+zheng871026@users.noreply.github.com> Date: Fri, 20 Nov 2020 21:57:10 +0800 Subject: [PATCH 05/20] fix for allreduce non-contiguous datatypes --- ompi/mca/coll/ucx/coll_ucx_op.c | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index d5f2f5e000d..a32dde25b4d 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -17,6 +17,11 @@ #include "ompi/message/message.h" #include +static inline int mca_coll_ucx_is_datatype_supported(struct ompi_datatype_t *dtype, int count) +{ + return ompi_datatype_is_contiguous_memory_layout(dtype, count); +} + int mca_coll_ucx_start(size_t count, ompi_request_t** requests) { mca_coll_ucx_persistent_op_t *preq = NULL; @@ -91,6 +96,10 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, { mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; + if (ucs_unlikely(!mca_coll_ucx_is_datatype_supported(dtype, count))) { + COLL_UCX_ERROR("UCX component does not support discontinuous datatype. Please use other coll component."); + return OMPI_ERR_NOT_SUPPORTED; + } COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); From 9b6d0564292e8f0ac3627de3c4eb4e8711733248 Mon Sep 17 00:00:00 2001 From: zheng871026 <40054765+zheng871026@users.noreply.github.com> Date: Wed, 2 Dec 2020 19:27:20 +0800 Subject: [PATCH 06/20] fix cleancode --- ompi/mca/coll/ucx/coll_ucx_datatype.c | 36 +++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_datatype.c b/ompi/mca/coll/ucx/coll_ucx_datatype.c index 05eb985cf93..e5793c53423 100644 --- a/ompi/mca/coll/ucx/coll_ucx_datatype.c +++ b/ompi/mca/coll/ucx/coll_ucx_datatype.c @@ -72,22 +72,29 @@ static size_t coll_ucx_generic_datatype_pack(void *state, size_t offset, uint32_t iov_count; struct iovec iov; size_t length; - + int rc; + iov_count = 1; iov.iov_base = dest; iov.iov_len = max_length; opal_convertor_set_position(&convertor->opal_conv, &offset); length = max_length; - opal_convertor_pack(&convertor->opal_conv, &iov, &iov_count, &length); - return length; + rc = opal_convertor_pack(&convertor->opal_conv, &iov, &iov_count, &length); + if (OPAL_UNLIKELY(rc < 0)) { + int err = MPI_ERR_INTERN; + COLL_UCX_ERROR("Failed to pack datatype structure"); + ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to pack datatype structure"); + } else { + return length; + } } static ucs_status_t coll_ucx_generic_datatype_unpack(void *state, size_t offset, const void *src, size_t length) { mca_coll_ucx_convertor_t *convertor = state; - + int rc; uint32_t iov_count; struct iovec iov; opal_convertor_t conv; @@ -106,16 +113,22 @@ static ucs_status_t coll_ucx_generic_datatype_unpack(void *state, size_t offset, convertor->opal_conv.pBaseBuf, 0, &conv); opal_convertor_set_position(&conv, &offset); - opal_convertor_unpack(&conv, &iov, &iov_count, &length); + rc = opal_convertor_unpack(&conv, &iov, &iov_count, &length); opal_convertor_cleanup(&conv); OBJ_DESTRUCT(&conv); /* permanently switch to un-ordered mode */ convertor->offset = 0; } else { - opal_convertor_unpack(&convertor->opal_conv, &iov, &iov_count, &length); + rc = opal_convertor_unpack(&convertor->opal_conv, &iov, &iov_count, &length); convertor->offset += length; } - return UCS_OK; + if (OPAL_UNLIKELY(rc < 0)) { + int err = MPI_ERR_INTERN; + COLL_UCX_ERROR("Failed to unpack datatype structure"); + ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to unpack datatype structure"); + } else { + return UCS_OK; + } } static void coll_ucx_generic_datatype_finish(void *state) @@ -176,9 +189,6 @@ coll_ucx_datatype_t *mca_coll_ucx_init_nbx_datatype(ompi_datatype_t *datatype, if (pml_datatype == NULL) { int err = MPI_ERR_INTERN; COLL_UCX_ERROR("Failed to allocate datatype structure"); - /* TODO: this error should return to the caller and invoke an error - * handler from the MPI API call. - * For now, it is fatal. */ ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure"); } @@ -214,9 +224,6 @@ ucp_datatype_t mca_coll_ucx_init_datatype(ompi_datatype_t *datatype) if (status != UCS_OK) { int err = MPI_ERR_INTERN; COLL_UCX_ERROR("Failed to create UCX datatype for %s", datatype->name); - /* TODO: this error should return to the caller and invoke an error - * handler from the MPI API call. - * For now, it is fatal. */ ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure"); } @@ -234,9 +241,6 @@ ucp_datatype_t mca_coll_ucx_init_datatype(ompi_datatype_t *datatype) int err = MPI_ERR_INTERN; COLL_UCX_ERROR("Failed to add UCX datatype attribute for %s (%p): %d", datatype->name, (void*)datatype, ret); - /* TODO: this error should return to the caller and invoke an error - * handler from the MPI API call. - * For now, it is fatal. */ ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to allocate datatype structure"); } } From 961eda695b145dc2e32a4eeb7a0a807d4e3ac423 Mon Sep 17 00:00:00 2001 From: nsos Date: Fri, 4 Dec 2020 09:45:04 +0800 Subject: [PATCH 07/20] FIX: REMOVE TRAILING SPACES remove trailing spaces. --- ompi/mca/coll/ucx/coll_ucx_datatype.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_datatype.c b/ompi/mca/coll/ucx/coll_ucx_datatype.c index e5793c53423..7b632deff30 100644 --- a/ompi/mca/coll/ucx/coll_ucx_datatype.c +++ b/ompi/mca/coll/ucx/coll_ucx_datatype.c @@ -73,7 +73,7 @@ static size_t coll_ucx_generic_datatype_pack(void *state, size_t offset, struct iovec iov; size_t length; int rc; - + iov_count = 1; iov.iov_base = dest; iov.iov_len = max_length; From 67e322faf954dd7d27bc8c8aaaa4d4ec68d879db Mon Sep 17 00:00:00 2001 From: shizhibao Date: Sun, 6 Dec 2020 21:50:04 +0800 Subject: [PATCH 08/20] Allreduce support non-contiguous datatype --- ompi/mca/coll/ucx/coll_ucx_op.c | 48 ++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index a32dde25b4d..71978626275 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -94,40 +94,62 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_communicator_t *comm, mca_coll_base_module_t *module) { - mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t*)module; - - if (ucs_unlikely(!mca_coll_ucx_is_datatype_supported(dtype, count))) { - COLL_UCX_ERROR("UCX component does not support discontinuous datatype. Please use other coll component."); - return OMPI_ERR_NOT_SUPPORTED; + mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t *)module; + char *inplace_buff = NULL; + ucg_coll_h coll = NULL; + ptrdiff_t extent, dsize, gap = 0; + int err; + + dsize = opal_datatype_span(&dtype->super, count, &gap); + if (sbuf == MPI_IN_PLACE && dsize != 0) { + inplace_buff = (char *)malloc(dsize); + if (inplace_buff == NULL) { + return OMPI_ERR_OUT_OF_RESOURCE; + } + sbuf = inplace_buff - gap; + err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)sbuf, (char *)rbuf); + } else { + err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf); } + if (err != MPI_SUCCESS) { + if (inplace_buff != NULL) { + free(inplace_buff); + } + return err; + } + COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); - ptrdiff_t dtype_size; - ucg_coll_h coll = NULL; - ompi_datatype_type_extent(dtype, &dtype_size); - ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, + + ompi_datatype_type_extent(dtype, &extent); + ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, op, 0, 0, &coll); if (OPAL_UNLIKELY(ret != UCS_OK)) { COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret)); - return OMPI_ERROR; + goto exit; } ret = ucg_collective_start_nbr(coll, req); if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { COLL_UCX_ERROR("ucx allreduce start failed: %s", ucs_status_string(ret)); - return OMPI_ERROR; + goto exit; } if (ucs_unlikely(ret == UCS_OK)) { - return OMPI_SUCCESS; + goto exit; } ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx allreduce", (void)0); COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce END"); - return OMPI_SUCCESS; +exit: + if (inplace_buff != NULL) { + free(inplace_buff); + } + + return (ret == UCS_OK) ? OMPI_SUCCESS : OMPI_ERROR; } int mca_coll_ucx_iallreduce(const void *sbuf, void *rbuf, int count, From 9d68b69e72588dc76b5eaa0a3cbb12717cbe67ae Mon Sep 17 00:00:00 2001 From: "public (843ed2ad0e21)" <993835762@qq.com> Date: Tue, 8 Dec 2020 19:10:58 +0800 Subject: [PATCH 09/20] increase the check of data size --- ompi/mca/coll/ucx/coll_ucx_op.c | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index 71978626275..71282fc5c7f 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -22,6 +22,13 @@ static inline int mca_coll_ucx_is_datatype_supported(struct ompi_datatype_t *dty return ompi_datatype_is_contiguous_memory_layout(dtype, count); } +static ucs_status_t mca_coll_ucx_check_total_data_size(size_t dtype_size, int count) +{ + static const uint64_t max_size = 4294967296; + uint64_t total_size = dtype_size * count; + return (total_size <= max_size) ? UCS_OK : UCS_ERR_OUT_OF_RANGE; +} + int mca_coll_ucx_start(size_t count, ompi_request_t** requests) { mca_coll_ucx_persistent_op_t *preq = NULL; @@ -100,6 +107,13 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, ptrdiff_t extent, dsize, gap = 0; int err; + ompi_datatype_type_extent(dtype, &extent); + ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)extent, count); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx component only support data size <= 2^32 bytes. please use other component."); + return OMPI_ERROR; + } + dsize = opal_datatype_span(&dtype->super, count, &gap); if (sbuf == MPI_IN_PLACE && dsize != 0) { inplace_buff = (char *)malloc(dsize); @@ -122,8 +136,7 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); - ompi_datatype_type_extent(dtype, &extent); - ucs_status_t ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, + ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, op, 0, 0, &coll); if (OPAL_UNLIKELY(ret != UCS_OK)) { COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret)); @@ -446,7 +459,12 @@ int mca_coll_ucx_bcast(void *buff, int count, struct ompi_datatype_t *dtype, int ptrdiff_t dtype_size; ucg_coll_h coll = NULL; ompi_datatype_type_extent(dtype, &dtype_size); - ucs_status_t ret = ucg_coll_bcast_init(buff, buff, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, + ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)dtype_size, count); + if (OPAL_UNLIKELY(ret != UCS_OK)) { + COLL_UCX_ERROR("ucx component only support data size <= 2^32 bytes. please use other component."); + return OMPI_ERROR; + } + ret = ucg_coll_bcast_init(buff, buff, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, 0, root, 0, &coll); if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { COLL_UCX_ERROR("ucx bcast init failed: %s", ucs_status_string(ret)); From a8fca2d616d44fbd5a08927d9de2ed1d38708cab Mon Sep 17 00:00:00 2001 From: "public (843ed2ad0e21)" <993835762@qq.com> Date: Wed, 9 Dec 2020 10:52:50 +0800 Subject: [PATCH 10/20] modify the format and add some notes --- ompi/mca/coll/ucx/coll_ucx_component.c | 9 ++++++--- ompi/mca/coll/ucx/coll_ucx_op.c | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_component.c b/ompi/mca/coll/ucx/coll_ucx_component.c index ee73c8fee5f..dbee7c4ec8f 100644 --- a/ompi/mca/coll/ucx/coll_ucx_component.c +++ b/ompi/mca/coll/ucx/coll_ucx_component.c @@ -357,14 +357,16 @@ int mca_coll_ucx_init(void) status = ucg_worker_create(mca_coll_ucx_component.ucg_context, ¶ms, &mca_coll_ucx_component.ucg_worker); if (UCS_OK != status) { - COLL_UCX_WARN("Failed to create UCG worker"); + COLL_UCX_WARN("Failed to create UCG worker, automatically select other available and highest " + "priority collective component."); rc = OMPI_ERROR; goto err; } status = mca_coll_ucx_init_worker(); if (UCS_OK != status) { - COLL_UCX_WARN("Failed to init UCG worker"); + COLL_UCX_WARN("Failed to init UCG worker, automatically select other available and highest " + "priority collective component."); rc = OMPI_ERROR; goto err_destroy_worker; } @@ -377,7 +379,8 @@ int mca_coll_ucx_init(void) rc = opal_progress_register(mca_coll_ucx_progress); if (OPAL_SUCCESS != rc) { - COLL_UCX_ERROR("Failed to progress register"); + COLL_UCX_ERROR("Failed to progress register, automatically select other available and highest " + "priority collective component."); goto err_destroy_worker; } diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index 71282fc5c7f..a6d2ab4cfb3 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -137,7 +137,7 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, - op, 0, 0, &coll); + op, 0, 0, &coll); if (OPAL_UNLIKELY(ret != UCS_OK)) { COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret)); goto exit; @@ -465,7 +465,7 @@ int mca_coll_ucx_bcast(void *buff, int count, struct ompi_datatype_t *dtype, int return OMPI_ERROR; } ret = ucg_coll_bcast_init(buff, buff, count, (size_t)dtype_size, dtype, ucx_module->ucg_group, 0, - 0, root, 0, &coll); + 0, root, 0, &coll); if (OPAL_UNLIKELY(UCS_STATUS_IS_ERR(ret))) { COLL_UCX_ERROR("ucx bcast init failed: %s", ucs_status_string(ret)); return OMPI_ERROR; From 3409554f939fe919a221993ee0c41308bcce782b Mon Sep 17 00:00:00 2001 From: shizhibao Date: Thu, 10 Dec 2020 10:52:20 +0800 Subject: [PATCH 11/20] Support allreduce non-contiguous datatype --- ompi/mca/coll/ucx/coll_ucx_module.c | 11 ++++++ ompi/mca/coll/ucx/coll_ucx_op.c | 56 +++++++++++++++++++---------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index ca11f7ae585..09c996b1974 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -331,6 +331,16 @@ static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt, return 0; } +static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap) +{ + struct ompi_datatype_t *dtype = (struct ompi_datatype_t *)dt_ext; + ptrdiff_t dsize, gp= 0; + + dsize = opal_datatype_span(&dtype->super, count, &gp); + *gap = gp; + return dsize; +} + static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) { args->member_count = ompi_comm_size(comm); @@ -341,6 +351,7 @@ static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_ args->cb_group_obj = comm; args->op_is_commute_f = ompi_op_is_commute; args->mpi_dt_convert = mca_coll_ucg_datatype_convert; + args->mpi_datatype_span = coll_ucx_datatype_span; } static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args) diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index a6d2ab4cfb3..f14c4481f34 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -96,6 +96,36 @@ int mca_coll_ucx_start(size_t count, ompi_request_t** requests) ((char *)alloca(mca_coll_ucx_component.request_size) + \ mca_coll_ucx_component.request_size); +static int coll_ucx_allreduce_pre_init(struct ompi_datatype_t *dtype, int count, void *sbuf, + void *rbuf, char **inplace_buff, ptrdiff_t *gap) +{ + ptrdiff_t dsize, gp, lb = 0; + char *inpbuf = NULL; + int err; + + ompi_datatype_type_lb(dtype, &lb); + if ((dtype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) && + (dtype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) && + (lb == 0)) { + return UCS_OK; + } + + dsize = opal_datatype_span(&dtype->super, count, &gp); + if (sbuf == MPI_IN_PLACE && dsize != 0) { + inpbuf = (char *)malloc(dsize); + if (inpbuf == NULL) { + return UCS_ERR_NO_MEMORY; + } + *inplace_buff = inpbuf; + *gap = gp; + err = ompi_datatype_copy_content_same_ddt(dtype, count, inpbuf - gp, (char *)rbuf); + } else { + err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf); + } + + return (err == MPI_SUCCESS) ? UCS_OK : UCS_ERR_INVALID_PARAM; +} + int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, struct ompi_datatype_t *dtype, struct ompi_op_t *op, struct ompi_communicator_t *comm, @@ -104,8 +134,8 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, mca_coll_ucx_module_t *ucx_module = (mca_coll_ucx_module_t *)module; char *inplace_buff = NULL; ucg_coll_h coll = NULL; - ptrdiff_t extent, dsize, gap = 0; - int err; + ptrdiff_t extent, gap = 0; + char *sbuf_rel = NULL; ompi_datatype_type_extent(dtype, &extent); ucs_status_t ret = mca_coll_ucx_check_total_data_size((size_t)extent, count); @@ -114,29 +144,17 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, return OMPI_ERROR; } - dsize = opal_datatype_span(&dtype->super, count, &gap); - if (sbuf == MPI_IN_PLACE && dsize != 0) { - inplace_buff = (char *)malloc(dsize); - if (inplace_buff == NULL) { - return OMPI_ERR_OUT_OF_RESOURCE; - } - sbuf = inplace_buff - gap; - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)sbuf, (char *)rbuf); - } else { - err = ompi_datatype_copy_content_same_ddt(dtype, count, (char *)rbuf, (char *)sbuf); - } - if (err != MPI_SUCCESS) { - if (inplace_buff != NULL) { - free(inplace_buff); - } - return err; + ret = coll_ucx_allreduce_pre_init(dtype, count, sbuf, rbuf, &inplace_buff, &gap); + if (ret != UCS_OK) { + goto exit; } + sbuf_rel = (inplace_buff == NULL) ? sbuf : inplace_buff - gap; COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); ucs_status_ptr_t req = COLL_UCX_REQ_ALLOCA(ucx_module); - ret = ucg_coll_allreduce_init(sbuf, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, + ret = ucg_coll_allreduce_init(sbuf_rel, rbuf, count, (size_t)extent, dtype, ucx_module->ucg_group, 0, op, 0, 0, &coll); if (OPAL_UNLIKELY(ret != UCS_OK)) { COLL_UCX_ERROR("ucx allreduce init failed: %s", ucs_status_string(ret)); From db8641145a1db47c6fda39133ef345d73516598e Mon Sep 17 00:00:00 2001 From: shizhibao Date: Wed, 16 Dec 2020 20:18:36 +0800 Subject: [PATCH 12/20] Fix allreduce empty datatype bug --- ompi/mca/coll/ucx/coll_ucx_op.c | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index f14c4481f34..b1867446b7e 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -144,10 +144,13 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, return OMPI_ERROR; } - ret = coll_ucx_allreduce_pre_init(dtype, count, sbuf, rbuf, &inplace_buff, &gap); - if (ret != UCS_OK) { - goto exit; + if (count > 0 && extent > 0) { + ret = coll_ucx_allreduce_pre_init(dtype, count, sbuf, rbuf, &inplace_buff, &gap); + if (ret != UCS_OK) { + goto exit; + } } + sbuf_rel = (inplace_buff == NULL) ? sbuf : inplace_buff - gap; COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce START"); From 3147facba1653fe30899825e8e8ba081f6cc408d Mon Sep 17 00:00:00 2001 From: zheng871026 <190974948@qq.com> Date: Thu, 17 Dec 2020 19:25:09 +0800 Subject: [PATCH 13/20] solve the init topo map fault --- ompi/mca/coll/ucx/coll_ucx_module.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 09c996b1974..5d26337096a 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -206,8 +206,9 @@ static int mca_coll_ucx_init_global_topo(mca_coll_ucx_module_t *module) ret = ompi_coll_base_allgather_intra_bruck(val, LOC_SIZE, MPI_CHAR, topo_info, LOC_SIZE, MPI_CHAR, MPI_COMM_WORLD, &module->super); if (ret != OMPI_SUCCESS) { - status = OMPI_ERROR; - goto end; + int err = MPI_ERR_INTERN; + COLL_UCX_ERROR("ompi_coll_base_allgather_intra_bruck failed"); + ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to init topo map"); } /* Obtain node index to indicate each 'loc' belongs to which node, From 7bb2bb63219f42a3f917b0ec8d3ae469a48c4deb Mon Sep 17 00:00:00 2001 From: "public (843ed2ad0e21)" <993835762@qq.com> Date: Fri, 18 Dec 2020 21:21:18 +0800 Subject: [PATCH 14/20] Solve the problem of program jam when one socket is balanced and the other is unbalanced --- ompi/mca/coll/ucx/coll_ucx_module.c | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 5d26337096a..aa9fc59b01b 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -381,6 +381,23 @@ static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_pa } } +static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params, mca_coll_ucx_module_t *module) +{ + unsigned pps = ucg_builtin_calculate_ppx(group_params, UCG_GROUP_MEMBER_DISTANCE_SOCKET); + unsigned ppn = ucg_builtin_calculate_ppx(group_params, UCG_GROUP_MEMBER_DISTANCE_HOST); + char is_socket_balance = (pps == (ppn - pps) || pps == ppn); + char result = is_socket_balance; + int status = ompi_coll_base_allreduce_intra_basic_linear(&is_socket_balance, &result, 1, MPI_CHAR, MPI_MIN, + MPI_COMM_WORLD, &module->super); + if (status != OMPI_SUCCESS) { + int error = MPI_ERR_INTERN; + COLL_UCX_ERROR("ompi_coll_base_allreduce_intra_basic_linear failed"); + ompi_mpi_errors_are_fatal_comm_handler(NULL, &error, "Failed to init is_socket_balance"); + } + group_params->is_socket_balance = result; + return ; +} + static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_communicator_t *comm) { ucs_status_t error; @@ -454,6 +471,7 @@ static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_commun goto out; } + mca_coll_ucg_init_is_socket_balance(&args, module); error = ucg_group_create(mca_coll_ucx_component.ucg_worker, &args, &module->ucg_group); /* Examine comm_new return value */ From 31207659637db5b4f3f408cbb6c394659a7554e9 Mon Sep 17 00:00:00 2001 From: "public (843ed2ad0e21)" <993835762@qq.com> Date: Mon, 21 Dec 2020 14:42:33 +0800 Subject: [PATCH 15/20] Fixed a bug that failed to call allreduce --- ompi/mca/coll/ucx/coll_ucx_module.c | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index aa9fc59b01b..04f1ec08b30 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -381,21 +381,22 @@ static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_pa } } -static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params, mca_coll_ucx_module_t *module) +static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params, mca_coll_ucx_module_t *module, + struct ompi_communicator_t *comm) { unsigned pps = ucg_builtin_calculate_ppx(group_params, UCG_GROUP_MEMBER_DISTANCE_SOCKET); unsigned ppn = ucg_builtin_calculate_ppx(group_params, UCG_GROUP_MEMBER_DISTANCE_HOST); char is_socket_balance = (pps == (ppn - pps) || pps == ppn); char result = is_socket_balance; int status = ompi_coll_base_allreduce_intra_basic_linear(&is_socket_balance, &result, 1, MPI_CHAR, MPI_MIN, - MPI_COMM_WORLD, &module->super); + comm, &module->super); if (status != OMPI_SUCCESS) { int error = MPI_ERR_INTERN; COLL_UCX_ERROR("ompi_coll_base_allreduce_intra_basic_linear failed"); ompi_mpi_errors_are_fatal_comm_handler(NULL, &error, "Failed to init is_socket_balance"); } group_params->is_socket_balance = result; - return ; + return; } static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_communicator_t *comm) @@ -471,7 +472,7 @@ static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_commun goto out; } - mca_coll_ucg_init_is_socket_balance(&args, module); + mca_coll_ucg_init_is_socket_balance(&args, module, comm); error = ucg_group_create(mca_coll_ucx_component.ucg_worker, &args, &module->ucg_group); /* Examine comm_new return value */ From 1f81ff8e6f56453a3695c60d425c59fc0f003277 Mon Sep 17 00:00:00 2001 From: "public (843ed2ad0e21)" <993835762@qq.com> Date: Tue, 29 Dec 2020 00:13:18 +0800 Subject: [PATCH 16/20] Fix bug in non-contiguous datatype --- ompi/mca/coll/ucx/coll_ucx_module.c | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 04f1ec08b30..99acd98191f 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -342,6 +342,13 @@ static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap) return dsize; } +static ucg_group_member_index_t mca_coll_ucx_get_global_member_idx(void *cb_group_obj, + ucg_group_member_index_t index) +{ + ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; + return (ucg_group_member_index_t)mcacoll_ucx_find_rank_in_comm_world(comm, (int)index); +} + static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) { args->member_count = ompi_comm_size(comm); @@ -353,6 +360,7 @@ static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_ args->op_is_commute_f = ompi_op_is_commute; args->mpi_dt_convert = mca_coll_ucg_datatype_convert; args->mpi_datatype_span = coll_ucx_datatype_span; + args->mpi_global_idx_f = mca_coll_ucx_get_global_member_idx; } static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args) @@ -527,6 +535,7 @@ static int mca_coll_ucx_module_enable(mca_coll_base_module_t *module, rc = mca_coll_ucg_create(ucx_module, comm); if (rc != OMPI_SUCCESS) { + OBJ_REALEASE(module->base_data); return rc; } From 18e9d1c8e83a609a52404b6a788297aa45550a1a Mon Sep 17 00:00:00 2001 From: "public (843ed2ad0e21)" <993835762@qq.com> Date: Wed, 30 Dec 2020 10:25:50 +0800 Subject: [PATCH 17/20] Fix bug in discontig datatype --- ompi/mca/coll/ucx/coll_ucx_module.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 99acd98191f..c945b6dbdfb 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -346,7 +346,7 @@ static ucg_group_member_index_t mca_coll_ucx_get_global_member_idx(void *cb_grou ucg_group_member_index_t index) { ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; - return (ucg_group_member_index_t)mcacoll_ucx_find_rank_in_comm_world(comm, (int)index); + return (ucg_group_member_index_t)mca_coll_ucx_find_rank_in_comm_world(comm, (int)index); } static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) @@ -535,7 +535,7 @@ static int mca_coll_ucx_module_enable(mca_coll_base_module_t *module, rc = mca_coll_ucg_create(ucx_module, comm); if (rc != OMPI_SUCCESS) { - OBJ_REALEASE(module->base_data); + OBJ_RELEASE(module->base_data); return rc; } From 5c6dbeca07f6e351cf1e9feba57fc919efa030e5 Mon Sep 17 00:00:00 2001 From: RainybIue <993835762@qq.com> Date: Sun, 7 Feb 2021 18:09:21 +0800 Subject: [PATCH 18/20] MCA/COLL/UCX: Change UCG API, Improve code structure --- ompi/mca/coll/ucx/coll_ucx.h | 27 +- ompi/mca/coll/ucx/coll_ucx_component.c | 569 ++++++++++++++--------- ompi/mca/coll/ucx/coll_ucx_module.c | 612 ++++++++++++------------- ompi/mca/coll/ucx/coll_ucx_op.c | 16 +- ompi/mca/coll/ucx/coll_ucx_request.c | 6 +- ompi/mca/coll/ucx/coll_ucx_request.h | 2 +- 6 files changed, 658 insertions(+), 574 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx.h b/ompi/mca/coll/ucx/coll_ucx.h index 9e3bf616651..b2f5381d544 100644 --- a/ompi/mca/coll/ucx/coll_ucx.h +++ b/ompi/mca/coll/ucx/coll_ucx.h @@ -60,8 +60,10 @@ typedef struct mca_coll_ucx_component { bool enable_topo_map; /* UCX global objects */ + ucp_context_h ucp_context; + ucp_worker_h ucp_worker; ucg_context_h ucg_context; - ucg_worker_h ucg_worker; + ucg_group_h ucg_group; int output; ucs_list_link_t group_head; char **topo_map; @@ -92,37 +94,16 @@ typedef struct mca_coll_ucx_module { } mca_coll_ucx_module_t; OBJ_CLASS_DECLARATION(mca_coll_ucx_module_t); -/* - * Component-oriented functions for using UCX collectives. - */ -int mca_coll_ucx_open(void); -int mca_coll_ucx_close(void); -int mca_coll_ucx_init(void); -void mca_coll_ucx_cleanup(void); -int mca_coll_ucx_enable(bool enable); -int mca_coll_ucx_progress(void); - /* * TESTING PURPOSES: get the worker from the module. */ -ucg_worker_h mca_coll_ucx_get_component_worker(void); +ucp_worker_h mca_coll_ucx_get_component_worker(void); /* * Start persistent collectives from an array of requests. */ int mca_coll_ucx_start(size_t count, ompi_request_t** requests); -/* - * Obtain the address for a remote node. - */ -ucs_status_t mca_coll_ucx_resolve_address(void *cb_group_obj, ucg_group_member_index_t idx, ucg_address_t **addr, - size_t *addr_len); - -/* - * Release an obtained address for a remote node. - */ -void mca_coll_ucx_release_address(ucg_address_t *addr); - /* * The collective operations themselves. */ diff --git a/ompi/mca/coll/ucx/coll_ucx_component.c b/ompi/mca/coll/ucx/coll_ucx_component.c index dbee7c4ec8f..ec4827f540d 100644 --- a/ompi/mca/coll/ucx/coll_ucx_component.c +++ b/ompi/mca/coll/ucx/coll_ucx_component.c @@ -10,6 +10,7 @@ * * $HEADER$ */ + #include "ompi_config.h" #include @@ -18,7 +19,7 @@ #include "opal/mca/common/ucx/common_ucx.h" #include "opal/mca/installdirs/installdirs.h" - +#include "ompi/op/op.h" #include "coll_ucx.h" #include "coll_ucx_request.h" #include "coll_ucx_datatype.h" @@ -34,9 +35,9 @@ const char *mca_coll_ucx_component_version_string = static int ucx_open(void); static int ucx_close(void); static int ucx_register(void); -int mca_coll_ucx_init_query(bool enable_progress_threads, - bool enable_mpi_threads); -mca_coll_base_module_t *mca_coll_ucx_comm_query(struct ompi_communicator_t *comm, int *priority); +static int mca_coll_ucx_init_query(bool enable_progress_threads, bool enable_mpi_threads); +static mca_coll_base_module_t *mca_coll_ucx_comm_query(struct ompi_communicator_t *comm, + int *priority); int mca_coll_ucx_output = -1; mca_coll_ucx_component_t mca_coll_ucx_component = { @@ -73,156 +74,147 @@ mca_coll_ucx_component_t mca_coll_ucx_component = { .topo_map = NULL }; -int mca_coll_ucx_init_query(bool enable_progress_threads, - bool enable_mpi_threads) +static int mca_coll_ucx_send_worker_address(void) { - return OMPI_SUCCESS; -} + ucp_address_t *address = NULL; + ucs_status_t status; + size_t addrlen; + int rc; -mca_coll_base_module_t *mca_coll_ucx_comm_query(struct ompi_communicator_t *comm, int *priority) -{ - /* basic checks */ - if ((OMPI_COMM_IS_INTER(comm)) || (ompi_comm_size(comm) < 2)) { - return NULL; + status = ucp_worker_get_address(mca_coll_ucx_component.ucp_worker, &address, &addrlen); + if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to get worker address, %s", ucs_status_string(status)); + return OMPI_ERROR; } - /* create a new module for this communicator */ - COLL_UCX_VERBOSE(10, "Creating ucx_context for comm %p, comm_id %d, comm_size %d", - (void*)comm, comm->c_contextid, ompi_comm_size(comm)); - mca_coll_ucx_module_t *ucx_module = OBJ_NEW(mca_coll_ucx_module_t); - if (!ucx_module) { - return NULL; + OPAL_MODEX_SEND(rc, OPAL_PMIX_GLOBAL, &mca_coll_ucx_component.super.collm_version, + (void*)address, addrlen); + if (OPAL_SUCCESS != rc) { + COLL_UCX_ERROR("Open MPI couldn't distribute EP connection details"); + return OMPI_ERROR; } - *priority = mca_coll_ucx_component.priority; - return &(ucx_module->super); + ucp_worker_release_address(mca_coll_ucx_component.ucp_worker, address); + + return OMPI_SUCCESS; } -static int ucx_register(void) +static int mca_coll_ucx_recv_worker_address(ompi_proc_t *proc, + ucp_address_t **address_p, + size_t *addrlen_p) { - int status; - mca_coll_ucx_component.verbose = 0; - status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "verbosity", - "Verbosity of the UCX component", - MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, - OPAL_INFO_LVL_3, - MCA_BASE_VAR_SCOPE_LOCAL, - &mca_coll_ucx_component.verbose); - if (status < OPAL_SUCCESS) { - return OMPI_ERROR; + int ret; + + *address_p = NULL; + OPAL_MODEX_RECV(ret, &mca_coll_ucx_component.super.collm_version, + &proc->super.proc_name, (void**)address_p, addrlen_p); + if (ret != OPAL_SUCCESS) { + COLL_UCX_ERROR("Failed to receive UCX worker address: %s (%d)", opal_strerror(ret), ret); } + return ret; +} - mca_coll_ucx_component.priority = 91; - status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "priority", - "Priority of the UCX component", - MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, - OPAL_INFO_LVL_3, - MCA_BASE_VAR_SCOPE_LOCAL, - &mca_coll_ucx_component.priority); - if (status < OPAL_SUCCESS) { - return OMPI_ERROR; +static ucs_status_t mca_coll_ucx_resolve_address(void *cb_group_obj, + ucg_group_member_index_t rank, + ucp_address_t **addr, + size_t *addr_len) +{ + /* Sanity checks */ + ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; + if (rank == (ucg_group_member_index_t)comm->c_my_rank) { + return UCS_ERR_UNSUPPORTED; } - mca_coll_ucx_component.num_disconnect = 1; - status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "num_disconnect", - "How may disconnects go in parallel", - MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, - OPAL_INFO_LVL_3, - MCA_BASE_VAR_SCOPE_LOCAL, - &mca_coll_ucx_component.num_disconnect); - if (status < OPAL_SUCCESS) { - return OMPI_ERROR; + /* Check the cache for a previously established connection to that rank */ + ompi_proc_t *proc_peer = ompi_comm_peer_lookup(comm, rank); + *addr = proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL]; + *addr_len = 1; + if (*addr != NULL) { + return UCS_OK; } - mca_coll_ucx_component.enable_topo_map = 1; - status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "enable_topo_map", - "Enable global topology map for ucg", - MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, - OPAL_INFO_LVL_3, - MCA_BASE_VAR_SCOPE_LOCAL, - &mca_coll_ucx_component.enable_topo_map); - if (status < OPAL_SUCCESS) { - return OMPI_ERROR; + /* Obtain the UCP address of the remote */ + int ret = mca_coll_ucx_recv_worker_address(proc_peer, addr, addr_len); + if (ret < 0) { + COLL_UCX_ERROR("mca_coll_ucx_recv_worker_address(proc=%d rank=%lu) failed", + proc_peer->super.proc_name.vpid, rank); + return UCS_ERR_INVALID_ADDR; } - opal_common_ucx_mca_var_register(&mca_coll_ucx_component.super.collm_version); - return OMPI_SUCCESS; + /* Cache the connection for future invocations with this rank */ + proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL] = *addr; + return UCS_OK; } -static int ucx_open(void) +static void mca_coll_ucx_release_address(ucp_address_t *addr) { - mca_coll_ucx_component.output = opal_output_open(NULL); - opal_output_set_verbosity(mca_coll_ucx_component.output, mca_coll_ucx_component.verbose); - - opal_common_ucx_mca_register(); - - return mca_coll_ucx_open(); + /* no need to free - the address is stored in proc_peer->proc_endpoints */ + return; } -static int ucx_close(void) +static void mca_coll_ucx_mpi_reduce(void *mpi_op, + void *src, + void *dst, + int count, + void *mpi_dtype) { - if (mca_coll_ucx_component.ucg_worker == NULL) { - return OMPI_ERROR; - } - mca_coll_ucx_cleanup(); - - opal_common_ucx_mca_deregister(); - - return mca_coll_ucx_close(); + return ompi_op_reduce((ompi_op_t*)mpi_op, src, dst, count, (ompi_datatype_t*)mpi_dtype); } -static int mca_coll_ucx_send_worker_address(void) +static int mca_coll_ucx_mpi_op_is_commute(void *mpi_op) { - ucg_address_t *address = NULL; - ucs_status_t status; - size_t addrlen; - int rc; + return ompi_op_is_commute((ompi_op_t*)mpi_op); +} - status = ucg_worker_get_address(mca_coll_ucx_component.ucg_worker, &address, &addrlen); - if (UCS_OK != status) { - COLL_UCX_ERROR("Failed to get worker address"); - return OMPI_ERROR; +static ucg_group_member_index_t mca_coll_ucx_get_global_member_idx(void *cb_group_obj, + ucg_group_member_index_t index) +{ + ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; + struct ompi_proc_t *proc = ompi_comm_peer_lookup(comm, (int)index); + if (proc == NULL) { + return (ucg_group_member_index_t)-1; } - OPAL_MODEX_SEND(rc, OPAL_PMIX_GLOBAL, &mca_coll_ucx_component.super.collm_version, - (void*)address, addrlen); - if (OPAL_SUCCESS != rc) { - COLL_UCX_ERROR("Open MPI couldn't distribute EP connection details"); - return OMPI_ERROR; + unsigned i; + unsigned member_count = ompi_comm_size(MPI_COMM_WORLD); + for (i = 0; i < member_count; ++i) { + struct ompi_proc_t *global_proc = ompi_comm_peer_lookup(MPI_COMM_WORLD, i); + if (global_proc == proc) { + return i; + } } - ucg_worker_release_address(mca_coll_ucx_component.ucg_worker, address); + return (ucg_group_member_index_t)-1; +} - return OMPI_SUCCESS; +static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt, + ucp_datatype_t *ucp_dt) +{ + *ucp_dt = mca_coll_ucx_get_datatype(mpi_dt); + return 0; } -static int mca_coll_ucx_recv_worker_address(ompi_proc_t *proc, - ucg_address_t **address_p, - size_t *addrlen_p) +static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap) { - int ret; + struct ompi_datatype_t *dtype = (struct ompi_datatype_t *)dt_ext; + ptrdiff_t dsize, gp= 0; - *address_p = NULL; - OPAL_MODEX_RECV(ret, &mca_coll_ucx_component.super.collm_version, - &proc->super.proc_name, (void**)address_p, addrlen_p); - if (ret != OPAL_SUCCESS) { - COLL_UCX_ERROR("Failed to receive UCX worker address: %s (%d)", opal_strerror(ret), ret); - } - return ret; + dsize = opal_datatype_span(&dtype->super, count, &gp); + *gap = gp; + return dsize; } -int mca_coll_ucx_open(void) +static int mca_coll_ucx_init_ucp() { - ucg_context_attr_t attr; - ucg_params_t params; - ucg_config_t *config = NULL; + ucp_context_attr_t attr; + ucp_params_t params; + ucp_config_t *config = NULL; ucs_status_t status; - COLL_UCX_VERBOSE(1, "mca_coll_ucx_open"); - /* Read options */ - status = ucg_config_read("MPI", NULL, &config); + status = ucp_config_read("MPI", NULL, &config); if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to read ucp config, status %d", status); return OMPI_ERROR; } @@ -231,14 +223,13 @@ int mca_coll_ucx_open(void) UCP_PARAM_FIELD_REQUEST_SIZE | UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_CLEANUP | - // UCP_PARAM_FIELD_TAG_SENDER_MASK | UCP_PARAM_FIELD_MT_WORKERS_SHARED | UCP_PARAM_FIELD_ESTIMATED_NUM_EPS; params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64 | - UCP_FEATURE_GROUPS; + UCP_FEATURE_AM; params.request_size = sizeof(ompi_request_t); params.request_init = mca_coll_ucx_request_init; params.request_cleanup = mca_coll_ucx_request_cleanup; @@ -246,167 +237,222 @@ int mca_coll_ucx_open(void) since it will be protected by worker */ params.estimated_num_eps = ompi_proc_world_size(); - status = ucg_init(¶ms, config, &mca_coll_ucx_component.ucg_context); - ucg_config_release(config); - config = NULL; + status = ucp_init(¶ms, config, &mca_coll_ucx_component.ucp_context); + ucp_config_release(config); if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to init ucp context, %s", ucs_status_string(status)); return OMPI_ERROR; } - /* Query UCX attributes */ attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE; - status = ucg_context_query(mca_coll_ucx_component.ucg_context, &attr); + status = ucp_context_query(mca_coll_ucx_component.ucp_context, &attr); if (UCS_OK != status) { - goto out; + COLL_UCX_ERROR("Failed to query ucp context, %s", ucs_status_string(status)); + goto err_cleanup_ucp_context; } - mca_coll_ucx_component.request_size = attr.request_size; - - /* Initialize UCX worker */ - if (OMPI_SUCCESS != mca_coll_ucx_init()) { - goto out; - } - int i; mca_coll_ucx_component.datatype_attr_keyval = MPI_KEYVAL_INVALID; for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID; } - - ucs_list_head_init(&mca_coll_ucx_component.group_head); return OMPI_SUCCESS; -out: - ucg_cleanup(mca_coll_ucx_component.ucg_context); - mca_coll_ucx_component.ucg_context = NULL; +err_cleanup_ucp_context: + ucp_cleanup(mca_coll_ucx_component.ucp_context); + mca_coll_ucx_component.ucp_context = NULL; return OMPI_ERROR; } -int mca_coll_ucx_close(void) +static int mca_coll_ucx_init_ucg() { - COLL_UCX_VERBOSE(1, "mca_coll_ucx_close"); - - int i; - for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { - if (mca_coll_ucx_component.predefined_types[i] != COLL_UCX_DATATYPE_INVALID) { - ucp_dt_destroy(mca_coll_ucx_component.predefined_types[i]); - mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID; - } - } + ucs_status_t status; + ucg_params_t params; - if (mca_coll_ucx_component.ucg_worker != NULL) { - mca_coll_ucx_cleanup(); - mca_coll_ucx_component.ucg_worker = NULL; + params.field_mask = UCG_PARAM_FIELD_ADDRESS_CB | + UCG_PARAM_FIELD_REDUCE_CB | + UCG_PARAM_FIELD_COMMUTE_CB | + UCG_PARAM_FIELD_GLOBALIDX_CB | + UCG_PARAM_FIELD_DTCONVERT_CB | + UCG_PARAM_FIELD_DATATYPESPAN_CB | + UCG_PARAM_FIELD_MPI_IN_PLACE; + params.address.lookup_f = mca_coll_ucx_resolve_address; + params.address.release_f = mca_coll_ucx_release_address; + params.mpi_reduce_f = mca_coll_ucx_mpi_reduce; + params.op_is_commute_f = mca_coll_ucx_mpi_op_is_commute; + params.mpi_global_idx_f = mca_coll_ucx_get_global_member_idx; + params.mpi_dt_convert = mca_coll_ucg_datatype_convert; + params.mpi_datatype_span = coll_ucx_datatype_span; + params.mpi_in_place = (void*)MPI_IN_PLACE; + + status = ucg_init(¶ms, NULL, &mca_coll_ucx_component.ucg_context); + if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to init ucg config, %s", ucs_status_string(status)); + return OMPI_ERROR; } - if (mca_coll_ucx_component.ucg_context != NULL) { - ucg_cleanup(mca_coll_ucx_component.ucg_context); - mca_coll_ucx_component.ucg_context = NULL; - } return OMPI_SUCCESS; } -int mca_coll_ucx_progress(void) +static int mca_coll_ucx_create_worker() { - mca_coll_ucx_module_t *module = NULL; - ucs_list_for_each(module, &mca_coll_ucx_component.group_head, ucs_list) { - ucg_group_progress(module->ucg_group); + ucp_worker_params_t params; + ucs_status_t status; + + params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + if (ompi_mpi_thread_multiple) { + params.thread_mode = UCS_THREAD_MODE_MULTI; + } else { + params.thread_mode = UCS_THREAD_MODE_SINGLE; + } + status = ucp_worker_create(mca_coll_ucx_component.ucp_context, + ¶ms, + &mca_coll_ucx_component.ucp_worker); + if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to create ucp worker, %s", ucs_status_string(status)); + return OMPI_ERROR; } + return OMPI_SUCCESS; } -int mca_coll_ucx_init_worker(void) +static int mca_coll_ucx_check_worker() { - int rc; - ucg_worker_attr_t attr; + ucs_status_t status; + ucp_worker_attr_t attr; attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; - rc = ucg_worker_query(mca_coll_ucx_component.ucg_worker, &attr); - if (UCS_OK != rc) { - COLL_UCX_ERROR("Failed to query UCP worker thread level"); - rc = OMPI_ERROR; - return rc; - } + status = ucp_worker_query(mca_coll_ucx_component.ucp_worker, &attr); + if (UCS_OK != status) { + COLL_UCX_ERROR("Failed to query UCP worker thread mode, %s", ucs_status_string(status)); + return OMPI_ERROR; + } /* UCX does not support multithreading, disqualify current PML for now */ if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) { COLL_UCX_ERROR("UCP worker does not support MPI_THREAD_MULTIPLE"); - rc = OMPI_ERR_NOT_SUPPORTED; - return rc; + return OMPI_ERR_NOT_SUPPORTED; } - /* Share my UCP address, so it could be later obtained via @ref mca_coll_ucx_resolve_address */ - rc = mca_coll_ucx_send_worker_address(); - return rc; + return OMPI_SUCCESS; } -int mca_coll_ucx_init(void) +static int mca_coll_ucx_progress(void) { - ucg_worker_params_t params; - ucs_status_t status; - int rc; - - COLL_UCX_VERBOSE(1, "mca_coll_ucx_init"); - params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - params.thread_mode = UCS_THREAD_MODE_SINGLE; - if (ompi_mpi_thread_multiple) { - params.thread_mode = UCS_THREAD_MODE_MULTI; - } else { - params.thread_mode = UCS_THREAD_MODE_SINGLE; + mca_coll_ucx_module_t *module = NULL; + ucs_list_for_each(module, &mca_coll_ucx_component.group_head, ucs_list) { + ucg_group_progress(module->ucg_group); } + return OMPI_SUCCESS; +} - status = ucg_worker_create(mca_coll_ucx_component.ucg_context, ¶ms, &mca_coll_ucx_component.ucg_worker); - if (UCS_OK != status) { - COLL_UCX_WARN("Failed to create UCG worker, automatically select other available and highest " - "priority collective component."); - rc = OMPI_ERROR; - goto err; +static int mca_coll_ucx_init_worker(void) +{ + COLL_UCX_VERBOSE(1, "mca_coll_ucx_init_worker"); + if (OMPI_SUCCESS != mca_coll_ucx_create_worker()) { + return OMPI_ERROR; } - status = mca_coll_ucx_init_worker(); - if (UCS_OK != status) { - COLL_UCX_WARN("Failed to init UCG worker, automatically select other available and highest " - "priority collective component."); - rc = OMPI_ERROR; + if (OMPI_SUCCESS != mca_coll_ucx_check_worker()) { goto err_destroy_worker; } + /* Share my UCP address, so it could be later obtained via @ref mca_coll_ucx_resolve_address */ + if (OMPI_SUCCESS != mca_coll_ucx_send_worker_address()) { + goto err_destroy_worker; + } + /* Initialize the free lists */ OBJ_CONSTRUCT(&mca_coll_ucx_component.convs, mca_coll_ucx_freelist_t); COLL_UCX_FREELIST_INIT(&mca_coll_ucx_component.convs, mca_coll_ucx_convertor_t, 128, -1, 128); - rc = opal_progress_register(mca_coll_ucx_progress); - if (OPAL_SUCCESS != rc) { + if (OPAL_SUCCESS != opal_progress_register(mca_coll_ucx_progress)) { COLL_UCX_ERROR("Failed to progress register, automatically select other available and highest " "priority collective component."); goto err_destroy_worker; } - COLL_UCX_VERBOSE(2, "created ucp context %p, worker %p", (void *)mca_coll_ucx_component.ucg_context, - (void *)mca_coll_ucx_component.ucg_worker); - return rc; + return OMPI_SUCCESS; err_destroy_worker: - ucg_worker_destroy(mca_coll_ucx_component.ucg_worker); - mca_coll_ucx_component.ucg_worker = NULL; -err: - return rc; + ucp_worker_destroy(mca_coll_ucx_component.ucp_worker); + mca_coll_ucx_component.ucp_worker = NULL; + return OMPI_ERROR; } -void mca_coll_ucx_cleanup(void) +void mca_coll_ucx_cleanup_worker(void) { - COLL_UCX_VERBOSE(1, "mca_coll_ucx_cleanup"); + COLL_UCX_VERBOSE(1, "mca_coll_ucx_cleanup_worker"); opal_progress_unregister(mca_coll_ucx_progress); OBJ_DESTRUCT(&mca_coll_ucx_component.convs); - if (mca_coll_ucx_component.ucg_worker) { - ucg_worker_destroy(mca_coll_ucx_component.ucg_worker); - mca_coll_ucx_component.ucg_worker = NULL; + if (mca_coll_ucx_component.ucp_worker != NULL) { + ucp_worker_destroy(mca_coll_ucx_component.ucp_worker); + mca_coll_ucx_component.ucp_worker = NULL; + } + + return; +} + +static int mca_coll_ucx_open(void) +{ + COLL_UCX_VERBOSE(1, "mca_coll_ucx_open"); + + /* Initialize UCP context */ + if (OMPI_SUCCESS != mca_coll_ucx_init_ucp()) { + return OMPI_ERROR; + } + + /* Initialize UCG context */ + if (OMPI_SUCCESS != mca_coll_ucx_init_ucg()) { + goto err_ucp_cleanup; } + + /* Initialize UCP worker */ + if (OMPI_SUCCESS != mca_coll_ucx_init_worker()) { + goto err_ucg_cleaup; + } + + ucs_list_head_init(&mca_coll_ucx_component.group_head); + return OMPI_SUCCESS; + +err_ucg_cleaup: + ucg_cleanup(mca_coll_ucx_component.ucg_context); + mca_coll_ucx_component.ucg_context = NULL; +err_ucp_cleanup: + ucp_cleanup(mca_coll_ucx_component.ucp_context); + mca_coll_ucx_component.ucp_context = NULL; + + return OMPI_ERROR; +} + +static int mca_coll_ucx_close(void) +{ + COLL_UCX_VERBOSE(1, "mca_coll_ucx_close"); + + int i; + for (i = 0; i < OMPI_DATATYPE_MAX_PREDEFINED; ++i) { + if (mca_coll_ucx_component.predefined_types[i] != COLL_UCX_DATATYPE_INVALID) { + ucp_dt_destroy(mca_coll_ucx_component.predefined_types[i]); + mca_coll_ucx_component.predefined_types[i] = COLL_UCX_DATATYPE_INVALID; + } + } + + mca_coll_ucx_cleanup_worker(); + if (mca_coll_ucx_component.ucg_context != NULL) { + ucg_cleanup(mca_coll_ucx_component.ucg_context); + mca_coll_ucx_component.ucg_context = NULL; + } + + if (mca_coll_ucx_component.ucp_context != NULL) { + ucp_cleanup(mca_coll_ucx_component.ucp_context); + mca_coll_ucx_component.ucp_context = NULL; + } + if (mca_coll_ucx_component.topo_map) { for (unsigned i = 0; i < mca_coll_ucx_component.world_member_count; i++) { free(mca_coll_ucx_component.topo_map[i]); @@ -415,45 +461,110 @@ void mca_coll_ucx_cleanup(void) free(mca_coll_ucx_component.topo_map); mca_coll_ucx_component.topo_map = NULL; } + + return OMPI_SUCCESS; } -ucs_status_t mca_coll_ucx_resolve_address(void *cb_group_obj, ucg_group_member_index_t rank, ucg_address_t **addr, - size_t *addr_len) +static int ucx_open(void) { - /* Sanity checks */ - ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; - if (rank == (ucg_group_member_index_t)comm->c_my_rank) { - return UCS_ERR_UNSUPPORTED; + mca_coll_ucx_component.output = opal_output_open(NULL); + opal_output_set_verbosity(mca_coll_ucx_component.output, mca_coll_ucx_component.verbose); + + if (OMPI_SUCCESS != mca_coll_ucx_open()) { + return OMPI_ERROR; } + opal_common_ucx_mca_register(); - /* Check the cache for a previously established connection to that rank */ - ompi_proc_t *proc_peer = - (struct ompi_proc_t*)ompi_comm_peer_lookup((ompi_communicator_t*)cb_group_obj, rank); - *addr = proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL]; - *addr_len = 1; - if (*addr) { - return UCS_OK; + return OMPI_SUCCESS; +} + +static int ucx_close(void) +{ + if (mca_coll_ucx_component.ucp_worker == NULL) { + return OMPI_ERROR; } + opal_common_ucx_mca_deregister(); - /* Obtain the UCP address of the remote */ - int ret = mca_coll_ucx_recv_worker_address(proc_peer, addr, addr_len); - if (ret < 0) { - COLL_UCX_ERROR("mca_coll_ucx_recv_worker_address(proc=%d rank=%lu) failed", - proc_peer->super.proc_name.vpid, rank); - return UCS_ERR_INVALID_ADDR; + return mca_coll_ucx_close(); +} + +static int ucx_register(void) +{ + int status; + mca_coll_ucx_component.verbose = 0; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "verbosity", + "Verbosity of the UCX component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.verbose); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; } - /* Cache the connection for future invocations with this rank */ - proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL] = *addr; - return UCS_OK; + mca_coll_ucx_component.priority = 91; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "priority", + "Priority of the UCX component", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.priority); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + mca_coll_ucx_component.num_disconnect = 1; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "num_disconnect", + "How may disconnects go in parallel", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.num_disconnect); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + mca_coll_ucx_component.enable_topo_map = 1; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "enable_topo_map", + "Enable global topology map for ucg", + MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, + OPAL_INFO_LVL_3, + MCA_BASE_VAR_SCOPE_LOCAL, + &mca_coll_ucx_component.enable_topo_map); + if (status < OPAL_SUCCESS) { + return OMPI_ERROR; + } + + opal_common_ucx_mca_var_register(&mca_coll_ucx_component.super.collm_version); + return OMPI_SUCCESS; } -void mca_coll_ucx_release_address(ucg_address_t *addr) +static int mca_coll_ucx_init_query(bool enable_progress_threads, + bool enable_mpi_threads) { - /* no need to free - the address is stored in proc_peer->proc_endpoints */ + return OMPI_SUCCESS; +} + +static mca_coll_base_module_t *mca_coll_ucx_comm_query(struct ompi_communicator_t *comm, int *priority) +{ + /* basic checks */ + if ((OMPI_COMM_IS_INTER(comm)) || (ompi_comm_size(comm) < 2)) { + return NULL; + } + + /* create a new module for this communicator */ + COLL_UCX_VERBOSE(10, "Creating ucx_context for comm %p, comm_id %d, comm_size %d", + (void*)comm, comm->c_contextid, ompi_comm_size(comm)); + mca_coll_ucx_module_t *ucx_module = OBJ_NEW(mca_coll_ucx_module_t); + if (!ucx_module) { + return NULL; + } + + *priority = mca_coll_ucx_component.priority; + return &(ucx_module->super); } -ucg_worker_h mca_coll_ucx_get_component_worker() +ucp_worker_h mca_coll_ucx_get_component_worker() { - return mca_coll_ucx_component.ucg_worker; + return mca_coll_ucx_component.ucp_worker; } diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index c945b6dbdfb..662338a7f2a 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -28,7 +28,8 @@ #include #include -static int mca_coll_ucg_obtain_addr_from_hostname(const char *hostname, struct in_addr *ip_addr) +static int mca_coll_ucx_obtain_addr_from_hostname(const char *hostname, + struct in_addr *ip_addr) { struct addrinfo hints; struct addrinfo *res = NULL, *cur = NULL; @@ -54,89 +55,152 @@ static int mca_coll_ucg_obtain_addr_from_hostname(const char *hostname, struct i return OMPI_SUCCESS; } -static int mca_coll_ucg_obtain_node_index(unsigned member_count, struct ompi_communicator_t *comm, uint16_t *node_index) +static uint16_t* mca_coll_ucx_obtain_node_index(struct ompi_communicator_t *comm) { - ucg_group_member_index_t rank_idx, rank2_idx; - uint16_t same_node_flag; - uint16_t node_idx = 0; - uint16_t init_node_idx = (uint16_t) - 1; - int status, status2; - struct in_addr ip_address, ip_address2; + int status; + unsigned member_count = ompi_comm_size(comm); + uint16_t* node_idx = malloc(sizeof(uint16_t) * member_count); + if (node_idx == NULL) { + return NULL; + } + uint16_t invalid_node_idx = (uint16_t)-1; + for(unsigned i = 0; i < member_count; ++i) { + node_idx[i] = invalid_node_idx; + } - /* initialize: -1: unnumbering flag */ - for (rank_idx = 0; rank_idx < member_count; rank_idx++) { - node_index[rank_idx] = init_node_idx; + /*get ip address */ + struct in_addr *ip_address = malloc(sizeof(struct in_addr) * member_count); + if (ip_address == NULL) { + goto err_free_node_idx; + } + for(unsigned i = 0; i < member_count; ++i) { + ompi_proc_t *rank = ompi_comm_peer_lookup(comm, i); + status = mca_coll_ucx_obtain_addr_from_hostname(rank->super.proc_hostname, + ip_address + i); + if (status != OMPI_SUCCESS) { + goto err_free_ip_addr; + } } - - for (rank_idx = 0; rank_idx < member_count; rank_idx++) { - if (node_index[rank_idx] == init_node_idx) { - struct ompi_proc_t *rank_iter = - (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, rank_idx); - /* super.proc_hostname give IP address or real hostname */ - /* transform hostname to IP address for uniform format */ - status = mca_coll_ucg_obtain_addr_from_hostname(rank_iter->super.proc_hostname, &ip_address); - for (rank2_idx = rank_idx; rank2_idx < member_count; rank2_idx++) { - struct ompi_proc_t *rank2_iter = - (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, rank2_idx); - - status2 = mca_coll_ucg_obtain_addr_from_hostname(rank2_iter->super.proc_hostname, &ip_address2); - if (status != OMPI_SUCCESS || status2 != OMPI_SUCCESS) { - return OMPI_ERROR; - } - /* if rank_idx and rank2_idx in same node, same_flag = 1 */ - same_node_flag = (memcmp(&ip_address, &ip_address2, sizeof(ip_address))) ? 0 : 1; - if (same_node_flag == 1 && node_index[rank2_idx] == init_node_idx) { - node_index[rank2_idx] = node_idx; + /* assign node index, starts from 0 */ + uint16_t last_node_idx = 0; + for (unsigned i = 0; i < member_count; ++i) { + if (node_idx[i] == invalid_node_idx) { + node_idx[i] = last_node_idx; + /* find the node with same ipaddr, assign the same node idx */ + for (unsigned j = i+1; j < member_count; ++j) { + if (0 == memcmp(&ip_address[i], &ip_address[j], sizeof(struct in_addr))) { + node_idx[j] = last_node_idx; } } - node_idx++; + ++last_node_idx; } } - - /* make sure every rank has its node_index */ - for (rank_idx = 0; rank_idx < member_count; rank_idx++) { - /* some rank do NOT have node_index */ - if (node_index[rank_idx] == init_node_idx) { - return OMPI_ERROR; + free(ip_address); + return node_idx; + +err_free_ip_addr: + free(ip_address); +err_free_node_idx: + free(node_idx); + return NULL; +} + +static enum ucg_group_member_distance* mca_coll_ucx_obtain_distance(struct ompi_communicator_t *comm) +{ + int my_idx = ompi_comm_rank(comm); + int member_cnt = ompi_comm_size(comm); + enum ucg_group_member_distance *distance = malloc(member_cnt * sizeof(enum ucg_group_member_distance)); + if (distance == NULL) { + return NULL; + } + + struct ompi_proc_t *rank_iter; + for (int rank_idx = 0; rank_idx < member_cnt; ++rank_idx) { + rank_iter = ompi_comm_peer_lookup(comm, rank_idx); + rank_iter->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL] = NULL; + if (rank_idx == my_idx) { + distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_SELF; + } else if (OPAL_PROC_ON_LOCAL_L3CACHE(rank_iter->super.proc_flags)) { + distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_L3CACHE; + } else if (OPAL_PROC_ON_LOCAL_SOCKET(rank_iter->super.proc_flags)) { + distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_SOCKET; + } else if (OPAL_PROC_ON_LOCAL_HOST(rank_iter->super.proc_flags)) { + distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_HOST; + } else { + distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_NET; } } - return OMPI_SUCCESS; + return distance; } -static int mca_coll_ucx_create_topo_map(const uint16_t *node_index, const char *topo_info, unsigned loc_size, unsigned rank_cnt) +static void mca_coll_ucx_deallocate_topo_map(char **topo_map, unsigned member_count) { - mca_coll_ucx_component.topo_map = (char**)malloc(sizeof(char*) * rank_cnt); - if (mca_coll_ucx_component.topo_map == NULL) { - return OMPI_ERROR; + if (topo_map == NULL) { + return; + } + for (unsigned i = 0; i < member_count; ++i) { + if (topo_map[i] == NULL) { + /* The following are NULL too, so break */ + break; + } + free(topo_map[i]); + topo_map[i] = NULL; } + free(topo_map); + topo_map = NULL; + return; +} - unsigned i, j; - for (i = 0; i < rank_cnt; i++) { - mca_coll_ucx_component.topo_map[i] = (char*)malloc(sizeof(char) * rank_cnt); - if (mca_coll_ucx_component.topo_map[i] == NULL) { - for (j = 0; j < i; j++) { - free(mca_coll_ucx_component.topo_map[j]); - mca_coll_ucx_component.topo_map[j] = NULL; - } - free(mca_coll_ucx_component.topo_map); - mca_coll_ucx_component.topo_map = NULL; - return OMPI_ERROR; +static char** mca_coll_ucx_allocate_topo_map(unsigned member_count) +{ + char **topo_map = malloc(sizeof(char*) * member_count); + if (topo_map == NULL) { + return NULL; + } + memset(topo_map, 0, sizeof(char*) * member_count); + + for (unsigned i = 0; i < member_count; ++i) { + topo_map[i] = malloc(sizeof(char) * member_count); + if (topo_map[i] == NULL) { + goto err; } + } + + return topo_map; +err: + mca_coll_ucx_deallocate_topo_map(topo_map, member_count); + return NULL; +} + +static char** mca_coll_ucx_create_topo_map(const uint16_t *node_index, + char *localities, + unsigned loc_size, + unsigned member_count) +{ + char **topo_map = mca_coll_ucx_allocate_topo_map(member_count); + if (topo_map == NULL) { + return NULL; + } + + unsigned i, j; + enum ucg_group_member_distance distance; + opal_hwloc_locality_t rel_loc; + for (i = 0; i < member_count; ++i) { for (j = 0; j <= i; j++) { if (i == j) { - mca_coll_ucx_component.topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_SELF; + topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_SELF; continue; } if (node_index[i] != node_index[j]) { - mca_coll_ucx_component.topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; - mca_coll_ucx_component.topo_map[j][i] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; + topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; + topo_map[j][i] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; continue; } - opal_hwloc_locality_t rel_loc = opal_hwloc_compute_relative_locality(topo_info + i * loc_size, topo_info + j * loc_size); - enum ucg_group_member_distance distance; + rel_loc = opal_hwloc_compute_relative_locality(localities + i * loc_size, + localities + j * loc_size); if (OPAL_PROC_ON_LOCAL_L3CACHE(rel_loc)) { distance = UCG_GROUP_MEMBER_DISTANCE_L3CACHE; } else if (OPAL_PROC_ON_LOCAL_SOCKET(rel_loc)) { @@ -146,11 +210,11 @@ static int mca_coll_ucx_create_topo_map(const uint16_t *node_index, const char * } else { distance = UCG_GROUP_MEMBER_DISTANCE_NET; } - mca_coll_ucx_component.topo_map[i][j] = (char)distance; - mca_coll_ucx_component.topo_map[j][i] = (char)distance; + topo_map[i][j] = (char)distance; + topo_map[j][i] = (char)distance; } } - return OMPI_SUCCESS; + return topo_map; } static int mca_coll_ucx_print_topo_map(unsigned rank_cnt, char **topo_map) @@ -178,215 +242,146 @@ static int mca_coll_ucx_print_topo_map(unsigned rank_cnt, char **topo_map) return status; } -static int mca_coll_ucx_init_global_topo(mca_coll_ucx_module_t *module) +static int mca_coll_ucx_convert_to_global_rank(struct ompi_communicator_t *comm, int rank) +{ + struct ompi_proc_t *proc = ompi_comm_peer_lookup(comm, rank); + if (proc == NULL) { + return -1; + } + + unsigned i; + unsigned member_count = ompi_comm_size(MPI_COMM_WORLD); + for (i = 0; i < member_count; ++i) { + struct ompi_proc_t *global_proc = ompi_comm_peer_lookup(MPI_COMM_WORLD, i); + if (global_proc == proc) { + return i; + } + } + + return -1; +} + +static int mca_coll_ucx_create_global_topo_map(mca_coll_ucx_module_t *module, + struct ompi_communicator_t *comm) { if (mca_coll_ucx_component.topo_map != NULL) { return OMPI_SUCCESS; } - - /* Derive the 'loc' string from pmix and gather all 'loc' string from all the ranks in the world. */ - int status = OMPI_SUCCESS; - uint16_t *node_index = NULL; - unsigned LOC_SIZE = 64; - unsigned rank_cnt = mca_coll_ucx_component.world_member_count = ompi_comm_size(MPI_COMM_WORLD); - char *topo_info = (char*)malloc(sizeof(char) * LOC_SIZE * rank_cnt); - if (topo_info == NULL) { - status = OMPI_ERROR; - goto end; - } - memset(topo_info, 0, sizeof(char) * LOC_SIZE * rank_cnt); + /*get my locality string*/ int ret; - char *val = NULL; + char *locality = NULL; OPAL_MODEX_RECV_VALUE_OPTIONAL(ret, OPAL_PMIX_LOCALITY_STRING, - &opal_proc_local_get()->proc_name, &val, OPAL_STRING); - if (val == NULL || ret != OMPI_SUCCESS) { - status = OMPI_ERROR; - goto end; + &opal_proc_local_get()->proc_name, &locality, OPAL_STRING); + if (locality == NULL || ret != OMPI_SUCCESS) { + free(locality); + return OMPI_ERROR; } - - ret = ompi_coll_base_allgather_intra_bruck(val, LOC_SIZE, MPI_CHAR, topo_info, LOC_SIZE, MPI_CHAR, MPI_COMM_WORLD, &module->super); + int locality_size = strlen(locality); + + /* gather all members locality */ + int member_count = ompi_comm_size(comm); + COLL_UCX_ASSERT(locality_size <= 64); + unsigned one_locality_size = 64 * sizeof(char); + unsigned total_locality_size = one_locality_size * member_count; + char *localities = (char*)malloc(total_locality_size); + if (localities == NULL) { + ret = OMPI_ERROR; + goto err_free_locality; + } + memset(localities, 0, total_locality_size); + ret = ompi_coll_base_allgather_intra_bruck(locality, locality_size, MPI_CHAR, + localities, one_locality_size, MPI_CHAR, + MPI_COMM_WORLD, &module->super); if (ret != OMPI_SUCCESS) { int err = MPI_ERR_INTERN; COLL_UCX_ERROR("ompi_coll_base_allgather_intra_bruck failed"); ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to init topo map"); } - - /* Obtain node index to indicate each 'loc' belongs to which node, - as 'loc' only has info of local machine and contains no network info. */ - node_index = (uint16_t*)malloc(rank_cnt * sizeof(uint16_t)); - if (node_index == NULL) { - status = OMPI_ERROR; - goto end; + /* get node index */ + uint16_t* node_idx = mca_coll_ucx_obtain_node_index(comm); + if (node_idx == NULL) { + ret = OMPI_ERROR; + goto err_free_localities; } - ret = mca_coll_ucg_obtain_node_index(rank_cnt, MPI_COMM_WORLD, node_index); - if (ret != OMPI_SUCCESS) { - status = OMPI_ERROR; - goto end; + /* create topology map */ + char **topo_map = mca_coll_ucx_create_topo_map(node_idx, + localities, + one_locality_size, + member_count); + if (topo_map == NULL) { + ret = OMPI_ERROR; + goto err_free_node_idx; } - /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ - ret = mca_coll_ucx_create_topo_map(node_index, topo_info, LOC_SIZE, rank_cnt); - if (ret != OMPI_SUCCESS) { - status = OMPI_ERROR; - goto end; - } - - ret = mca_coll_ucx_print_topo_map(rank_cnt, mca_coll_ucx_component.topo_map); - if (ret != OMPI_SUCCESS) { - status = OMPI_ERROR; - } + /* save to global variable */ + mca_coll_ucx_component.topo_map = topo_map; + mca_coll_ucx_component.world_member_count = member_count; + ret = OMPI_SUCCESS; -end: - if (val) { - free(val); - val = NULL; - } +err_free_node_idx: + free(node_idx); +err_free_localities: + free(localities); +err_free_locality: + free(locality); - if (node_index) { - free(node_index); - node_index = NULL; - } - if (topo_info) { - free(topo_info); - topo_info = NULL; - } - return status; + return ret; } -static int mca_coll_ucx_find_rank_in_comm_world(struct ompi_communicator_t *comm, int comm_rank) +static char** mca_coll_ucx_obtain_topo_map(mca_coll_ucx_module_t *module, + struct ompi_communicator_t *comm) { - struct ompi_proc_t *proc = (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, comm_rank); - if (proc == NULL) { - return -1; - } - - unsigned i; - for (i = 0; i < ompi_comm_size(MPI_COMM_WORLD); i++) { - struct ompi_proc_t *rank_iter = (struct ompi_proc_t*)ompi_comm_peer_lookup(MPI_COMM_WORLD, i); - if (rank_iter == proc) { - return i; + if (mca_coll_ucx_component.topo_map == NULL) { + /* global topo map is always needed. */ + if (OMPI_SUCCESS != mca_coll_ucx_create_global_topo_map(module, comm)) { + return NULL; } } - - return -1; -} -static int mca_coll_ucx_create_comm_topo(ucg_group_params_t *args, struct ompi_communicator_t *comm) -{ - int status; if (comm == MPI_COMM_WORLD) { - if (args->topo_map != NULL) { - free(args->topo_map); - } - args->topo_map = mca_coll_ucx_component.topo_map; - return OMPI_SUCCESS; + return mca_coll_ucx_component.topo_map; } - /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ - unsigned i; - for (i = 0; i < args->member_count; i++) { + unsigned member_count = ompi_comm_size(comm); + char **topo_map = mca_coll_ucx_allocate_topo_map(member_count); + if (topo_map == NULL) { + return NULL; + } + /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ + for (unsigned i = 0; i < member_count; ++i) { /* Find the rank in the MPI_COMM_WORLD for rank i in the comm. */ - int world_rank_i = mca_coll_ucx_find_rank_in_comm_world(comm, i); - if (world_rank_i == -1) { - return OMPI_ERROR; + int i_global_rank = mca_coll_ucx_convert_to_global_rank(comm, i); + if (i_global_rank == -1) { + goto err_free_topo_map; } - for (unsigned j = 0; j <= i; j++) { - int world_rank_j = mca_coll_ucx_find_rank_in_comm_world(comm, j); - if (world_rank_j == -1) { - return OMPI_ERROR; + for (unsigned j = 0; j <= i; ++j) { + int j_global_rank = mca_coll_ucx_convert_to_global_rank(comm, j); + if (j_global_rank == -1) { + goto err_free_topo_map; } - args->topo_map[i][j] = mca_coll_ucx_component.topo_map[world_rank_i][world_rank_j]; - args->topo_map[j][i] = mca_coll_ucx_component.topo_map[world_rank_j][world_rank_i]; + topo_map[i][j] = mca_coll_ucx_component.topo_map[i_global_rank][j_global_rank]; + topo_map[j][i] = mca_coll_ucx_component.topo_map[j_global_rank][i_global_rank]; } } - status = mca_coll_ucx_print_topo_map(args->member_count, args->topo_map); - return status; -} + mca_coll_ucx_print_topo_map(member_count, topo_map); -static void mca_coll_ucg_create_distance_array(struct ompi_communicator_t *comm, ucg_group_member_index_t my_idx, ucg_group_params_t *args) -{ - ucg_group_member_index_t rank_idx; - for (rank_idx = 0; rank_idx < args->member_count; rank_idx++) { - struct ompi_proc_t *rank_iter = (struct ompi_proc_t*)ompi_comm_peer_lookup(comm, rank_idx); - rank_iter->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_COLL] = NULL; - if (rank_idx == my_idx) { - args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_SELF; - } else if (OPAL_PROC_ON_LOCAL_L3CACHE(rank_iter->super.proc_flags)) { - args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_L3CACHE; - } else if (OPAL_PROC_ON_LOCAL_SOCKET(rank_iter->super.proc_flags)) { - args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_SOCKET; - } else if (OPAL_PROC_ON_LOCAL_HOST(rank_iter->super.proc_flags)) { - args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_HOST; - } else { - args->distance[rank_idx] = UCG_GROUP_MEMBER_DISTANCE_NET; - } - } -} + return topo_map; -static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt, - ucp_datatype_t *ucp_dt) -{ - *ucp_dt = mca_coll_ucx_get_datatype(mpi_dt); - return 0; +err_free_topo_map: + mca_coll_ucx_deallocate_topo_map(topo_map, member_count); + return NULL; } -static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap) +static void mca_coll_ucx_free_topo_map(char **topo_map, unsigned member_count) { - struct ompi_datatype_t *dtype = (struct ompi_datatype_t *)dt_ext; - ptrdiff_t dsize, gp= 0; - - dsize = opal_datatype_span(&dtype->super, count, &gp); - *gap = gp; - return dsize; -} - -static ucg_group_member_index_t mca_coll_ucx_get_global_member_idx(void *cb_group_obj, - ucg_group_member_index_t index) -{ - ompi_communicator_t* comm = (ompi_communicator_t*)cb_group_obj; - return (ucg_group_member_index_t)mca_coll_ucx_find_rank_in_comm_world(comm, (int)index); -} - -static void mca_coll_ucg_init_group_param(struct ompi_communicator_t *comm, ucg_group_params_t *args) -{ - args->member_count = ompi_comm_size(comm); - args->cid = ompi_comm_get_cid(comm); - args->mpi_reduce_f = ompi_op_reduce; - args->resolve_address_f = mca_coll_ucx_resolve_address; - args->release_address_f = mca_coll_ucx_release_address; - args->cb_group_obj = comm; - args->op_is_commute_f = ompi_op_is_commute; - args->mpi_dt_convert = mca_coll_ucg_datatype_convert; - args->mpi_datatype_span = coll_ucx_datatype_span; - args->mpi_global_idx_f = mca_coll_ucx_get_global_member_idx; -} - -static void mca_coll_ucg_arg_free(struct ompi_communicator_t *comm, ucg_group_params_t *args) -{ - unsigned i; - - if (args->distance != NULL) { - free(args->distance); - args->distance = NULL; + /* mca_coll_ucx_component.topo_map will be freed in mca_coll_ucx_module_destruct() */ + if (topo_map != mca_coll_ucx_component.topo_map) { + mca_coll_ucx_deallocate_topo_map(topo_map, member_count); } - if (args->node_index != NULL) { - free(args->node_index); - args->node_index = NULL; - } - - if (comm != MPI_COMM_WORLD && args->topo_map != NULL) { - for (i = 0; i < args->member_count; i++) { - if (args->topo_map[i] != NULL) { - free(args->topo_map[i]); - args->topo_map[i] = NULL; - } - } - free(args->topo_map); - args->topo_map = NULL; - } + return; } static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params, mca_coll_ucx_module_t *module, @@ -407,95 +402,93 @@ static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params return; } -static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_communicator_t *comm) +static int mca_coll_ucx_init_ucg_group_params(mca_coll_ucx_module_t *module, + struct ompi_communicator_t *comm, + ucg_group_params_t *params) { - ucs_status_t error; - ucg_group_params_t args; - ucg_group_member_index_t my_idx; - int status = OMPI_SUCCESS; - unsigned i; + memset(params, 0, sizeof(*params)); + uint16_t binding_policy = OPAL_GET_BINDING_POLICY(opal_hwloc_binding_policy); + params->field_mask = UCG_GROUP_PARAM_FIELD_UCP_WORKER | + UCG_GROUP_PARAM_FIELD_ID | + UCG_GROUP_PARAM_FIELD_MEMBER_COUNT | + UCG_GROUP_PARAM_FIELD_DISTANCE | + UCG_GROUP_PARAM_FIELD_NODE_INDEX | + UCG_GROUP_PARAM_FIELD_BIND_TO_NONE | + UCG_GROUP_PARAM_FIELD_CB_GROUP_IBJ | + UCG_GROUP_PARAM_FIELD_IS_SOCKET_BALANCE; + params->ucp_worker = mca_coll_ucx_component.ucp_worker; + params->group_id = ompi_comm_get_cid(comm); + params->member_count = ompi_comm_size(comm); + params->distance = mca_coll_ucx_obtain_distance(comm); + if (params->distance == NULL) { + return OMPI_ERROR; + } + params->node_index = mca_coll_ucx_obtain_node_index(comm); + if (params->node_index == NULL) { + goto err_free_distane; + } + params->is_bind_to_none = binding_policy == OPAL_BIND_TO_NONE; + params->cb_group_obj = comm; + mca_coll_ucg_init_is_socket_balance(params, module, comm); + if (mca_coll_ucx_component.enable_topo_map && binding_policy == OPAL_BIND_TO_CORE) { + params->field_mask |= UCG_GROUP_PARAM_FIELD_TOPO_MAP; + params->topo_map = mca_coll_ucx_obtain_topo_map(module, comm); + if (params->topo_map == NULL) { + goto err_node_idx; + } + } + return OMPI_SUCCESS; +err_node_idx: + free(params->node_index); + params->node_index = NULL; +err_free_distane: + free(params->distance); + params->distance = NULL; + return OMPI_ERROR; +} + +static void mca_coll_ucx_cleanup_group_params(ucg_group_params_t *params) +{ + if (params->topo_map != NULL) { + mca_coll_ucx_free_topo_map(params->topo_map, params->member_count); + params->topo_map = NULL; + } + if (params->node_index != NULL) { + free(params->node_index); + params->node_index = NULL; + } + if (params->distance != NULL) { + free(params->distance); + params->distance = NULL; + } + return; +} +static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_communicator_t *comm) +{ #if OMPI_GROUP_SPARSE COLL_UCX_ERROR("Sparse process groups are not supported"); return UCS_ERR_UNSUPPORTED; #endif - - /* Fill in group initialization parameters */ - my_idx = ompi_comm_rank(comm); - mca_coll_ucg_init_group_param(comm, &args); - args.distance = malloc(args.member_count * sizeof(*args.distance)); - args.node_index = malloc(args.member_count * sizeof(*args.node_index)); - args.is_bind_to_none = (OPAL_BIND_TO_NONE == OPAL_GET_BINDING_POLICY(opal_hwloc_binding_policy)); - args.topo_map = NULL; - - if (args.distance == NULL || args.node_index == NULL) { - MCA_COMMON_UCX_WARN("Failed to allocate memory for %lu local ranks", args.member_count); - status = OMPI_ERROR; - goto out; - } - - if (mca_coll_ucx_component.enable_topo_map && (OPAL_BIND_TO_CORE == OPAL_GET_BINDING_POLICY(opal_hwloc_binding_policy))) { - /* Initialize global topology map. */ - args.topo_map = (char**)malloc(sizeof(char*) * args.member_count); - if (args.topo_map == NULL) { - MCA_COMMON_UCX_WARN("Failed to allocate memory for %lu local ranks", args.member_count); - status = OMPI_ERROR; - goto out; - } - - for (i = 0; i < args.member_count; i++) { - args.topo_map[i] = (char*)malloc(sizeof(char) * args.member_count); - if (args.topo_map[i] == NULL) { - MCA_COMMON_UCX_WARN("Failed to allocate memory for %lu local ranks", args.member_count); - status = OMPI_ERROR; - goto out; - } - } - - status = mca_coll_ucx_init_global_topo(module); - if (status != OMPI_SUCCESS) { - MCA_COMMON_UCX_WARN("Failed to create global topology."); - status = OMPI_ERROR; - goto out; - } - - if (status == OMPI_SUCCESS) { - status = mca_coll_ucx_create_comm_topo(&args, comm); - if (status != OMPI_SUCCESS) { - MCA_COMMON_UCX_WARN("Failed to create communicator topology."); - status = OMPI_ERROR; - goto out; - } - } - } - - /* Generate (temporary) rank-distance array */ - mca_coll_ucg_create_distance_array(comm, my_idx, &args); - - /* Generate node_index for each process */ - status = mca_coll_ucg_obtain_node_index(args.member_count, comm, args.node_index); - - if (status != OMPI_SUCCESS) { - status = OMPI_ERROR; - goto out; + ucg_group_params_t params; + if (OMPI_SUCCESS != mca_coll_ucx_init_ucg_group_params(module, comm, ¶ms)) { + return OMPI_ERROR; } - - mca_coll_ucg_init_is_socket_balance(&args, module, comm); - error = ucg_group_create(mca_coll_ucx_component.ucg_worker, &args, &module->ucg_group); - /* Examine comm_new return value */ - if (error != UCS_OK) { - MCA_COMMON_UCX_WARN("ucg_new failed: %s", ucs_status_string(error)); - status = OMPI_ERROR; - goto out; + ucs_status_t status = ucg_group_create(mca_coll_ucx_component.ucg_context, + ¶ms, + &module->ucg_group); + if (status != UCS_OK) { + COLL_UCX_ERROR("Failed to create ucg group, %s", ucs_status_string(status)); + goto err_cleanup_params; } ucs_list_add_tail(&mca_coll_ucx_component.group_head, &module->ucs_list); - status = OMPI_SUCCESS; + return OMPI_SUCCESS; -out: - mca_coll_ucg_arg_free(comm, &args); - return status; +err_cleanup_params: + mca_coll_ucx_cleanup_group_params(¶ms); + return OMPI_ERROR; } /* @@ -558,12 +551,11 @@ static void mca_coll_ucx_module_construct(mca_coll_ucx_module_t *module) module->super.coll_allreduce = mca_coll_ucx_allreduce; module->super.coll_barrier = mca_coll_ucx_barrier; module->super.coll_bcast = mca_coll_ucx_bcast; - ucs_list_head_init(&module->ucs_list); } static void mca_coll_ucx_module_destruct(mca_coll_ucx_module_t *module) { - if (module->ucg_group) { + if (module->ucg_group != NULL) { ucg_group_destroy(module->ucg_group); } ucs_list_del(&module->ucs_list); diff --git a/ompi/mca/coll/ucx/coll_ucx_op.c b/ompi/mca/coll/ucx/coll_ucx_op.c index b1867446b7e..b7fc6241cdc 100644 --- a/ompi/mca/coll/ucx/coll_ucx_op.c +++ b/ompi/mca/coll/ucx/coll_ucx_op.c @@ -174,7 +174,7 @@ int mca_coll_ucx_allreduce(const void *sbuf, void *rbuf, int count, goto exit; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx allreduce", (void)0); COLL_UCX_TRACE("%s", sbuf, rbuf, count, dtype, comm, "allreduce END"); @@ -290,7 +290,7 @@ int mca_coll_ucx_reduce(const void *sbuf, void* rbuf, int count, return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx reduce", (void)0); } @@ -327,7 +327,7 @@ int mca_coll_ucx_scatter(const void *sbuf, int scount, struct ompi_datatype_t *s return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx scatter", (void)0); } @@ -364,7 +364,7 @@ int mca_coll_ucx_gather(const void *sbuf, int scount, struct ompi_datatype_t *sd return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx gather", (void)0); } @@ -401,7 +401,7 @@ int mca_coll_ucx_allgather(const void *sbuf, int scount, struct ompi_datatype_t return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx allgather", (void)0); } @@ -438,7 +438,7 @@ int mca_coll_ucx_alltoall(const void *sbuf, int scount, struct ompi_datatype_t * return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx alltoall", (void)0); } @@ -465,7 +465,7 @@ int mca_coll_ucx_barrier(struct ompi_communicator_t *comm, mca_coll_base_module_ return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx barrier", (void)0); } @@ -502,6 +502,6 @@ int mca_coll_ucx_bcast(void *buff, int count, struct ompi_datatype_t *dtype, int return OMPI_SUCCESS; } - ucp_worker_h ucp_worker = mca_coll_ucx_component.ucg_worker; + ucp_worker_h ucp_worker = mca_coll_ucx_component.ucp_worker; MCA_COMMON_UCX_WAIT_LOOP(req, OPAL_COMMON_UCX_REQUEST_TYPE_UCG, ucp_worker, "ucx bcast", (void)0); } diff --git a/ompi/mca/coll/ucx/coll_ucx_request.c b/ompi/mca/coll/ucx/coll_ucx_request.c index 1c5161eee7e..280509d1496 100644 --- a/ompi/mca/coll/ucx/coll_ucx_request.c +++ b/ompi/mca/coll/ucx/coll_ucx_request.c @@ -10,7 +10,7 @@ * * $HEADER$ */ - +#include "ompi_config.h" #include #include "coll_ucx_request.h" #include "ompi/message/message.h" @@ -30,7 +30,7 @@ static int mca_coll_ucx_request_free(ompi_request_t **rptr) static int mca_coll_ucx_request_cancel(ompi_request_t *req, int flag) { - ucg_request_cancel(mca_coll_ucx_component.ucg_worker, req); + ucg_request_cancel(mca_coll_ucx_component.ucg_group, req); return OMPI_SUCCESS; } @@ -137,7 +137,7 @@ static int mca_coll_ucx_persistent_op_cancel(ompi_request_t *req, int flag) mca_coll_ucx_persistent_op_t* preq = (mca_coll_ucx_persistent_op_t*)req; if (preq->tmp_req != NULL) { - ucg_request_cancel(preq->ucg_worker, preq->tmp_req); + ucg_request_cancel(preq->ucg_group, preq->tmp_req); } return OMPI_SUCCESS; } diff --git a/ompi/mca/coll/ucx/coll_ucx_request.h b/ompi/mca/coll/ucx/coll_ucx_request.h index d988419746e..e59ef0146de 100644 --- a/ompi/mca/coll/ucx/coll_ucx_request.h +++ b/ompi/mca/coll/ucx/coll_ucx_request.h @@ -26,7 +26,7 @@ struct coll_ucx_persistent_op { ompi_request_t ompi; ompi_request_t *tmp_req; ucg_coll_h coll_desc; - ucg_worker_h ucg_worker; + ucg_group_h ucg_group; unsigned flags; }; From 06b6116bbf029dcdf2fcd9b51d001e5585797444 Mon Sep 17 00:00:00 2001 From: RainybIue <993835762@qq.com> Date: Mon, 22 Feb 2021 15:50:24 +0800 Subject: [PATCH 19/20] MCA/COLL/UCX: Optimize part of the code format --- ompi/mca/coll/ucx/coll_ucx_component.c | 7 +++---- ompi/mca/coll/ucx/coll_ucx_module.c | 22 ++++++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx_component.c b/ompi/mca/coll/ucx/coll_ucx_component.c index ec4827f540d..c0030a4b21f 100644 --- a/ompi/mca/coll/ucx/coll_ucx_component.c +++ b/ompi/mca/coll/ucx/coll_ucx_component.c @@ -197,9 +197,8 @@ static int mca_coll_ucg_datatype_convert(ompi_datatype_t *mpi_dt, static ptrdiff_t coll_ucx_datatype_span(void *dt_ext, int count, ptrdiff_t *gap) { struct ompi_datatype_t *dtype = (struct ompi_datatype_t *)dt_ext; - ptrdiff_t dsize, gp= 0; - - dsize = opal_datatype_span(&dtype->super, count, &gp); + ptrdiff_t gp = 0; + ptrdiff_t dsize = opal_datatype_span(&dtype->super, count, &gp); *gap = gp; return dsize; } @@ -326,7 +325,7 @@ static int mca_coll_ucx_check_worker() if (UCS_OK != status) { COLL_UCX_ERROR("Failed to query UCP worker thread mode, %s", ucs_status_string(status)); return OMPI_ERROR; - } + } /* UCX does not support multithreading, disqualify current PML for now */ if (ompi_mpi_thread_multiple && (attr.thread_mode != UCS_THREAD_MODE_MULTI)) { diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 662338a7f2a..5617da4d3b6 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -64,16 +64,16 @@ static uint16_t* mca_coll_ucx_obtain_node_index(struct ompi_communicator_t *comm return NULL; } uint16_t invalid_node_idx = (uint16_t)-1; - for(unsigned i = 0; i < member_count; ++i) { + for (unsigned i = 0; i < member_count; ++i) { node_idx[i] = invalid_node_idx; } - /*get ip address */ + /* get ip address */ struct in_addr *ip_address = malloc(sizeof(struct in_addr) * member_count); if (ip_address == NULL) { goto err_free_node_idx; } - for(unsigned i = 0; i < member_count; ++i) { + for (unsigned i = 0; i < member_count; ++i) { ompi_proc_t *rank = ompi_comm_peer_lookup(comm, i); status = mca_coll_ucx_obtain_addr_from_hostname(rank->super.proc_hostname, ip_address + i); @@ -88,8 +88,8 @@ static uint16_t* mca_coll_ucx_obtain_node_index(struct ompi_communicator_t *comm if (node_idx[i] == invalid_node_idx) { node_idx[i] = last_node_idx; /* find the node with same ipaddr, assign the same node idx */ - for (unsigned j = i+1; j < member_count; ++j) { - if (0 == memcmp(&ip_address[i], &ip_address[j], sizeof(struct in_addr))) { + for (unsigned j = i + 1; j < member_count; ++j) { + if (memcmp(&ip_address[i], &ip_address[j], sizeof(struct in_addr)) == 0) { node_idx[j] = last_node_idx; } } @@ -267,7 +267,7 @@ static int mca_coll_ucx_create_global_topo_map(mca_coll_ucx_module_t *module, if (mca_coll_ucx_component.topo_map != NULL) { return OMPI_SUCCESS; } - /*get my locality string*/ + /* get my locality string */ int ret; char *locality = NULL; OPAL_MODEX_RECV_VALUE_OPTIONAL(ret, OPAL_PMIX_LOCALITY_STRING, @@ -391,8 +391,14 @@ static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params unsigned ppn = ucg_builtin_calculate_ppx(group_params, UCG_GROUP_MEMBER_DISTANCE_HOST); char is_socket_balance = (pps == (ppn - pps) || pps == ppn); char result = is_socket_balance; - int status = ompi_coll_base_allreduce_intra_basic_linear(&is_socket_balance, &result, 1, MPI_CHAR, MPI_MIN, - comm, &module->super); + int status = ompi_coll_base_barrier_intra_basic_linear(comm, &module->super); + if (status != OMPI_SUCCESS) { + int error = MPI_ERR_INTERN; + COLL_UCX_ERROR("ompi_coll_base_barrier_intra_basic_linear failed"); + ompi_mpi_errors_are_fatal_comm_handler(NULL, &error, "Failed to init is_socket_balance"); + } + status = ompi_coll_base_allreduce_intra_basic_linear(&is_socket_balance, &result, 1, MPI_CHAR, MPI_MIN, + comm, &module->super); if (status != OMPI_SUCCESS) { int error = MPI_ERR_INTERN; COLL_UCX_ERROR("ompi_coll_base_allreduce_intra_basic_linear failed"); From 41098d006212abb2bc3ee6c1caf8c42632319727 Mon Sep 17 00:00:00 2001 From: shizhibao Date: Mon, 15 Mar 2021 15:25:07 +0800 Subject: [PATCH 20/20] Delete topo_map --- ompi/mca/coll/ucx/coll_ucx.h | 42 +- ompi/mca/coll/ucx/coll_ucx_component.c | 26 +- ompi/mca/coll/ucx/coll_ucx_module.c | 783 ++++++++++++++++--------- 3 files changed, 572 insertions(+), 279 deletions(-) diff --git a/ompi/mca/coll/ucx/coll_ucx.h b/ompi/mca/coll/ucx/coll_ucx.h index b2f5381d544..659639a822c 100644 --- a/ompi/mca/coll/ucx/coll_ucx.h +++ b/ompi/mca/coll/ucx/coll_ucx.h @@ -49,6 +49,40 @@ BEGIN_C_DECLS typedef struct coll_ucx_persistent_op mca_coll_ucx_persistent_op_t; typedef struct coll_ucx_convertor mca_coll_ucx_convertor_t; +typedef enum { + COLL_UCX_TOPO_LEVEL_ROOT, + COLL_UCX_TOPO_LEVEL_NODE, + COLL_UCX_TOPO_LEVEL_SOCKET, + COLL_UCX_TOPO_LEVEL_L3CACHE, +} coll_ucx_topo_level_t; + +typedef union coll_ucx_topo_tree { + struct { + int rank_nums; + int child_nums; + union coll_ucx_topo_tree *child; + } inter; + struct { + int rank_nums; + int rank_min; + int rank_max; + } leaf; +} coll_ucx_topo_tree_t; + +typedef struct { + uint32_t node_id : 24; + uint32_t sock_id : 8; +} rank_location_t; + +typedef struct { + int rank_nums; + int node_nums; + int sock_nums; + coll_ucx_topo_level_t level; + coll_ucx_topo_tree_t tree; + rank_location_t *locs; +} coll_ucx_topo_info_t; + typedef struct mca_coll_ucx_component { /* base MCA collectives component */ mca_coll_base_component_t super; @@ -57,7 +91,7 @@ typedef struct mca_coll_ucx_component { int priority; int verbose; int num_disconnect; - bool enable_topo_map; + int topo_aware_level; /* UCX global objects */ ucp_context_h ucp_context; @@ -66,8 +100,7 @@ typedef struct mca_coll_ucx_component { ucg_group_h ucg_group; int output; ucs_list_link_t group_head; - char **topo_map; - unsigned world_member_count; + coll_ucx_topo_info_t topo; /* Requests */ mca_coll_ucx_freelist_t persistent_ops; @@ -86,6 +119,9 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucx_component_t mca_coll_ucx_component; typedef struct mca_coll_ucx_module { mca_coll_base_module_t super; + /* per-communicator topo info and op interface */ + coll_ucx_topo_tree_t *topo_tree; + /* UCX per-communicator context */ ucg_group_h ucg_group; diff --git a/ompi/mca/coll/ucx/coll_ucx_component.c b/ompi/mca/coll/ucx/coll_ucx_component.c index c0030a4b21f..27be6f71f55 100644 --- a/ompi/mca/coll/ucx/coll_ucx_component.c +++ b/ompi/mca/coll/ucx/coll_ucx_component.c @@ -70,8 +70,11 @@ mca_coll_ucx_component_t mca_coll_ucx_component = { .priority = 91, /* priority */ .verbose = 0, /* verbose level */ .num_disconnect = 0, /* ucx_enable */ - .enable_topo_map = 1, /* enable topology map */ - .topo_map = NULL + .topo_aware_level = COLL_UCX_TOPO_LEVEL_SOCKET, + .topo = { + .level = COLL_UCX_TOPO_LEVEL_ROOT, + .locs = NULL, + }, }; static int mca_coll_ucx_send_worker_address(void) @@ -452,14 +455,7 @@ static int mca_coll_ucx_close(void) mca_coll_ucx_component.ucp_context = NULL; } - if (mca_coll_ucx_component.topo_map) { - for (unsigned i = 0; i < mca_coll_ucx_component.world_member_count; i++) { - free(mca_coll_ucx_component.topo_map[i]); - mca_coll_ucx_component.topo_map[i] = NULL; - } - free(mca_coll_ucx_component.topo_map); - mca_coll_ucx_component.topo_map = NULL; - } + mca_coll_ucx_destroy_global_topo(); return OMPI_SUCCESS; } @@ -523,13 +519,13 @@ static int ucx_register(void) return OMPI_ERROR; } - mca_coll_ucx_component.enable_topo_map = 1; - status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "enable_topo_map", - "Enable global topology map for ucg", - MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, + mca_coll_ucx_component.topo_aware_level = COLL_UCX_TOPO_LEVEL_SOCKET; + status = mca_base_component_var_register(&mca_coll_ucx_component.super.collm_version, "topo_aware_level", + "Topology aware level for ucg", + MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_3, MCA_BASE_VAR_SCOPE_LOCAL, - &mca_coll_ucx_component.enable_topo_map); + &mca_coll_ucx_component.topo_aware_level); if (status < OPAL_SUCCESS) { return OMPI_ERROR; } diff --git a/ompi/mca/coll/ucx/coll_ucx_module.c b/ompi/mca/coll/ucx/coll_ucx_module.c index 5617da4d3b6..73fafe2fa85 100644 --- a/ompi/mca/coll/ucx/coll_ucx_module.c +++ b/ompi/mca/coll/ucx/coll_ucx_module.c @@ -28,6 +28,516 @@ #include #include +static inline int mca_coll_ucx_get_world_rank(ompi_communicator_t *comm, int rank) +{ + ompi_proc_t *proc = ompi_comm_peer_lookup(comm, rank); + + return ((ompi_process_name_t*)&proc->super.proc_name)->vpid; +} + +static inline int mca_coll_ucx_get_node_nums(uint32_t *node_nums) +{ + int rc; + opal_process_name_t wildcard_rank; + + wildcard_rank.jobid = ORTE_PROC_MY_NAME->jobid; + wildcard_rank.vpid = ORTE_NAME_WILDCARD->vpid; + + /* get number of nodes in the job */ + OPAL_MODEX_RECV_VALUE_OPTIONAL(rc, OPAL_PMIX_NUM_NODES, + &wildcard_rank, &node_nums, OPAL_UINT32); + + return rc; +} + +static inline int mca_coll_ucx_get_nodeid(ompi_communicator_t *comm, int rank, uint32_t *nodeid) +{ + int rc; + ompi_proc_t *proc; + + proc = ompi_comm_peer_lookup(comm, rank); + OPAL_MODEX_RECV_VALUE_OPTIONAL(rc, OPAL_PMIX_NODEID, + &(proc->super.proc_name), &nodeid, OPAL_UINT32); + + return rc; +} + +static int mca_coll_ucx_fill_loc_detail(mca_coll_ucx_module_t *module, rank_location_t *locs, int size) +{ + int i, rc; + char *val, *beg; + uint8_t sockid, max_sockid, *sockids; + + OPAL_MODEX_RECV_VALUE_OPTIONAL(rc, OPAL_PMIX_LOCALITY_STRING, + &(opal_proc_local_get()->proc_name), &val, OPAL_STRING); + if (rc != OMPI_SUCCESS || val == NULL) { + COLL_UCX_ERROR("fail to get locality string, error code:%d", rc); + return OMPI_ERROR; + } + + sockids = (uint8_t *)malloc(sizeof(uint8_t) * size); + if (sockids == NULL) { + free(val); + COLL_UCX_ERROR("fail to alloc sockid array, rank_nums:%d", size); + return OMPI_ERR_OUT_OF_RESOURCE; + } + + beg = strstr(val, "SK") + strlen("SK"); + sockid = (uint8_t)atoi(beg); + rc = ompi_coll_base_allgather_intra_bruck(&sockid, 1, MPI_UINT8_T, sockids, 1, MPI_UINT8_T, + MPI_COMM_WORLD, &module->super); + if (rc != OMPI_SUCCESS) { + free(val); + free(sockids); + COLL_UCX_ERROR("ompi_coll_base_allgather_intra_bruck fail"); + ompi_mpi_errors_are_fatal_comm_handler(NULL, &rc, "fail to gather sockids"); + } + + max_sockid = 0; + for (i = 0; i < size; i++) { + locs[i].sock_id = sockids[i]; + if (sockids[i] > max_sockid) { + max_sockid = sockids[i]; + } + } + mca_coll_ucx_component.topo.sock_nums = max_sockid + 1; + + free(val); + free(sockids); + return OMPI_SUCCESS; +} + +static inline coll_ucx_topo_level_t mca_coll_ucx_get_topo_level() +{ + if (mca_coll_ucx_component.topo_aware_level >= COLL_UCX_TOPO_LEVEL_SOCKET && + OPAL_GET_BINDING_POLICY(opal_hwloc_binding_policy) >= OPAL_BIND_TO_SOCKET) { + return COLL_UCX_TOPO_LEVEL_SOCKET; + } + + return COLL_UCX_TOPO_LEVEL_NODE; +} + +static inline int mca_coll_ucx_get_topo_child_nums(coll_ucx_topo_level_t level) +{ + if (level >= COLL_UCX_TOPO_LEVEL_SOCKET) { + return mca_coll_ucx_component.topo.sock_nums; + } + + return mca_coll_ucx_component.topo.node_nums; +} + +static inline coll_ucx_topo_tree_t *mca_coll_ucx_get_topo_child(coll_ucx_topo_tree_t *root, + coll_ucx_topo_level_t level, + int rank) +{ + int nodeid, sockid; + + if (level >= COLL_UCX_TOPO_LEVEL_SOCKET) { + sockid = mca_coll_ucx_component.topo.locs[rank].sock_id; + return &root->inter.child[sockid]; + } + + nodeid = mca_coll_ucx_component.topo.locs[rank].node_id; + return &root->inter.child[nodeid]; +} + +static int mca_coll_ucx_build_topo_tree(coll_ucx_topo_tree_t *root, + coll_ucx_topo_level_t level) +{ + int i, rc, child_nums; + coll_ucx_topo_tree_t *child; + + if (level >= mca_coll_ucx_component.topo.level) { + root->leaf.rank_nums = 0; + return OMPI_SUCCESS; + } + + level++; + + root->inter.rank_nums = 0; + root->inter.child = NULL; + child_nums = mca_coll_ucx_get_topo_child_nums(level); + child = (coll_ucx_topo_tree_t *)malloc(sizeof(*child) * child_nums); + if (child == NULL) { + COLL_UCX_ERROR("fail to alloc children, child_nums:%d, child_level:%d, component_level:%d", + child_nums, level, mca_coll_ucx_component.topo.level); + return OMPI_ERR_OUT_OF_RESOURCE; + } + + root->inter.child_nums = child_nums; + root->inter.child = child; + for (i = 0; i < child_nums; i++) { + rc = mca_coll_ucx_build_topo_tree(&child[i], level); + if (rc != OMPI_SUCCESS) { + return rc; + } + } + + return OMPI_SUCCESS; +} + +static void mca_coll_ucx_destroy_topo_tree(coll_ucx_topo_tree_t *root, + coll_ucx_topo_level_t level) +{ + int i, child_nums; + coll_ucx_topo_tree_t *child; + + if (level >= mca_coll_ucx_component.topo.level) { + return; + } + + level++; + + child = root->inter.child; + if (child == NULL) { + return; + } + + child_nums = root->inter.child_nums; + for (i = 0; i < child_nums; i++) { + mca_coll_ucx_destroy_topo_tree(&child[i], level); + } + + free(child); + root->inter.child = NULL; +} + +static void mca_coll_ucx_update_topo_tree(coll_ucx_topo_tree_t *root, + coll_ucx_topo_level_t level, + int rank) +{ + int i, rc, child_nums; + coll_ucx_topo_tree_t *child; + + if (level >= mca_coll_ucx_component.topo.level) { + if (root->leaf.rank_nums == 0) { + root->leaf.rank_min = rank; + } + root->leaf.rank_max = rank; + root->leaf.rank_nums++; + return; + } + + root->inter.rank_nums++; + level++; + + child = mca_coll_ucx_get_topo_child(root, level, rank); + return mca_coll_ucx_update_topo_tree(child, level, rank); +} + +static int mca_coll_ucx_init_global_topo(mca_coll_ucx_module_t *module) +{ + int i, rc, rank_nums; + uint32_t node_nums, nodeid; + rank_location_t *locs; + coll_ucx_topo_tree_t *root; + + if (mca_coll_ucx_component.topo.locs != NULL) { + return OMPI_SUCCESS; + } + + rank_nums = ompi_comm_size(MPI_COMM_WORLD); + locs = (rank_location_t *)malloc(sizeof(*locs) * rank_nums); + if (locs == NULL) { + COLL_UCX_ERROR("fail to alloc rank location array, rank_nums:%d", rank_nums); + return OMPI_ERR_OUT_OF_RESOURCE; + } + + mca_coll_ucx_component.topo.locs = locs; + mca_coll_ucx_component.topo.rank_nums = rank_nums; + + rc = mca_coll_ucx_get_node_nums(&node_nums); + if (rc != OMPI_SUCCESS) { + COLL_UCX_ERROR("fail to get node_nums, error code:%d", rc); + return rc; + } + + mca_coll_ucx_component.topo.node_nums = node_nums; + + mca_coll_ucx_component.topo.level = mca_coll_ucx_get_topo_level(); + if (mca_coll_ucx_component.topo.level > COLL_UCX_TOPO_LEVEL_NODE) { + rc = mca_coll_ucx_fill_loc_detail(module, locs, rank_nums); + if (rc != OMPI_SUCCESS) { + return rc; + } + } + + root = &mca_coll_ucx_component.topo.tree; + rc = mca_coll_ucx_build_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT); + if (rc != OMPI_SUCCESS) { + COLL_UCX_ERROR("fail to init global topo tree"); + return rc; + } + + for (i = 0; i < rank_nums; i++) { + rc = mca_coll_ucx_get_nodeid(MPI_COMM_WORLD, i, &nodeid); + if (rc != OMPI_SUCCESS) { + COLL_UCX_ERROR("fail to get nodeid, error code:%d", rc); + return rc; + } + locs[i].node_id = nodeid; + mca_coll_ucx_update_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT, i); + } + + return OMPI_SUCCESS; +} + +void mca_coll_ucx_destroy_global_topo() +{ + rank_location_t *locs; + coll_ucx_topo_tree_t *root; + + root = &mca_coll_ucx_component.topo.tree; + mca_coll_ucx_destroy_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT); + + locs = mca_coll_ucx_component.topo.locs; + if (locs != NULL) { + free(locs); + mca_coll_ucx_component.topo.locs = NULL; + } +} + +static int mca_coll_ucx_init_comm_topo(mca_coll_ucx_module_t *module, ompi_communicator_t *comm) +{ + int i, rc, rank_nums, global_rank; + coll_ucx_topo_tree_t *root; + + if (comm == MPI_COMM_WORLD) { + module->topo_tree = &mca_coll_ucx_component.topo.tree; + return OMPI_SUCCESS; + } + + root = (coll_ucx_topo_tree_t *)malloc(sizeof(*root)); + if (root == NULL) { + COLL_UCX_ERROR("fail to alloc communicator topo tree root"); + return OMPI_ERR_OUT_OF_RESOURCE; + } + + module->topo_tree = root; + rc = mca_coll_ucx_build_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT); + if (rc != OMPI_SUCCESS) { + COLL_UCX_ERROR("fail to init communicator topo tree"); + return rc; + } + + rank_nums = ompi_comm_size(comm); + for (i = 0; i < rank_nums; i++) { + global_rank = mca_coll_ucx_get_world_rank(comm, i); + mca_coll_ucx_update_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT, global_rank); + } + + return OMPI_SUCCESS; +} + +static void mca_coll_ucx_destroy_comm_topo(mca_coll_ucx_module_t *module) +{ + coll_ucx_topo_tree_t *root = module->topo_tree; + + if (root == NULL || root == &mca_coll_ucx_component.topo.tree) { + return; + } + + mca_coll_ucx_destroy_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT); + free(root); + module->topo_tree = NULL; +} + +static int mca_coll_ucx_init_topo_info(mca_coll_ucx_module_t *module, ompi_communicator_t *comm) +{ + int rc; + + if (comm == MPI_COMM_WORLD) { + rc = mca_coll_ucx_init_global_topo(module); + if (rc != OMPI_SUCCESS) { + return rc; + } + } + + return mca_coll_ucx_init_comm_topo(module, comm); +} + +static void mca_coll_ucx_check_node_aware_tree(coll_ucx_topo_tree_t *root, ucg_topo_args_t *arg) +{ + int i; + coll_ucx_topo_tree_t *node; + + arg->rank_continuous_in_node = 1; + arg->rank_continuous_in_sock = 0; + arg->rank_balance_in_node = 1; + arg->rank_balance_in_sock = 0; + + for (i = 0; i < root->inter.child_nums; i++) { + node = &root->inter.child[i]; + if (node->leaf.rank_nums == 0) { + continue; + } + + COLL_UCX_VERBOSE(1, "node%d:rank_nums=%d,min_rank=%d,max_rank=%d", + i, node->leaf.rank_nums, node->leaf.rank_min, node->leaf.rank_max); + + if (node->leaf.rank_max - node->leaf.rank_min + 1 != node->leaf.rank_nums) { + arg->rank_continuous_in_node = 0; + return; + } + } +} + +static void mca_coll_ucx_check_sock_aware_tree(coll_ucx_topo_tree_t *root, ucg_topo_args_t *arg) +{ + int i, j, sock_nums, min, max, rank_nums1, rank_nums2; + coll_ucx_topo_tree_t *node; + coll_ucx_topo_tree_t *sock; + + arg->rank_continuous_in_node = 1; + arg->rank_continuous_in_sock = 1; + arg->rank_balance_in_node = 1; + arg->rank_balance_in_sock = 1; + + for (i = 0; i < root->inter.child_nums; i++) { + node = &root->inter.child[i]; + if (node->inter.rank_nums == 0) { + continue; + } + sock_nums = 0; + rank_nums1 = 0; + rank_nums2 = 0; + for (j = 0; j < node->inter.child_nums; j++) { + sock = &node->inter.child[j]; + if (sock->leaf.rank_nums == 0) { + continue; + } + if (sock_nums == 0) { + min = sock->leaf.rank_min; + max = sock->leaf.rank_max; + rank_nums1 = sock->leaf.rank_nums; + } else { + min = sock->leaf.rank_min < min ? sock->leaf.rank_min : min; + max = sock->leaf.rank_max > max ? sock->leaf.rank_max : max; + rank_nums2 = sock->leaf.rank_nums; + } + sock_nums++; + if (sock->leaf.rank_max - sock->leaf.rank_min + 1 != sock->leaf.rank_nums) { + arg->rank_continuous_in_sock = 0; + } + } + + COLL_UCX_VERBOSE(1, "node%d:rank_nums=%d,min_rank=%d,max_rank=%d,sock_num=%d,sock1_nums=%d,sock2_nums=%d", + i, node->inter.rank_nums, min, max, sock_nums, rank_nums1, rank_nums2); + + if (max - min + 1 != node->inter.rank_nums) { + arg->rank_continuous_in_node = 0; + } + if (sock_nums > 2 || (sock_nums == 2 && rank_nums1 != rank_nums2)) { + arg->rank_balance_in_sock = 0; + } + } +} + +static void mca_coll_ucx_print_ucg_topo_args(const ucg_topo_args_t *arg) +{ + COLL_UCX_VERBOSE(1, "ucg_topo_args:rank_continuous_in_node=%d", arg->rank_continuous_in_node); + COLL_UCX_VERBOSE(1, "ucg_topo_args:rank_continuous_in_sock=%d", arg->rank_continuous_in_sock); + COLL_UCX_VERBOSE(1, "ucg_topo_args:rank_balance_in_node=%d", arg->rank_balance_in_node); + COLL_UCX_VERBOSE(1, "ucg_topo_args:rank_balance_in_sock=%d", arg->rank_balance_in_sock); +} + +static void mca_coll_ucx_set_ucg_topo_args(mca_coll_ucx_module_t *module, ucg_topo_args_t *arg) +{ + if (mca_coll_ucx_component.topo.level >= COLL_UCX_TOPO_LEVEL_SOCKET) { + mca_coll_ucx_check_sock_aware_tree(module->topo_tree, arg); + } else { + mca_coll_ucx_check_node_aware_tree(module->topo_tree, arg); + } + + mca_coll_ucx_print_ucg_topo_args(arg); +} + +static void mca_coll_ucx_print_topo_tree(coll_ucx_topo_tree_t *root, + coll_ucx_topo_level_t level) +{ + int i, child_nums; + coll_ucx_topo_tree_t *child; + + if (level >= mca_coll_ucx_component.topo.level) { + COLL_UCX_VERBOSE(1, "ranks info:nums=%d,min=%d,max=%d", + root->leaf.rank_nums, + root->leaf.rank_min, + root->leaf.rank_max); + return; + } + + level++; + + child = root->inter.child; + if (child == NULL) { + return; + } + + child_nums = root->inter.child_nums; + for (i = 0; i < child_nums; i++) { + COLL_UCX_VERBOSE(1, "%s %d/%d:rank_nums=%d", level == COLL_UCX_TOPO_LEVEL_NODE ? + "node" : "socket", i, child_nums, child[i].inter.rank_nums); + mca_coll_ucx_print_topo_tree(&child[i], level); + } +} + +static void mca_coll_ucx_print_global_topo() +{ + int i, j, rows, cols, len, rank_nums; + char logbuf[512]; + char *buf = logbuf; + rank_location_t *locs; + coll_ucx_topo_tree_t *root; + + locs = mca_coll_ucx_component.topo.locs; + if (locs == NULL) { + return; + } + + cols = 32; + rank_nums = mca_coll_ucx_component.topo.rank_nums; + rows = rank_nums / cols; + for (i = 0; i < rows; i++) { + for (j = 0; j < cols; j++) { + len = sprintf(buf, "(%u,%u)", locs->node_id, locs->sock_id); + locs++; + buf += len; + } + *buf = '\0'; + buf = logbuf; + COLL_UCX_VERBOSE(1, "rank %d~%d location:%s", i * cols, (i + 1) * cols - 1, buf); + } + + for (j = rows * cols; j < rank_nums; j++) { + len = sprintf(buf, "(%u,%u)", locs->node_id, locs->sock_id); + locs++; + buf += len; + } + *buf = '\0'; + buf = logbuf; + COLL_UCX_VERBOSE(1, "rank %d~%d location:%s", rows * cols, rank_nums - 1, buf); +} + +static void mca_coll_ucx_print_comm_topo(mca_coll_ucx_module_t *module) +{ + coll_ucx_topo_tree_t *root = module->topo_tree; + + if (root == NULL) { + return; + } + + mca_coll_ucx_print_topo_tree(root, COLL_UCX_TOPO_LEVEL_ROOT); +} + +static void mca_coll_ucx_print_topo_info(mca_coll_ucx_module_t *module, ompi_communicator_t *comm) +{ + if (comm == MPI_COMM_WORLD) { + mca_coll_ucx_print_global_topo(); + } + + return mca_coll_ucx_print_comm_topo(module); +} + static int mca_coll_ucx_obtain_addr_from_hostname(const char *hostname, struct in_addr *ip_addr) { @@ -134,256 +644,6 @@ static enum ucg_group_member_distance* mca_coll_ucx_obtain_distance(struct ompi_ return distance; } -static void mca_coll_ucx_deallocate_topo_map(char **topo_map, unsigned member_count) -{ - if (topo_map == NULL) { - return; - } - for (unsigned i = 0; i < member_count; ++i) { - if (topo_map[i] == NULL) { - /* The following are NULL too, so break */ - break; - } - free(topo_map[i]); - topo_map[i] = NULL; - } - free(topo_map); - topo_map = NULL; - return; -} - -static char** mca_coll_ucx_allocate_topo_map(unsigned member_count) -{ - char **topo_map = malloc(sizeof(char*) * member_count); - if (topo_map == NULL) { - return NULL; - } - memset(topo_map, 0, sizeof(char*) * member_count); - - for (unsigned i = 0; i < member_count; ++i) { - topo_map[i] = malloc(sizeof(char) * member_count); - if (topo_map[i] == NULL) { - goto err; - } - } - - return topo_map; -err: - mca_coll_ucx_deallocate_topo_map(topo_map, member_count); - return NULL; -} - -static char** mca_coll_ucx_create_topo_map(const uint16_t *node_index, - char *localities, - unsigned loc_size, - unsigned member_count) -{ - char **topo_map = mca_coll_ucx_allocate_topo_map(member_count); - if (topo_map == NULL) { - return NULL; - } - - unsigned i, j; - enum ucg_group_member_distance distance; - opal_hwloc_locality_t rel_loc; - for (i = 0; i < member_count; ++i) { - for (j = 0; j <= i; j++) { - if (i == j) { - topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_SELF; - continue; - } - - if (node_index[i] != node_index[j]) { - topo_map[i][j] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; - topo_map[j][i] = (char)UCG_GROUP_MEMBER_DISTANCE_NET; - continue; - } - - rel_loc = opal_hwloc_compute_relative_locality(localities + i * loc_size, - localities + j * loc_size); - if (OPAL_PROC_ON_LOCAL_L3CACHE(rel_loc)) { - distance = UCG_GROUP_MEMBER_DISTANCE_L3CACHE; - } else if (OPAL_PROC_ON_LOCAL_SOCKET(rel_loc)) { - distance = UCG_GROUP_MEMBER_DISTANCE_SOCKET; - } else if (OPAL_PROC_ON_LOCAL_HOST(rel_loc)) { - distance = UCG_GROUP_MEMBER_DISTANCE_HOST; - } else { - distance = UCG_GROUP_MEMBER_DISTANCE_NET; - } - topo_map[i][j] = (char)distance; - topo_map[j][i] = (char)distance; - } - } - return topo_map; -} - -static int mca_coll_ucx_print_topo_map(unsigned rank_cnt, char **topo_map) -{ - int status = OMPI_SUCCESS; - - /* Print topo map for rank 0. */ - if (ompi_comm_rank(MPI_COMM_WORLD) == 0) { - unsigned i; - for (i = 0; i < rank_cnt; i++) { - char *topo_print = (char*)malloc(rank_cnt + 1); - if (topo_print == NULL) { - status = OMPI_ERROR; - return status; - } - for (unsigned j = 0; j < rank_cnt; j++) { - topo_print[j] = '0' + (int)topo_map[i][j]; - } - topo_print[rank_cnt] = '\0'; - COLL_UCX_VERBOSE(8, "%s\n", topo_print); - free(topo_print); - topo_print = NULL; - } - } - return status; -} - -static int mca_coll_ucx_convert_to_global_rank(struct ompi_communicator_t *comm, int rank) -{ - struct ompi_proc_t *proc = ompi_comm_peer_lookup(comm, rank); - if (proc == NULL) { - return -1; - } - - unsigned i; - unsigned member_count = ompi_comm_size(MPI_COMM_WORLD); - for (i = 0; i < member_count; ++i) { - struct ompi_proc_t *global_proc = ompi_comm_peer_lookup(MPI_COMM_WORLD, i); - if (global_proc == proc) { - return i; - } - } - - return -1; -} - -static int mca_coll_ucx_create_global_topo_map(mca_coll_ucx_module_t *module, - struct ompi_communicator_t *comm) -{ - if (mca_coll_ucx_component.topo_map != NULL) { - return OMPI_SUCCESS; - } - /* get my locality string */ - int ret; - char *locality = NULL; - OPAL_MODEX_RECV_VALUE_OPTIONAL(ret, OPAL_PMIX_LOCALITY_STRING, - &opal_proc_local_get()->proc_name, &locality, OPAL_STRING); - if (locality == NULL || ret != OMPI_SUCCESS) { - free(locality); - return OMPI_ERROR; - } - int locality_size = strlen(locality); - - /* gather all members locality */ - int member_count = ompi_comm_size(comm); - COLL_UCX_ASSERT(locality_size <= 64); - unsigned one_locality_size = 64 * sizeof(char); - unsigned total_locality_size = one_locality_size * member_count; - char *localities = (char*)malloc(total_locality_size); - if (localities == NULL) { - ret = OMPI_ERROR; - goto err_free_locality; - } - memset(localities, 0, total_locality_size); - ret = ompi_coll_base_allgather_intra_bruck(locality, locality_size, MPI_CHAR, - localities, one_locality_size, MPI_CHAR, - MPI_COMM_WORLD, &module->super); - if (ret != OMPI_SUCCESS) { - int err = MPI_ERR_INTERN; - COLL_UCX_ERROR("ompi_coll_base_allgather_intra_bruck failed"); - ompi_mpi_errors_are_fatal_comm_handler(NULL, &err, "Failed to init topo map"); - } - /* get node index */ - uint16_t* node_idx = mca_coll_ucx_obtain_node_index(comm); - if (node_idx == NULL) { - ret = OMPI_ERROR; - goto err_free_localities; - } - - /* create topology map */ - char **topo_map = mca_coll_ucx_create_topo_map(node_idx, - localities, - one_locality_size, - member_count); - if (topo_map == NULL) { - ret = OMPI_ERROR; - goto err_free_node_idx; - } - - /* save to global variable */ - mca_coll_ucx_component.topo_map = topo_map; - mca_coll_ucx_component.world_member_count = member_count; - ret = OMPI_SUCCESS; - -err_free_node_idx: - free(node_idx); -err_free_localities: - free(localities); -err_free_locality: - free(locality); - - return ret; -} - -static char** mca_coll_ucx_obtain_topo_map(mca_coll_ucx_module_t *module, - struct ompi_communicator_t *comm) -{ - if (mca_coll_ucx_component.topo_map == NULL) { - /* global topo map is always needed. */ - if (OMPI_SUCCESS != mca_coll_ucx_create_global_topo_map(module, comm)) { - return NULL; - } - } - - if (comm == MPI_COMM_WORLD) { - return mca_coll_ucx_component.topo_map; - } - - unsigned member_count = ompi_comm_size(comm); - char **topo_map = mca_coll_ucx_allocate_topo_map(member_count); - if (topo_map == NULL) { - return NULL; - } - /* Create a topo matrix. As it is Diagonal symmetry, only half of the matrix will be computed. */ - for (unsigned i = 0; i < member_count; ++i) { - /* Find the rank in the MPI_COMM_WORLD for rank i in the comm. */ - int i_global_rank = mca_coll_ucx_convert_to_global_rank(comm, i); - if (i_global_rank == -1) { - goto err_free_topo_map; - } - for (unsigned j = 0; j <= i; ++j) { - int j_global_rank = mca_coll_ucx_convert_to_global_rank(comm, j); - if (j_global_rank == -1) { - goto err_free_topo_map; - } - topo_map[i][j] = mca_coll_ucx_component.topo_map[i_global_rank][j_global_rank]; - topo_map[j][i] = mca_coll_ucx_component.topo_map[j_global_rank][i_global_rank]; - } - } - - mca_coll_ucx_print_topo_map(member_count, topo_map); - - return topo_map; - -err_free_topo_map: - mca_coll_ucx_deallocate_topo_map(topo_map, member_count); - return NULL; -} - -static void mca_coll_ucx_free_topo_map(char **topo_map, unsigned member_count) -{ - /* mca_coll_ucx_component.topo_map will be freed in mca_coll_ucx_module_destruct() */ - if (topo_map != mca_coll_ucx_component.topo_map) { - mca_coll_ucx_deallocate_topo_map(topo_map, member_count); - } - - return; -} - static void mca_coll_ucg_init_is_socket_balance(ucg_group_params_t *group_params, mca_coll_ucx_module_t *module, struct ompi_communicator_t *comm) { @@ -425,6 +685,7 @@ static int mca_coll_ucx_init_ucg_group_params(mca_coll_ucx_module_t *module, params->ucp_worker = mca_coll_ucx_component.ucp_worker; params->group_id = ompi_comm_get_cid(comm); params->member_count = ompi_comm_size(comm); + mca_coll_ucx_set_ucg_topo_args(module, ¶ms->topo_args); params->distance = mca_coll_ucx_obtain_distance(comm); if (params->distance == NULL) { return OMPI_ERROR; @@ -436,13 +697,6 @@ static int mca_coll_ucx_init_ucg_group_params(mca_coll_ucx_module_t *module, params->is_bind_to_none = binding_policy == OPAL_BIND_TO_NONE; params->cb_group_obj = comm; mca_coll_ucg_init_is_socket_balance(params, module, comm); - if (mca_coll_ucx_component.enable_topo_map && binding_policy == OPAL_BIND_TO_CORE) { - params->field_mask |= UCG_GROUP_PARAM_FIELD_TOPO_MAP; - params->topo_map = mca_coll_ucx_obtain_topo_map(module, comm); - if (params->topo_map == NULL) { - goto err_node_idx; - } - } return OMPI_SUCCESS; err_node_idx: free(params->node_index); @@ -455,10 +709,6 @@ static int mca_coll_ucx_init_ucg_group_params(mca_coll_ucx_module_t *module, static void mca_coll_ucx_cleanup_group_params(ucg_group_params_t *params) { - if (params->topo_map != NULL) { - mca_coll_ucx_free_topo_map(params->topo_map, params->member_count); - params->topo_map = NULL; - } if (params->node_index != NULL) { free(params->node_index); params->node_index = NULL; @@ -476,6 +726,12 @@ static int mca_coll_ucg_create(mca_coll_ucx_module_t *module, struct ompi_commun COLL_UCX_ERROR("Sparse process groups are not supported"); return UCS_ERR_UNSUPPORTED; #endif + + if (OMPI_SUCCESS != mca_coll_ucx_init_topo_info(module, comm)) { + COLL_UCX_ERROR("fail to init topo info"); + return OMPI_ERROR; + } + ucg_group_params_t params; if (OMPI_SUCCESS != mca_coll_ucx_init_ucg_group_params(module, comm, ¶ms)) { return OMPI_ERROR; @@ -557,6 +813,8 @@ static void mca_coll_ucx_module_construct(mca_coll_ucx_module_t *module) module->super.coll_allreduce = mca_coll_ucx_allreduce; module->super.coll_barrier = mca_coll_ucx_barrier; module->super.coll_bcast = mca_coll_ucx_bcast; + + ucs_list_head_init(&module->ucs_list); } static void mca_coll_ucx_module_destruct(mca_coll_ucx_module_t *module) @@ -564,7 +822,10 @@ static void mca_coll_ucx_module_destruct(mca_coll_ucx_module_t *module) if (module->ucg_group != NULL) { ucg_group_destroy(module->ucg_group); } + ucs_list_del(&module->ucs_list); + + mca_coll_ucx_destroy_comm_topo(module); } OBJ_CLASS_INSTANCE(mca_coll_ucx_module_t,