Skip to content

Commit

Permalink
ch4/ofi: refactor pipeline_info into a union
Browse files Browse the repository at this point in the history
Make the code cleaner to separate the pipeline_info type into a union of
send and recv.
  • Loading branch information
hzhou committed Feb 6, 2024
1 parent 1c8c6fa commit 4b78f19
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 46 deletions.
60 changes: 28 additions & 32 deletions src/mpid/ch4/netmod/ofi/ofi_gpu_pipeline.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ int MPIDI_OFI_gpu_pipeline_send(MPIR_Request * sreq, const void *send_buf,
MPIDI_OFI_idata_set_gpuchunk_bits(&cq_data, n_chunks);
MPIDI_OFI_idata_set_gpu_packed_bit(&cq_data, is_packed);

MPIDI_OFI_REQUEST(sreq, pipeline_info.cq_data) = cq_data;
MPIDI_OFI_REQUEST(sreq, pipeline_info.remote_addr) = remote_addr;
MPIDI_OFI_REQUEST(sreq, pipeline_info.vci_local) = vci_local;
MPIDI_OFI_REQUEST(sreq, pipeline_info.ctx_idx) = ctx_idx;
MPIDI_OFI_REQUEST(sreq, pipeline_info.match_bits) = match_bits;
MPIDI_OFI_REQUEST(sreq, pipeline_info.data_sz) = data_sz;
MPIDI_OFI_REQUEST(sreq, pipeline_info.send.cq_data) = cq_data;
MPIDI_OFI_REQUEST(sreq, pipeline_info.send.remote_addr) = remote_addr;
MPIDI_OFI_REQUEST(sreq, pipeline_info.send.vci_local) = vci_local;
MPIDI_OFI_REQUEST(sreq, pipeline_info.send.ctx_idx) = ctx_idx;
MPIDI_OFI_REQUEST(sreq, pipeline_info.send.match_bits) = match_bits;

/* Send the initial empty packet for matching */
MPIDI_OFI_CALL_RETRY(fi_tinjectdata(MPIDI_OFI_global.ctx[ctx_idx].tx, NULL, 0, cq_data,
Expand Down Expand Up @@ -188,7 +187,7 @@ static int send_copy_poll(MPIR_Async_thing * thing)
static void send_copy_complete(MPIR_Request * sreq, const void *buf, MPI_Aint chunk_sz)
{
int mpi_errno = MPI_SUCCESS;
int vci_local = MPIDI_OFI_REQUEST(sreq, pipeline_info.vci_local);
int vci_local = MPIDI_OFI_REQUEST(sreq, pipeline_info.send.vci_local);

struct chunk_req *chunk_req = MPL_malloc(sizeof(struct chunk_req), MPL_MEM_BUFFER);
MPIR_Assertp(chunk_req);
Expand All @@ -197,11 +196,11 @@ static void send_copy_complete(MPIR_Request * sreq, const void *buf, MPI_Aint ch
chunk_req->event_id = MPIDI_OFI_EVENT_SEND_GPU_PIPELINE;
chunk_req->buf = (void *) buf;

int ctx_idx = MPIDI_OFI_REQUEST(sreq, pipeline_info.ctx_idx);
fi_addr_t remote_addr = MPIDI_OFI_REQUEST(sreq, pipeline_info.remote_addr);
uint64_t cq_data = MPIDI_OFI_REQUEST(sreq, pipeline_info.cq_data);
uint64_t match_bits = MPIDI_OFI_REQUEST(sreq, pipeline_info.match_bits);
match_bits |= MPIDI_OFI_GPU_PIPELINE_SEND;
int ctx_idx = MPIDI_OFI_REQUEST(sreq, pipeline_info.send.ctx_idx);
fi_addr_t remote_addr = MPIDI_OFI_REQUEST(sreq, pipeline_info.send.remote_addr);
uint64_t cq_data = MPIDI_OFI_REQUEST(sreq, pipeline_info.send.cq_data);
uint64_t match_bits = MPIDI_OFI_REQUEST(sreq, pipeline_info.send.match_bits) |
MPIDI_OFI_GPU_PIPELINE_SEND;
MPID_THREAD_CS_ENTER(VCI, MPIDI_VCI(vci_local).lock);
MPIDI_OFI_CALL_RETRY(fi_tsenddata(MPIDI_OFI_global.ctx[ctx_idx].tx,
buf, chunk_sz, NULL /* desc */ ,
Expand Down Expand Up @@ -245,7 +244,6 @@ int MPIDI_OFI_gpu_pipeline_send_event(struct fi_cq_tagged_entry *wc, MPIR_Reques
struct recv_alloc {
MPIR_Request *rreq;
struct chunk_req *chunk_req;
int idx;
int n_chunks;
};

Expand All @@ -259,14 +257,14 @@ int MPIDI_OFI_gpu_pipeline_recv(MPIR_Request * rreq,
{
int mpi_errno = MPI_SUCCESS;

MPIDI_OFI_REQUEST(rreq, pipeline_info.offset) = 0;
MPIDI_OFI_REQUEST(rreq, pipeline_info.is_sync) = false;
MPIDI_OFI_REQUEST(rreq, pipeline_info.remote_addr) = remote_addr;
MPIDI_OFI_REQUEST(rreq, pipeline_info.vci_local) = vci_local;
MPIDI_OFI_REQUEST(rreq, pipeline_info.match_bits) = match_bits;
MPIDI_OFI_REQUEST(rreq, pipeline_info.mask_bits) = mask_bits;
MPIDI_OFI_REQUEST(rreq, pipeline_info.data_sz) = data_sz;
MPIDI_OFI_REQUEST(rreq, pipeline_info.ctx_idx) = ctx_idx;
/* The 1st recv is an empty chunk for matching. We need initialize rreq. */
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.offset) = 0;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.is_sync) = false;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.remote_addr) = remote_addr;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.vci_local) = vci_local;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.match_bits) = match_bits;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.mask_bits) = mask_bits;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.ctx_idx) = ctx_idx;

/* Save original buf, datatype and count */
MPIDI_OFI_REQUEST(rreq, noncontig.pack.buf) = recv_buf;
Expand All @@ -278,7 +276,6 @@ int MPIDI_OFI_gpu_pipeline_recv(MPIR_Request * rreq,
MPIR_Assert(p);

p->rreq = rreq;
p->idx = 0;
p->n_chunks = -1; /* it's MPIDI_OFI_EVENT_RECV_GPU_PIPELINE_INIT */

mpi_errno = MPIR_Async_things_add(recv_alloc_poll, p);
Expand All @@ -296,7 +293,6 @@ static int start_recv_chunk(MPIR_Request * rreq, int idx, int n_chunks)
MPIR_Assert(p);

p->rreq = rreq;
p->idx = idx;
p->n_chunks = n_chunks;

mpi_errno = MPIR_Async_things_add(recv_alloc_poll, p);
Expand All @@ -319,11 +315,11 @@ static int recv_alloc_poll(MPIR_Async_thing * thing)
return MPIR_ASYNC_THING_NOPROGRESS;
}

fi_addr_t remote_addr = MPIDI_OFI_REQUEST(rreq, pipeline_info.remote_addr);
int ctx_idx = MPIDI_OFI_REQUEST(rreq, pipeline_info.ctx_idx);
fi_addr_t remote_addr = MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.remote_addr);
int ctx_idx = MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.ctx_idx);
int vci = MPIDI_Request_get_vci(rreq);
uint64_t match_bits = MPIDI_OFI_REQUEST(rreq, pipeline_info.match_bits);
uint64_t mask_bits = MPIDI_OFI_REQUEST(rreq, pipeline_info.mask_bits);
uint64_t match_bits = MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.match_bits);
uint64_t mask_bits = MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.mask_bits);

struct chunk_req *chunk_req;
chunk_req = MPL_malloc(sizeof(*chunk_req), MPL_MEM_BUFFER);
Expand Down Expand Up @@ -380,7 +376,7 @@ int MPIDI_OFI_gpu_pipeline_recv_event(struct fi_cq_tagged_entry *wc, MPIR_Reques
rreq->status.MPI_TAG = MPIDI_OFI_init_get_tag(wc->tag);

if (unlikely(MPIDI_OFI_is_tag_sync(wc->tag))) {
MPIDI_OFI_REQUEST(rreq, pipeline_info.is_sync) = true;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.is_sync) = true;
}

uint32_t packed = MPIDI_OFI_idata_get_gpu_packed_bit(wc->data);
Expand Down Expand Up @@ -435,7 +431,7 @@ static int start_recv_copy(MPIR_Request * rreq, void *buf, MPI_Aint chunk_sz,
{
int mpi_errno = MPI_SUCCESS;

MPI_Aint offset = MPIDI_OFI_REQUEST(rreq, pipeline_info.offset);
MPI_Aint offset = MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.offset);
int engine_type = MPIR_CVAR_CH4_OFI_GPU_PIPELINE_H2D_ENGINE_TYPE;

/* FIXME: current design unpacks all bytes from host buffer, overflow check is missing. */
Expand All @@ -445,7 +441,7 @@ static int start_recv_copy(MPIR_Request * rreq, void *buf, MPI_Aint chunk_sz,
MPL_GPU_COPY_H2D, engine_type, 1, &async_req);
MPIR_ERR_CHECK(mpi_errno);

MPIDI_OFI_REQUEST(rreq, pipeline_info.offset) += chunk_sz;
MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.offset) += chunk_sz;

struct recv_copy *p;
p = MPL_malloc(sizeof(*p), MPL_MEM_OTHER);
Expand Down Expand Up @@ -486,7 +482,7 @@ static void recv_copy_complete(MPIR_Request * rreq, void *buf)
MPIR_cc_decr(rreq->cc_ptr, &c);
if (c == 0) {
/* all chunks arrived and copied */
if (unlikely(MPIDI_OFI_REQUEST(rreq, pipeline_info.is_sync))) {
if (unlikely(MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.is_sync))) {
MPIR_Comm *comm = rreq->comm;
uint64_t ss_bits =
MPIDI_OFI_init_sendtag(MPL_atomic_relaxed_load_int
Expand All @@ -513,7 +509,7 @@ static void recv_copy_complete(MPIR_Request * rreq, void *buf)

MPIR_Datatype_release_if_not_builtin(MPIDI_OFI_REQUEST(rreq, datatype));
/* Set number of bytes in status. */
MPIR_STATUS_SET_COUNT(rreq->status, MPIDI_OFI_REQUEST(rreq, pipeline_info.offset));
MPIR_STATUS_SET_COUNT(rreq->status, MPIDI_OFI_REQUEST(rreq, pipeline_info.recv.offset));

MPIR_Request_free(rreq);
}
Expand Down
31 changes: 17 additions & 14 deletions src/mpid/ch4/netmod/ofi/ofi_pre.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,23 @@ typedef struct {
struct iovec iov;
void *inject_buf; /* Internal buffer for inject emulation */
} util;
struct {
fi_addr_t remote_addr;
int ctx_idx;
int vci_local;
int chunk_sz;
bool is_sync;
uint64_t cq_data;
uint64_t match_bits;
uint64_t mask_bits;
size_t offset;
size_t data_sz;
char *pack_recv_buf;
void *usm_host_buf; /* recv */
MPIR_Request *req;
union {
struct {
int vci_local;
int ctx_idx;
fi_addr_t remote_addr;
uint64_t cq_data;
uint64_t match_bits;
} send;
struct {
int vci_local;
int ctx_idx;
fi_addr_t remote_addr;
uint64_t match_bits;
uint64_t mask_bits;
MPI_Aint offset;
bool is_sync;
} recv;
} pipeline_info; /* GPU pipeline */
} MPIDI_OFI_request_t;

Expand Down

0 comments on commit 4b78f19

Please sign in to comment.