From dce2c3cf301194f337b46d4cfef3ee97b9c128b2 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 7 Nov 2024 11:10:05 -0600 Subject: [PATCH] ch4: use am_tag_{send,recv} in MPIDIG get When target reply data to origin get, use am_tag_send if available. --- src/mpid/ch4/include/mpidpre.h | 1 + src/mpid/ch4/src/ch4_types.h | 1 + src/mpid/ch4/src/mpidig.h | 1 + src/mpid/ch4/src/mpidig_init.c | 1 + src/mpid/ch4/src/mpidig_rma.h | 17 +++++-- src/mpid/ch4/src/mpidig_rma_callbacks.c | 64 +++++++++++++++---------- src/mpid/ch4/src/mpidig_rma_callbacks.h | 1 + 7 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/mpid/ch4/include/mpidpre.h b/src/mpid/ch4/include/mpidpre.h index ea2ceb78eaa..0281a5c2d72 100644 --- a/src/mpid/ch4/include/mpidpre.h +++ b/src/mpid/ch4/include/mpidpre.h @@ -118,6 +118,7 @@ typedef struct MPIDIG_put_req_t { typedef struct MPIDIG_get_req_t { MPIR_Request *greq_ptr; void *flattened_dt; + int am_tag; } MPIDIG_get_req_t; typedef struct MPIDIG_cswap_req_t { diff --git a/src/mpid/ch4/src/ch4_types.h b/src/mpid/ch4/src/ch4_types.h index 7b2b4f0afc6..910e6ebaf4c 100644 --- a/src/mpid/ch4/src/ch4_types.h +++ b/src/mpid/ch4/src/ch4_types.h @@ -142,6 +142,7 @@ typedef struct MPIDIG_get_msg_t { MPI_Aint target_datatype; MPI_Aint target_true_lb; int flattened_sz; + int am_tag; } MPIDIG_get_msg_t; typedef struct MPIDIG_get_ack_msg_t { diff --git a/src/mpid/ch4/src/mpidig.h b/src/mpid/ch4/src/mpidig.h index 6f292064ee7..3b7d934467b 100644 --- a/src/mpid/ch4/src/mpidig.h +++ b/src/mpid/ch4/src/mpidig.h @@ -77,6 +77,7 @@ enum { enum { MPIDIG_TAG_RECV_COMPLETE = 0, + MPIDIG_TAG_GET_COMPLETE, MPIDIG_TAG_RECV_STATIC_MAX }; diff --git a/src/mpid/ch4/src/mpidig_init.c b/src/mpid/ch4/src/mpidig_init.c index 83683d63e4b..c09a0cd8c3e 100644 --- a/src/mpid/ch4/src/mpidig_init.c +++ b/src/mpid/ch4/src/mpidig_init.c @@ -158,6 +158,7 @@ int MPIDIG_am_init(void) MPIDIG_am_rndv_reg_cb(MPIDIG_RNDV_GENERIC, &MPIDIG_do_cts); MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_RECV_COMPLETE, &MPIDIG_tag_recv_complete); + MPIDIG_am_tag_recv_reg_cb(MPIDIG_TAG_GET_COMPLETE, &MPIDIG_tag_get_complete); MPIDIG_am_comm_abort_init(); diff --git a/src/mpid/ch4/src/mpidig_rma.h b/src/mpid/ch4/src/mpidig_rma.h index 63545a456d5..62da94f0141 100644 --- a/src/mpid/ch4/src/mpidig_rma.h +++ b/src/mpid/ch4/src/mpidig_rma.h @@ -220,6 +220,18 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co * counter in request, thus it can be decreased at request completion. */ MPIDIG_win_cmpl_cnts_incr(win, target_rank, &sreq->dev.completion_notification); + bool is_local; + is_local = MPIDI_rank_is_local(target_rank, win->comm_ptr); + if (MPIDIG_can_do_tag(is_local)) { + am_hdr.am_tag = MPIDIG_get_next_am_tag(win->comm_ptr); + CH4_CALL(am_tag_recv(target_rank, win->comm_ptr, MPIDIG_TAG_GET_COMPLETE, am_hdr.am_tag, + origin_addr, origin_count, origin_datatype, vci_target, vci, sreq), + is_local, mpi_errno); + MPIR_ERR_CHECK(mpi_errno); + } else { + am_hdr.am_tag = -1; + } + int is_contig; MPIR_Datatype_is_contig(target_datatype, &is_contig); if (MPIR_DATATYPE_IS_PREDEFINED(target_datatype) || is_contig) { @@ -228,8 +240,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co MPIR_T_PVAR_TIMER_END(RMA, rma_amhdr_set); CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr), - NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq), - MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno); + NULL, 0, MPI_DATATYPE_NULL, vci, vci_target, sreq), is_local, mpi_errno); MPIR_ERR_CHECK(mpi_errno); goto fn_exit; } @@ -242,7 +253,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_get(void *origin_addr, MPI_Aint origin_co CH4_CALL(am_isend(target_rank, win->comm_ptr, MPIDIG_GET_REQ, &am_hdr, sizeof(am_hdr), flattened_dt, flattened_sz, MPI_BYTE, vci, vci_target, sreq), - MPIDI_rank_is_local(target_rank, win->comm_ptr), mpi_errno); + is_local, mpi_errno); MPIR_ERR_CHECK(mpi_errno); fn_exit: diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.c b/src/mpid/ch4/src/mpidig_rma_callbacks.c index 5b7014a107f..a280bd02c6a 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.c +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.c @@ -896,42 +896,44 @@ static int get_target_cmpl_cb(MPIR_Request * rreq) get_ack.greq_ptr = MPIDIG_REQUEST(rreq, req->greq.greq_ptr); win = rreq->u.rma.win; - int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); - int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); - if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt) == NULL) { + if (MPIDIG_REQUEST(rreq, req->greq.flattened_dt)) { + /* FIXME: MPIR_Typerep_unflatten should allocate the new object */ + MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem); + if (!dt) { + MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", + "MPIR_Datatype_mem"); + } + MPIR_Object_set_ref(dt, 1); + MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt)); + MPIDIG_REQUEST(rreq, datatype) = dt->handle; + /* count is still target_data_sz now, use it for reply */ + get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count); + MPIDIG_REQUEST(rreq, count) /= dt->size; + } else { MPIDI_Datatype_check_size(MPIDIG_REQUEST(rreq, datatype), MPIDIG_REQUEST(rreq, count), get_ack.target_data_sz); + } + + int local_vci = MPIDIG_REQUEST(rreq, req->local_vci); + int remote_vci = MPIDIG_REQUEST(rreq, req->remote_vci); + if (MPIDIG_REQUEST(rreq, req->greq.am_tag) >= 0) { + int src_rank = MPIDIG_REQUEST(rreq, u.target.origin_rank); + CH4_CALL(am_tag_send(src_rank, win->comm_ptr, MPIDIG_GET_ACK, + MPIDIG_REQUEST(rreq, req->greq.am_tag), + MPIDIG_REQUEST(rreq, buffer), + MPIDIG_REQUEST(rreq, count), + MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci, rreq), + MPIDI_REQUEST(rreq, is_local), mpi_errno); + } else { CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), MPIDIG_GET_ACK, &get_ack, sizeof(get_ack), MPIDIG_REQUEST(rreq, buffer), MPIDIG_REQUEST(rreq, count), MPIDIG_REQUEST(rreq, datatype), local_vci, remote_vci, rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno); - MPID_Request_complete(rreq); - MPIR_ERR_CHECK(mpi_errno); - goto fn_exit; - } - - /* FIXME: MPIR_Typerep_unflatten should allocate the new object */ - MPIR_Datatype *dt = (MPIR_Datatype *) MPIR_Handle_obj_alloc(&MPIR_Datatype_mem); - if (!dt) { - MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_OTHER, "**nomem", "**nomem %s", - "MPIR_Datatype_mem"); } - MPIR_Object_set_ref(dt, 1); - MPIR_Typerep_unflatten(dt, MPIDIG_REQUEST(rreq, req->greq.flattened_dt)); - MPIDIG_REQUEST(rreq, datatype) = dt->handle; - /* count is still target_data_sz now, use it for reply */ - get_ack.target_data_sz = MPIDIG_REQUEST(rreq, count); - MPIDIG_REQUEST(rreq, count) /= dt->size; - - CH4_CALL(am_isend_reply(win->comm_ptr, MPIDIG_REQUEST(rreq, u.target.origin_rank), - MPIDIG_GET_ACK, &get_ack, sizeof(get_ack), - MPIDIG_REQUEST(rreq, buffer), - MPIDIG_REQUEST(rreq, count), dt->handle, local_vci, - remote_vci, rreq), MPIDI_REQUEST(rreq, is_local), mpi_errno); MPID_Request_complete(rreq); - MPIR_ERR_CHECK(mpi_errno); + fn_exit: MPIR_FUNC_EXIT; return mpi_errno; @@ -2104,6 +2106,7 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, MPIDIG_REQUEST(rreq, req->greq.flattened_dt) = NULL; MPIDIG_REQUEST(rreq, req->greq.greq_ptr) = msg_hdr->greq_ptr; MPIDIG_REQUEST(rreq, u.target.origin_rank) = msg_hdr->src_rank; + MPIDIG_REQUEST(rreq, req->greq.am_tag) = msg_hdr->am_tag; if (msg_hdr->flattened_sz) { void *flattened_dt = MPL_malloc(msg_hdr->flattened_sz, MPL_MEM_BUFFER); @@ -2164,3 +2167,12 @@ int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, MPIR_FUNC_EXIT; return mpi_errno; } + +int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status) +{ + int mpi_errno = MPI_SUCCESS; + + mpi_errno = get_ack_target_cmpl_cb(req); + + return mpi_errno; +} diff --git a/src/mpid/ch4/src/mpidig_rma_callbacks.h b/src/mpid/ch4/src/mpidig_rma_callbacks.h index 3fe440cdc7d..ac8d7c6b887 100644 --- a/src/mpid/ch4/src/mpidig_rma_callbacks.h +++ b/src/mpid/ch4/src/mpidig_rma_callbacks.h @@ -112,5 +112,6 @@ int MPIDIG_get_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); int MPIDIG_get_ack_target_msg_cb(void *am_hdr, void *data, MPI_Aint in_data_sz, uint32_t attr, MPIR_Request ** req); +int MPIDIG_tag_get_complete(MPIR_Request * req, MPI_Status * status); #endif /* MPIDIG_RMA_CALLBACKS_H_INCLUDED */