Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCS/UCP/RCACHE: Store memh completion in request #10190

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/ucp/core/ucp_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -3568,7 +3568,8 @@ static void ucp_ep_req_purge_send(ucp_request_t *req, ucs_status_t status)
ucs_assertv(UCS_STATUS_IS_ERR(status), "req %p: status %s", req,
ucs_status_string(status));

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req)) {
ucp_request_memh_invalidate(req, status);
return;
}

Expand Down
10 changes: 4 additions & 6 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
}

void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
ucs_rcache_invalidate_comp_func_t cb, void *arg,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map)
{
ucs_trace("memh %p: invalidate address %p length %zu md_map %" PRIx64
Expand All @@ -393,7 +393,7 @@ void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
UCP_THREAD_CS_ENTER(&context->mt_lock);
memh->inv_md_map |= inv_md_map;
UCP_THREAD_CS_EXIT(&context->mt_lock);
ucs_rcache_region_invalidate(context->rcache, &memh->super, cb, arg);
ucs_rcache_region_invalidate(context->rcache, &memh->super, comp);
}

static void ucp_memh_put_rcache(ucp_context_h context, ucp_mem_h memh)
Expand Down Expand Up @@ -903,8 +903,7 @@ ucp_memh_find_slow(ucp_context_h context, void *address, size_t length,
uct_flags |= UCP_MM_UCT_ACCESS_FLAGS(memh->uct_flags);

/* Invalidate the mismatching region and get a new one */
ucs_rcache_region_invalidate(context->rcache, &memh->super,
ucs_empty_function, NULL);
ucs_rcache_region_invalidate(context->rcache, &memh->super, NULL);
ucp_memh_put(memh);
}
}
Expand Down Expand Up @@ -1930,8 +1929,7 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
"This may indicate that exported memory handle was "
"destroyed, but imported memory handle was not",
rregion->refcount);
ucs_rcache_region_invalidate(rcache, rregion,
ucs_empty_function, NULL);
ucs_rcache_region_invalidate(rcache, rregion, NULL);
ucs_rcache_region_put_unsafe(rcache, rregion);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ ucs_status_t ucp_memh_register(ucp_context_h context, ucp_mem_h memh,
const char *alloc_name);

void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
ucs_rcache_invalidate_comp_func_t cb, void *arg,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map);

void ucp_memh_put_slow(ucp_context_h context, ucp_mem_h memh);
Expand Down
51 changes: 32 additions & 19 deletions src/ucp/core/ucp_request.c
Original file line number Diff line number Diff line change
Expand Up @@ -395,50 +395,63 @@ static ucp_md_map_t ucp_request_get_invalidation_map(ucp_ep_h ep)
return inv_map;
}

int ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status)
static UCS_F_ALWAYS_INLINE ucp_mem_h* ucp_request_get_memh(ucp_request_t *req)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like no need in pointer

Suggested change
static UCS_F_ALWAYS_INLINE ucp_mem_h* ucp_request_get_memh(ucp_request_t *req)
static UCS_F_ALWAYS_INLINE ucp_mem_h ucp_request_get_memh(ucp_request_t *req)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pointer is actually used to assign NULL to memh, on line 456:

 *memh_p = NULL;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pointer is not needed as a return value of this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's needed, please check again
We assign NULL value to the request field, which address is stored in memh_p

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah.. found it but IMO this is a not good idea to set value using "getter", maybe reset the field directly by req->send.state.dt_iter.type.contig.memh = NULL at line 448?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hint. That would work if we always unconditionally replace original memh field by NULL.
But it's slightly more involved here, because I use the same function in ucp_request_memh_check_invalidate, that must not change the request, so it must be pure get.
But I'll think a bit more on how to make it less confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored it in the latest PR with entire invalidation solution, used extract approach:
03123f7

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 does it mean that this PR is not relevant any more?
( just cross reference PRs #10204 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to merge this small low-risk PR in the current release, but no luck.
I will confirm with tech leads on Monday whether it makes sense to split the overall solution into several pieces

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then maybe cherry-pick 03123f7 to here?

{
ucp_ep_h ep = req->send.ep;
ucp_err_handling_mode_t err_mode = ucp_ep_config(ep)->key.err_mode;
ucp_worker_h worker = ep->worker;
ucp_context_h context = worker->context;
ucp_mem_h *memh_p;
ucp_md_map_t invalidate_map;

if ((err_mode != UCP_ERR_HANDLING_MODE_PEER) ||
!(req->flags & UCP_REQUEST_FLAG_RKEY_INUSE)) {
return 0;
}
ucp_context_h context = req->send.ep->worker->context;

/* Get the contig memh from the request basing on the proto version */
if (context->config.ext.proto_enable) {
ucs_assertv(req->send.state.dt_iter.dt_class == UCP_DATATYPE_CONTIG,
"dt_class=%s",
ucp_datatype_class_names[req->send.state.dt_iter.dt_class]);
memh_p = &req->send.state.dt_iter.type.contig.memh;
return &req->send.state.dt_iter.type.contig.memh;
} else {
ucs_assertv(UCP_DT_IS_CONTIG(req->send.datatype), "datatype=0x%" PRIx64,
req->send.datatype);
memh_p = &req->send.state.dt.dt.contig.memh;
return &req->send.state.dt.dt.contig.memh;
}
}

int ucp_request_memh_check_invalidate(ucp_request_t *req)
{
ucp_ep_h ep = req->send.ep;
ucp_err_handling_mode_t err_mode = ucp_ep_config(ep)->key.err_mode;
ucp_mem_h *memh_p = ucp_request_get_memh(req);

if ((err_mode != UCP_ERR_HANDLING_MODE_PEER) ||
!(req->flags & UCP_REQUEST_FLAG_RKEY_INUSE)) {
return 0;
}

if ((*memh_p == NULL) || ucp_memh_is_user_memh(*memh_p)) {
return 0;
}

return 1;
}

void ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status)
{
ucp_ep_h ep = req->send.ep;
ucp_context_h context = ep->worker->context;
ucp_mem_h *memh_p = ucp_request_get_memh(req);
ucp_md_map_t invalidate_map;

ucs_assert(status != UCS_OK);

req->send.invalidate.worker = worker;
req->status = status;
req->send.invalidate.worker = ep->worker;
req->send.invalidate.comp.func = ucp_request_mem_invalidate_completion;
req->send.invalidate.comp.arg = req;
req->status = status;

invalidate_map = ucp_request_get_invalidation_map(ep);
ucp_trace_req(req, "mem invalidate buffer md_map 0x%" PRIx64 "/0x%" PRIx64,
invalidate_map, (*memh_p)->md_map);
ucp_memh_invalidate(context, *memh_p, ucp_request_mem_invalidate_completion,
req, invalidate_map);
ucp_memh_invalidate(context, *memh_p, &req->send.invalidate.comp,
invalidate_map);

ucp_memh_put(*memh_p);
*memh_p = NULL;
return 1;
}

UCS_PROFILE_FUNC(ucs_status_t, ucp_request_memory_reg,
Expand Down
18 changes: 13 additions & 5 deletions src/ucp/core/ucp_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ struct ucp_request {
} flush;

struct {
ucp_worker_h worker;
ucp_worker_h worker;
ucs_rcache_comp_entry_t comp;
} invalidate;

struct {
Expand Down Expand Up @@ -536,14 +537,21 @@ void ucp_request_memory_dereg(ucp_datatype_t datatype, ucp_dt_state_t *state,
ucp_request_t *req);

/**
* @brief Invalidates the request associated memh if required.
* @brief Detects whether request memh can be invalidated
*
* @param [in] req Request that contains memh
* @param [in] status Status of the error which caused abortion
*
* @return 1 if invalidation happened, 0 if invalidation isn't required/supported
* @return 1 if invalidation supported, 0 if invalidation isn't required/supported
*/
int ucp_request_memh_check_invalidate(ucp_request_t *req);

/**
* @brief Invalidates the request associated memh.
*
* @param [in] req Request that contains memh
* @param [in] status Status of the error which caused abortion
*/
int ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status);
void ucp_request_memh_invalidate(ucp_request_t *req, ucs_status_t status);

ucs_status_t ucp_request_send_start(ucp_request_t *req, ssize_t max_short,
size_t zcopy_thresh, size_t zcopy_max,
Expand Down
3 changes: 2 additions & 1 deletion src/ucp/rndv/proto_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,8 @@ void ucp_proto_rndv_rts_abort(ucp_request_t *req, ucs_status_t status)
{
ucp_am_release_user_header(req);

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req)) {
ucp_request_memh_invalidate(req, status);
ucp_proto_rndv_rts_reset(req);
return;
}
Expand Down
3 changes: 2 additions & 1 deletion src/ucp/rndv/rndv_rtr.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ static void ucp_proto_rndv_rtr_abort(ucp_request_t *req, ucs_status_t status)
rreq->status = status;
ucp_request_set_callback(req, send.cb, ucp_proto_rndv_rtr_abort_super);

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req)) {
if (req->send.rndv.rkey != NULL) {
ucp_proto_rndv_rkey_destroy(req);
}
ucp_proto_request_zcopy_id_reset(req);
ucp_request_memh_invalidate(req, status);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we able to add ucp_request_memh_invalidate() right after the check line 226, or some req field might touched after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put ucp_request_memh_invalidate call just before ucp_proto_request_zcopy_id_reset.
But we cannot put it before ucp_proto_rndv_rkey_destroy, because that function uses field from union rndv part of the request, which gets invalidated by invalidate struct.

return;
}

Expand Down
26 changes: 1 addition & 25 deletions src/ucs/memory/rcache.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ typedef struct ucs_rcache_inv_entry {
} ucs_rcache_inv_entry_t;


typedef struct ucs_rcache_comp_entry {
ucs_list_link_t list;
ucs_rcache_invalidate_comp_func_t func;
void *arg;
} ucs_rcache_comp_entry_t;


typedef struct {
ucs_rcache_t *rcache;
ucs_rcache_region_t *region;
Expand Down Expand Up @@ -458,9 +451,6 @@ void ucs_mem_region_destroy_internal(ucs_rcache_t *rcache,
comp = ucs_list_extract_head(&region->comp_list,
ucs_rcache_comp_entry_t, list);
comp->func(comp->arg);
ucs_spin_lock(&rcache->lock);
ucs_mpool_put(comp);
ucs_spin_unlock(&rcache->lock);
}

ucs_free(region);
Expand Down Expand Up @@ -1122,24 +1112,12 @@ void ucs_rcache_region_put(ucs_rcache_t *rcache, ucs_rcache_region_t *region)

void ucs_rcache_region_invalidate(ucs_rcache_t *rcache,
ucs_rcache_region_t *region,
ucs_rcache_invalidate_comp_func_t cb,
void *arg)
ucs_rcache_comp_entry_t *comp)
{
ucs_rcache_comp_entry_t *comp;

/* Completion entry should be added before region is invalidated */
ucs_spin_lock(&rcache->lock);
comp = ucs_mpool_get(&rcache->mp);
ucs_spin_unlock(&rcache->lock);

pthread_rwlock_wrlock(&rcache->pgt_lock);
if (comp != NULL) {
comp->func = cb;
comp->arg = arg;
ucs_list_add_tail(&region->comp_list, &comp->list);
} else {
ucs_rcache_region_error(rcache, region,
"failed to allocate completion object");
}

/* coverity[double_lock] */
Expand Down Expand Up @@ -1313,8 +1291,6 @@ static UCS_CLASS_INIT_FUNC(ucs_rcache_t, const ucs_rcache_params_t *params,
}

mp_obj_size = ucs_max(sizeof(ucs_pgt_dir_t), sizeof(ucs_rcache_inv_entry_t));
mp_obj_size = ucs_max(mp_obj_size, sizeof(ucs_rcache_comp_entry_t));

mp_align = ucs_max(sizeof(void *), UCS_PGT_ENTRY_MIN_ALIGN);

ucs_mpool_params_reset(&mp_params);
Expand Down
17 changes: 11 additions & 6 deletions src/ucs/memory/rcache.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ extern ucs_config_field_t ucs_config_rcache_table[];
typedef void (*ucs_rcache_invalidate_comp_func_t)(void *arg);


typedef struct {
ucs_list_link_t list;
ucs_rcache_invalidate_comp_func_t func;
void *arg;
} ucs_rcache_comp_entry_t;


/*
* Registration cache operations.
*/
Expand Down Expand Up @@ -261,15 +268,13 @@ void ucs_rcache_region_put(ucs_rcache_t *rcache, ucs_rcache_region_t *region);
*
* @param [in] rcache Memory registration cache.
* @param [in] region Memory region to invalidate.
* @param [in] cb Completion callback, is called when region is
* released. Callback cannot do any operations which may
* access the rcache.
* @param [in] arg Completion argument passed to completion callback.
* @param [in] comp Completion entry, called when region is released.
* Callback cannot do any operations which may access the
* rcache.
*/
void ucs_rcache_region_invalidate(ucs_rcache_t *rcache,
ucs_rcache_region_t *region,
ucs_rcache_invalidate_comp_func_t cb,
void *arg);
ucs_rcache_comp_entry_t *comp);


/**
Expand Down
8 changes: 6 additions & 2 deletions test/gtest/ucs/test_rcache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ UCS_MT_TEST_F(test_rcache, put_and_invalidate, 1)
{
static const size_t size = 1 * UCS_MBYTE;
std::vector<region*> regions;
std::vector<std::unique_ptr<ucs_rcache_comp_entry_t>> comps;
region *reg;
void *ptr;
size_t region_get_count; /* how many get operation to apply */
Expand All @@ -340,8 +341,11 @@ UCS_MT_TEST_F(test_rcache, put_and_invalidate, 1)
ASSERT_EQ(0, m_comp_count);
region *region = regions.back();
if ((iter & 1) == 0) { /* on even iteration invalidate region */
ucs_rcache_region_invalidate(m_rcache, &region->super,
&completion_cb, this);
auto *comp = new ucs_rcache_comp_entry_t();
comp->func = completion_cb;
comp->arg = this;
comps.emplace_back(comp);
ucs_rcache_region_invalidate(m_rcache, &region->super, comp);
/* after invalidation region should not be acquired again */
reg = get(ptr, size);
EXPECT_NE(reg, region);
Expand Down
Loading