Skip to content

Commit

Permalink
UCS/UCP/RCACHE: Store memh completion in request
Browse files Browse the repository at this point in the history
  • Loading branch information
iyastreb committed Sep 27, 2024
1 parent 1b0cece commit 5183688
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 67 deletions.
3 changes: 2 additions & 1 deletion src/ucp/core/ucp_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -3568,7 +3568,8 @@ static void ucp_ep_req_purge_send(ucp_request_t *req, ucs_status_t status)
ucs_assertv(UCS_STATUS_IS_ERR(status), "req %p: status %s", req,
ucs_status_string(status));

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req)) {
ucp_request_memh_invalidate(req, status);
return;
}

Expand Down
10 changes: 4 additions & 6 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
}

void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
ucs_rcache_invalidate_comp_func_t cb, void *arg,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map)
{
ucs_trace("memh %p: invalidate address %p length %zu md_map %" PRIx64
Expand All @@ -393,7 +393,7 @@ void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
UCP_THREAD_CS_ENTER(&context->mt_lock);
memh->inv_md_map |= inv_md_map;
UCP_THREAD_CS_EXIT(&context->mt_lock);
ucs_rcache_region_invalidate(context->rcache, &memh->super, cb, arg);
ucs_rcache_region_invalidate(context->rcache, &memh->super, comp);
}

static void ucp_memh_put_rcache(ucp_context_h context, ucp_mem_h memh)
Expand Down Expand Up @@ -903,8 +903,7 @@ ucp_memh_find_slow(ucp_context_h context, void *address, size_t length,
uct_flags |= UCP_MM_UCT_ACCESS_FLAGS(memh->uct_flags);

/* Invalidate the mismatching region and get a new one */
ucs_rcache_region_invalidate(context->rcache, &memh->super,
ucs_empty_function, NULL);
ucs_rcache_region_invalidate(context->rcache, &memh->super, NULL);
ucp_memh_put(memh);
}
}
Expand Down Expand Up @@ -1930,8 +1929,7 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
"This may indicate that exported memory handle was "
"destroyed, but imported memory handle was not",
rregion->refcount);
ucs_rcache_region_invalidate(rcache, rregion,
ucs_empty_function, NULL);
ucs_rcache_region_invalidate(rcache, rregion, NULL);
ucs_rcache_region_put_unsafe(rcache, rregion);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ ucs_status_t ucp_memh_register(ucp_context_h context, ucp_mem_h memh,
const char *alloc_name);

void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
ucs_rcache_invalidate_comp_func_t cb, void *arg,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map);

void ucp_memh_put_slow(ucp_context_h context, ucp_mem_h memh);
Expand Down
51 changes: 32 additions & 19 deletions src/ucp/core/ucp_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -395,50 +395,63 @@ static ucp_md_map_t ucp_request_get_invalidation_map(ucp_ep_h ep)
return inv_map;
}

int ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status)
static UCS_F_ALWAYS_INLINE ucp_mem_h* ucp_request_get_memh(ucp_request_t *req)
{
ucp_ep_h ep = req->send.ep;
ucp_err_handling_mode_t err_mode = ucp_ep_config(ep)->key.err_mode;
ucp_worker_h worker = ep->worker;
ucp_context_h context = worker->context;
ucp_mem_h *memh_p;
ucp_md_map_t invalidate_map;

if ((err_mode != UCP_ERR_HANDLING_MODE_PEER) ||
!(req->flags & UCP_REQUEST_FLAG_RKEY_INUSE)) {
return 0;
}
ucp_context_h context = req->send.ep->worker->context;

/* Get the contig memh from the request basing on the proto version */
if (context->config.ext.proto_enable) {
ucs_assertv(req->send.state.dt_iter.dt_class == UCP_DATATYPE_CONTIG,
"dt_class=%s",
ucp_datatype_class_names[req->send.state.dt_iter.dt_class]);
memh_p = &req->send.state.dt_iter.type.contig.memh;
return &req->send.state.dt_iter.type.contig.memh;
} else {
ucs_assertv(UCP_DT_IS_CONTIG(req->send.datatype), "datatype=0x%" PRIx64,
req->send.datatype);
memh_p = &req->send.state.dt.dt.contig.memh;
return &req->send.state.dt.dt.contig.memh;
}
}

int ucp_request_memh_check_invalidate(ucp_request_t *req)
{
ucp_ep_h ep = req->send.ep;
ucp_err_handling_mode_t err_mode = ucp_ep_config(ep)->key.err_mode;
ucp_mem_h *memh_p = ucp_request_get_memh(req);

if ((err_mode != UCP_ERR_HANDLING_MODE_PEER) ||
!(req->flags & UCP_REQUEST_FLAG_RKEY_INUSE)) {
return 0;
}

if ((*memh_p == NULL) || ucp_memh_is_user_memh(*memh_p)) {
return 0;
}

return 1;
}

void ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status)
{
ucp_ep_h ep = req->send.ep;
ucp_context_h context = ep->worker->context;
ucp_mem_h *memh_p = ucp_request_get_memh(req);
ucp_md_map_t invalidate_map;

ucs_assert(status != UCS_OK);

req->send.invalidate.worker = worker;
req->status = status;
req->send.invalidate.worker = ep->worker;
req->send.invalidate.comp.func = ucp_request_mem_invalidate_completion;
req->send.invalidate.comp.arg = req;
req->status = status;

invalidate_map = ucp_request_get_invalidation_map(ep);
ucp_trace_req(req, "mem invalidate buffer md_map 0x%" PRIx64 "/0x%" PRIx64,
invalidate_map, (*memh_p)->md_map);
ucp_memh_invalidate(context, *memh_p, ucp_request_mem_invalidate_completion,
req, invalidate_map);
ucp_memh_invalidate(context, *memh_p, &req->send.invalidate.comp,
invalidate_map);

ucp_memh_put(*memh_p);
*memh_p = NULL;
return 1;
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_request_memory_reg,
Expand Down
18 changes: 13 additions & 5 deletions src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ struct ucp_request {
} flush;

struct {
ucp_worker_h worker;
ucp_worker_h worker;
ucs_rcache_comp_entry_t comp;
} invalidate;

struct {
Expand Down Expand Up @@ -536,14 +537,21 @@ void ucp_request_memory_dereg(ucp_datatype_t datatype, ucp_dt_state_t *state,
ucp_request_t *req);

/**
* @brief Invalidates the request associated memh if required.
* @brief Detects whether request memh can be invalidated
*
* @param [in] req Request that contains memh
* @param [in] status Status of the error which caused abortion
*
* @return 1 if invalidation happened, 0 if invalidation isn't required/supported
* @return 1 if invalidation supported, 0 if invalidation isn't required/supported
*/
int ucp_request_memh_check_invalidate(ucp_request_t *req);

/**
* @brief Invalidates the request associated memh.
*
* @param [in] req Request that contains memh
* @param [in] status Status of the error which caused abortion
*/
int ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status);
void ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status);

ucs_status_t ucp_request_send_start(ucp_request_t *req, ssize_t max_short,
size_t zcopy_thresh, size_t zcopy_max,
Expand Down
3 changes: 2 additions & 1 deletion src/ucp/rndv/proto_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,9 @@ void ucp_proto_rndv_rts_abort(ucp_request_t *req, ucs_status_t status)
{
ucp_am_release_user_header(req);

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req)) {
ucp_proto_rndv_rts_reset(req);
ucp_request_memh_invalidate(req, status);
return;
}

Expand Down
3 changes: 2 additions & 1 deletion src/ucp/rndv/rndv_rtr.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ static void ucp_proto_rndv_rtr_abort(ucp_request_t *req, ucs_status_t status)
rreq->status = status;
ucp_request_set_callback(req, send.cb, ucp_proto_rndv_rtr_abort_super);

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req)) {
if (req->send.rndv.rkey != NULL) {
ucp_proto_rndv_rkey_destroy(req);
}
ucp_proto_request_zcopy_id_reset(req);
ucp_request_memh_invalidate(req, status);
return;
}

Expand Down
26 changes: 1 addition & 25 deletions src/ucs/memory/rcache.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ typedef struct ucs_rcache_inv_entry {
} ucs_rcache_inv_entry_t;


typedef struct ucs_rcache_comp_entry {
ucs_list_link_t list;
ucs_rcache_invalidate_comp_func_t func;
void *arg;
} ucs_rcache_comp_entry_t;


typedef struct {
ucs_rcache_t *rcache;
ucs_rcache_region_t *region;
Expand Down Expand Up @@ -458,9 +451,6 @@ void ucs_mem_region_destroy_internal(ucs_rcache_t *rcache,
comp = ucs_list_extract_head(&region->comp_list,
ucs_rcache_comp_entry_t, list);
comp->func(comp->arg);
ucs_spin_lock(&rcache->lock);
ucs_mpool_put(comp);
ucs_spin_unlock(&rcache->lock);
}

ucs_free(region);
Expand Down Expand Up @@ -1122,24 +1112,12 @@ void ucs_rcache_region_put(ucs_rcache_t *rcache, ucs_rcache_region_t *region)

void ucs_rcache_region_invalidate(ucs_rcache_t *rcache,
ucs_rcache_region_t *region,
ucs_rcache_invalidate_comp_func_t cb,
void *arg)
ucs_rcache_comp_entry_t *comp)
{
ucs_rcache_comp_entry_t *comp;

/* Completion entry should be added before region is invalidated */
ucs_spin_lock(&rcache->lock);
comp = ucs_mpool_get(&rcache->mp);
ucs_spin_unlock(&rcache->lock);

pthread_rwlock_wrlock(&rcache->pgt_lock);
if (comp != NULL) {
comp->func = cb;
comp->arg = arg;
ucs_list_add_tail(&region->comp_list, &comp->list);
} else {
ucs_rcache_region_error(rcache, region,
"failed to allocate completion object");
}

/* coverity[double_lock] */
Expand Down Expand Up @@ -1313,8 +1291,6 @@ static UCS_CLASS_INIT_FUNC(ucs_rcache_t, const ucs_rcache_params_t *params,
}

mp_obj_size = ucs_max(sizeof(ucs_pgt_dir_t), sizeof(ucs_rcache_inv_entry_t));
mp_obj_size = ucs_max(mp_obj_size, sizeof(ucs_rcache_comp_entry_t));

mp_align = ucs_max(sizeof(void *), UCS_PGT_ENTRY_MIN_ALIGN);

ucs_mpool_params_reset(&mp_params);
Expand Down
17 changes: 11 additions & 6 deletions src/ucs/memory/rcache.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ extern ucs_config_field_t ucs_config_rcache_table[];
typedef void (*ucs_rcache_invalidate_comp_func_t)(void *arg);


typedef struct {
ucs_list_link_t list;
ucs_rcache_invalidate_comp_func_t func;
void *arg;
} ucs_rcache_comp_entry_t;


/*
* Registration cache operations.
*/
Expand Down Expand Up @@ -261,15 +268,13 @@ void ucs_rcache_region_put(ucs_rcache_t *rcache, ucs_rcache_region_t *region);
*
* @param [in] rcache Memory registration cache.
* @param [in] region Memory region to invalidate.
* @param [in] cb Completion callback, is called when region is
* released. Callback cannot do any operations which may
* access the rcache.
* @param [in] arg Completion argument passed to completion callback.
* @param [in] comp Completion entry, called when region is released.
* Callback cannot do any operations which may access the
* rcache.
*/
void ucs_rcache_region_invalidate(ucs_rcache_t *rcache,
ucs_rcache_region_t *region,
ucs_rcache_invalidate_comp_func_t cb,
void *arg);
ucs_rcache_comp_entry_t *comp);


/**
Expand Down
8 changes: 6 additions & 2 deletions test/gtest/ucs/test_rcache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ UCS_MT_TEST_F(test_rcache, put_and_invalidate, 1)
{
static const size_t size = 1 * UCS_MBYTE;
std::vector<region*> regions;
std::vector<std::unique_ptr<ucs_rcache_comp_entry_t>> comps;
region *reg;
void *ptr;
size_t region_get_count; /* how many get operation to apply */
Expand All @@ -340,8 +341,11 @@ UCS_MT_TEST_F(test_rcache, put_and_invalidate, 1)
ASSERT_EQ(0, m_comp_count);
region *region = regions.back();
if ((iter & 1) == 0) { /* on even iteration invalidate region */
ucs_rcache_region_invalidate(m_rcache, &region->super,
&completion_cb, this);
auto *comp = new ucs_rcache_comp_entry_t();
comp->func = completion_cb;
comp->arg = this;
comps.emplace_back(comp);
ucs_rcache_region_invalidate(m_rcache, &region->super, comp);
/* after invalidation region should not be acquired again */
reg = get(ptr, size);
EXPECT_NE(reg, region);
Expand Down

0 comments on commit 5183688

Please sign in to comment.