Skip to content

Commit

Permalink
[Huawei]HMPI: Fix some bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Jiakun Liang <[email protected]>
  • Loading branch information
JKLiang9714 committed Apr 20, 2023
1 parent 951640a commit 1aa6802
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 70 deletions.
57 changes: 0 additions & 57 deletions ompi/mca/coll/ucg/coll_ucg_allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,55 +49,6 @@ static int mca_coll_ucg_request_allgatherv_init(mca_coll_ucg_req_t *coll_req,
return OMPI_SUCCESS;
}

static int mca_coll_ucg_allgatherv_check(const void *sbuf, ompi_datatype_t *sdtype,
const int *rcounts, ompi_datatype_t *rdtype,
ompi_communicator_t *comm)
{
size_t sdtype_size, rdtype_size;
ompi_datatype_type_size(sdtype, &sdtype_size);
ompi_datatype_type_size(rdtype, &rdtype_size);

size_t dt_size = (sbuf == MPI_IN_PLACE) ? rdtype_size : sdtype_size;
size_t total_msg_size = 0;
int group_size = ompi_comm_size(comm);
for (int i = 0; i < group_size; ++i) {
total_msg_size += dt_size * rcounts[i];
}
size_t avg_size = total_msg_size / group_size;
// TODO: Small message performance is bad
int supported = 1;
if (group_size <= 16) {
if (avg_size <= 64) {
supported = 0;
}
} else if (group_size <= 32) {
if (avg_size <= 1024) {
supported = 0;
}
} else if (group_size <= 64) {
if (avg_size <= 2048) {
supported = 0;
}
} else if (group_size <= 128) {
if (avg_size <= 512) {
supported = 0;
}
} else if (group_size <= 256) {
if (avg_size <= 1024) {
supported = 0;
}
} else if (group_size <= 512) {
if (avg_size <= 256) {
supported = 0;
}
}else {
if (avg_size <= 64) {
supported = 0;
}
}
return supported ? OMPI_SUCCESS : OMPI_ERR_NOT_SUPPORTED;
}

int mca_coll_ucg_allgatherv(const void *sbuf, int scount, ompi_datatype_t *sdtype,
void *rbuf, const int *rcounts, const int *disps,
ompi_datatype_t *rdtype, ompi_communicator_t *comm,
Expand All @@ -109,10 +60,6 @@ int mca_coll_ucg_allgatherv(const void *sbuf, int scount, ompi_datatype_t *sdtyp
mca_coll_ucg_req_t coll_req;
OBJ_CONSTRUCT(&coll_req, mca_coll_ucg_req_t);
int rc;
rc = mca_coll_ucg_allgatherv_check(sbuf, sdtype, rcounts, rdtype, comm);
if (rc != OMPI_SUCCESS) {
goto fallback;
}
rc = mca_coll_ucg_request_common_init(&coll_req, false, false);
if (rc != OMPI_SUCCESS) {
goto fallback;
Expand Down Expand Up @@ -163,10 +110,6 @@ int mca_coll_ucg_allgatherv_cache(const void *sbuf, int scount, ompi_datatype_t
};

int rc;
rc = mca_coll_ucg_allgatherv_check(sbuf, sdtype, rcounts, rdtype, comm);
if (rc != OMPI_SUCCESS) {
goto fallback;
}
rc = mca_coll_ucg_request_execute_cache(&args);
if (rc == OMPI_SUCCESS) {
return rc;
Expand Down
11 changes: 10 additions & 1 deletion ompi/mca/coll/ucg/coll_ucg_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,13 @@ static int mca_coll_ucg_init_once(ompi_communicator_t *comm)
goto err_cleanup_conv_pool;
}

uint32_t size = (uint32_t)ompi_comm_size(comm);
rc = mca_coll_ucg_subargs_pool_init(size);
if (rc != OMPI_SUCCESS) {
UCG_ERROR("Failed to init subargs mpool, %d", rc);
goto err_cleanup_rpool;
}

if (ompi_mpi_thread_multiple) {
UCG_DEBUG("rcache is non-thread-safe, disable it");
cm->max_rcache_size = 0;
Expand All @@ -372,7 +379,7 @@ static int mca_coll_ucg_init_once(ompi_communicator_t *comm)
UCG_DEBUG("max rcache size is %d", cm->max_rcache_size);
rc = mca_coll_ucg_rcache_init(cm->max_rcache_size);
if (rc != OMPI_SUCCESS) {
goto err_cleanup_rpool;
goto err_cleanup_subargs_pool;
}
}

Expand Down Expand Up @@ -405,6 +412,8 @@ static int mca_coll_ucg_init_once(ompi_communicator_t *comm)
if (cm->max_rcache_size > 0) {
mca_coll_ucg_rcache_cleanup();
}
err_cleanup_subargs_pool:
mca_coll_ucg_subargs_pool_cleanup();
err_cleanup_rpool:
mca_coll_ucg_rpool_cleanup();
err_cleanup_conv_pool:
Expand Down
183 changes: 171 additions & 12 deletions ompi/mca/coll/ucg/coll_ucg_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
mca_coll_ucg_rpool_t mca_coll_ucg_rpool = {0};
static mca_coll_ucg_rcache_t mca_coll_ucg_rcache;

mca_coll_ucg_subargs_pool_t mca_coll_ucg_subargs_pool = {0};

static void ucg_coll_ucg_rcache_ref(mca_coll_ucg_req_t *coll_req)
{
mca_coll_ucg_args_t *args = &coll_req->args;
Expand Down Expand Up @@ -148,7 +150,6 @@ static void mca_coll_ucg_request_complete(void *arg, ucg_status_t status)
} else {
ompi_req->req_status.MPI_ERROR = MPI_ERR_INTERN;
}
ompi_req->req_state = OMPI_REQUEST_INACTIVE;
ompi_request_complete(ompi_req, true);
return;
}
Expand Down Expand Up @@ -184,6 +185,11 @@ OBJ_CLASS_INSTANCE(mca_coll_ucg_req_t,
NULL,
NULL);

OBJ_CLASS_INSTANCE(mca_coll_ucg_subargs_t,
opal_free_list_item_t,
NULL,
NULL);

int mca_coll_ucg_rpool_init()
{
OBJ_CONSTRUCT(&mca_coll_ucg_rpool.flist, opal_free_list_t);
Expand All @@ -201,6 +207,24 @@ void mca_coll_ucg_rpool_cleanup()
return;
}

int mca_coll_ucg_subargs_pool_init(uint32_t size)
{
OBJ_CONSTRUCT(&mca_coll_ucg_subargs_pool.flist, opal_free_list_t);
int rc = opal_free_list_init(&mca_coll_ucg_subargs_pool.flist,
sizeof(mca_coll_ucg_subargs_t) + 4 * size * sizeof(int),
opal_cache_line_size, OBJ_CLASS(mca_coll_ucg_subargs_t),
0, 0,
0, INT_MAX, 128,
NULL, 0, NULL, NULL, NULL);
return rc == OPAL_SUCCESS ? OMPI_SUCCESS : OMPI_ERROR;
}

void mca_coll_ucg_subargs_pool_cleanup()
{
OBJ_DESTRUCT(&mca_coll_ucg_subargs_pool.flist);
return;
}

int mca_coll_ucg_rcache_init(int size)
{
if (size <= 0) {
Expand All @@ -226,11 +250,124 @@ void mca_coll_ucg_rcache_cleanup()
return;
}

static void mca_coll_ucg_rcache_coll_req_args_init(mca_coll_ucg_args_t *dst,
const mca_coll_ucg_args_t *src)
{
*dst = *src;
int *scounts, *sdispls, *rcounts, *rdispls, *disps;
uint32_t size = (uint32_t)ompi_comm_size(src->comm);
mca_coll_ucg_subargs_t *args = mca_coll_ucg_subargs_pool_get();

switch (src->coll_type) {
case MCA_COLL_UCG_TYPE_ALLTOALLV:
case MCA_COLL_UCG_TYPE_IALLTOALLV:
if (src->alltoallv.scounts == NULL ||
src->alltoallv.sdispls == NULL ||
src->alltoallv.rcounts == NULL ||
src->alltoallv.rdispls == NULL) {
return;
}
scounts = args->buf;
sdispls = scounts + size;
rcounts = sdispls + size;
rdispls = rcounts + size;
for (int i = 0; i < size; ++i) {
scounts[i] = src->alltoallv.scounts[i];
sdispls[i] = src->alltoallv.sdispls[i];
rcounts[i] = src->alltoallv.rcounts[i];
rdispls[i] = src->alltoallv.rdispls[i];
}
dst->alltoallv.scounts = scounts;
dst->alltoallv.sdispls = sdispls;
dst->alltoallv.rcounts = rcounts;
dst->alltoallv.rdispls = rdispls;
break;
case MCA_COLL_UCG_TYPE_SCATTERV:
case MCA_COLL_UCG_TYPE_ISCATTERV:
if (src->scatterv.scounts == NULL ||
src->scatterv.disps == NULL) {
return;
}
scounts = args->buf;
disps = scounts + size;
for (int i = 0; i < size; ++i) {
scounts[i] = src->scatterv.scounts[i];
disps[i] = src->scatterv.disps[i];
}
dst->scatterv.scounts = scounts;
dst->scatterv.disps = disps;
break;
case MCA_COLL_UCG_TYPE_GATHERV:
case MCA_COLL_UCG_TYPE_IGATHERV:
if (src->gatherv.rcounts == NULL ||
src->gatherv.disps == NULL) {
return;
}
rcounts = args->buf;
disps = rcounts + size;
for (int i = 0; i < size; ++i) {
rcounts[i] = src->gatherv.rcounts[i];
disps[i] = src->gatherv.disps[i];
}
dst->gatherv.rcounts = rcounts;
dst->gatherv.disps = disps;
break;
case MCA_COLL_UCG_TYPE_ALLGATHERV:
case MCA_COLL_UCG_TYPE_IALLGATHERV:
if (src->allgatherv.rcounts == NULL ||
src->allgatherv.disps == NULL) {
return;
}
rcounts = args->buf;
disps = rcounts + size;
for (int i = 0; i < size; ++i) {
rcounts[i] = src->allgatherv.rcounts[i];
disps[i] = src->allgatherv.disps[i];
}
dst->allgatherv.rcounts = rcounts;
dst->allgatherv.disps = disps;
break;
default:
break;
}
return;
}

static void mca_coll_ucg_rcache_coll_req_args_uninit(mca_coll_ucg_args_t *args)
{
void *buf = NULL;
switch (args->coll_type) {
case MCA_COLL_UCG_TYPE_ALLTOALLV:
case MCA_COLL_UCG_TYPE_IALLTOALLV:
buf = (void *)args->alltoallv.scounts;
break;
case MCA_COLL_UCG_TYPE_SCATTERV:
case MCA_COLL_UCG_TYPE_ISCATTERV:
buf = (void *)args->scatterv.scounts;
break;
case MCA_COLL_UCG_TYPE_GATHERV:
case MCA_COLL_UCG_TYPE_IGATHERV:
buf = (void *)args->gatherv.rcounts;
break;
case MCA_COLL_UCG_TYPE_ALLGATHERV:
case MCA_COLL_UCG_TYPE_IALLGATHERV:
buf = (void *)args->allgatherv.rcounts;
break;
default:
break;
}
if (buf != NULL) {
mca_coll_ucg_subargs_t *data = container_of(buf, mca_coll_ucg_subargs_t, buf);
mca_coll_ucg_subargs_pool_put(data);
}
return;
}

void mca_coll_ucg_rcache_mark_cacheable(mca_coll_ucg_req_t *coll_req,
mca_coll_ucg_args_t *key)
{
OBJ_CONSTRUCT(&coll_req->list, opal_list_item_t);
coll_req->args = *key;
mca_coll_ucg_rcache_coll_req_args_init(&coll_req->args, key); // deep copy
ucg_coll_ucg_rcache_ref(coll_req);
coll_req->cacheable = true;
return;
Expand All @@ -247,6 +384,22 @@ int mca_coll_ucg_rcache_add(mca_coll_ucg_req_t *coll_req, mca_coll_ucg_args_t *k
return OMPI_SUCCESS;
}

static bool mca_coll_ucg_rcache_compare(int size, const int *array1, const int *array2)
{
if (array1 == NULL || array2 == NULL) {
return true;
}
if (array1 != array2) {
return false;
}
for (int i = 0; i < size; ++i) {
if (array1[i] != array2[i]) {
return false;
}
}
return true;
}

static bool mca_coll_ucg_rcache_is_same(const mca_coll_ucg_args_t *key1,
const mca_coll_ucg_args_t *key2)
{
Expand All @@ -258,6 +411,7 @@ static bool mca_coll_ucg_rcache_is_same(const mca_coll_ucg_args_t *key1,
return false;
}

uint32_t comm_size = (uint32_t)ompi_comm_size(key1->comm);
bool is_same = false;
switch (key1->coll_type) {
case MCA_COLL_UCG_TYPE_BCAST:
Expand Down Expand Up @@ -291,27 +445,29 @@ static bool mca_coll_ucg_rcache_is_same(const mca_coll_ucg_args_t *key1,
const mca_coll_alltoallv_args_t *args1 = &key1->alltoallv;
const mca_coll_alltoallv_args_t *args2 = &key2->alltoallv;
is_same = args1->sbuf == args2->sbuf &&
args1->scounts == args2->scounts &&
args1->sdispls == args2->sdispls &&
args1->sdtype == args2->sdtype &&
args1->rbuf == args2->rbuf &&
args1->rcounts == args2->rcounts &&
args1->rdispls == args2->rdispls &&
args1->rdtype == args2->rdtype;
is_same = is_same &&
mca_coll_ucg_rcache_compare(comm_size, args1->scounts, args2->scounts) &&
mca_coll_ucg_rcache_compare(comm_size, args1->sdispls, args2->sdispls) &&
mca_coll_ucg_rcache_compare(comm_size, args1->rcounts, args2->rcounts) &&
mca_coll_ucg_rcache_compare(comm_size, args1->rdispls, args2->rdispls);
break;
}
case MCA_COLL_UCG_TYPE_SCATTERV:
case MCA_COLL_UCG_TYPE_ISCATTERV: {
const mca_coll_scatterv_args_t *args1 = &key1->scatterv;
const mca_coll_scatterv_args_t *args2 = &key2->scatterv;
is_same = args1->sbuf == args2->sbuf &&
args1->scounts == args2->scounts &&
args1->disps == args2->disps &&
args1->sdtype == args2->sdtype &&
args1->rbuf == args2->rbuf &&
args1->rcount == args2->rcount &&
args1->rdtype == args2->rdtype &&
args1->root == args2->root;
is_same = is_same &&
mca_coll_ucg_rcache_compare(comm_size, args1->scounts, args2->scounts) &&
mca_coll_ucg_rcache_compare(comm_size, args1->disps, args2->disps);
break;
}
case MCA_COLL_UCG_TYPE_GATHERV:
Expand All @@ -322,10 +478,11 @@ static bool mca_coll_ucg_rcache_is_same(const mca_coll_ucg_args_t *key1,
args1->scount == args2->scount &&
args1->sdtype == args2->sdtype &&
args1->rbuf == args2->rbuf &&
args1->rcounts == args2->rcounts &&
args1->disps == args2->disps &&
args1->rdtype == args2->rdtype &&
args1->root == args2->root;
is_same = is_same &&
mca_coll_ucg_rcache_compare(comm_size, args1->rcounts, args2->rcounts) &&
mca_coll_ucg_rcache_compare(comm_size, args1->disps, args2->disps);
break;
}
case MCA_COLL_UCG_TYPE_ALLGATHERV:
Expand All @@ -336,9 +493,10 @@ static bool mca_coll_ucg_rcache_is_same(const mca_coll_ucg_args_t *key1,
args1->scount == args2->scount &&
args1->sdtype == args2->sdtype &&
args1->rbuf == args2->rbuf &&
args1->rcounts == args2->rcounts &&
args1->disps == args2->disps &&
args1->rdtype == args2->rdtype;
is_same = is_same &&
mca_coll_ucg_rcache_compare(comm_size, args1->rcounts, args2->rcounts) &&
mca_coll_ucg_rcache_compare(comm_size, args1->disps, args2->disps);
break;
}
default:
Expand Down Expand Up @@ -385,6 +543,7 @@ void mca_coll_ucg_rcache_del(mca_coll_ucg_req_t *coll_req)

coll_req->cacheable = false;
ucg_coll_ucg_rcache_deref(coll_req);
mca_coll_ucg_rcache_coll_req_args_uninit(&coll_req->args);
OBJ_DESTRUCT(&coll_req->list);

mca_coll_ucg_request_cleanup(coll_req);
Expand Down
Loading

0 comments on commit 1aa6802

Please sign in to comment.