Skip to content

Commit

Permalink
Merge pull request #10199 from Artemy-Mellanox/topic/go-req-params
Browse files Browse the repository at this point in the history
BINDINGS/GO: Add multi-send flag and user memh support request params
  • Loading branch information
yosefe authored Oct 6, 2024
2 parents 63ce776 + bfe7f32 commit 6aacca7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 46 deletions.
15 changes: 1 addition & 14 deletions bindings/go/src/ucx/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,7 @@ type UcpEp struct {
var errorHandles = make(map[C.ucp_ep_h]UcpEpErrHandler)

func setSendParams(goRequestParams *UcpRequestParams, cRequestParams *C.ucp_request_param_t) uint64 {
var cbId uint64
if goRequestParams != nil {
if goRequestParams.Cb != nil {
cbId = register(goRequestParams.Cb)
cRequestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA
cbAddr := (*C.ucp_send_nbx_callback_t)(unsafe.Pointer(&cRequestParams.cb[0]))
*cbAddr = (C.ucp_send_nbx_callback_t)(C.ucxgo_completeGoSendRequest)
cRequestParams.user_data = unsafe.Pointer(uintptr(cbId))
}

setMemType(goRequestParams, cRequestParams)
}

return cbId
return packParams(goRequestParams, cRequestParams, unsafe.Pointer(C.ucxgo_completeGoSendRequest))
}

// This routine flushes all outstanding AMO and RMA communications on the endpoint.
Expand Down
46 changes: 41 additions & 5 deletions bindings/go/src/ucx/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type UcpRequestParams struct {
memTypeSet bool
memType UcsMemoryType
Cb UcpCallback
multi bool
Memory *UcpMemory
}

func (p *UcpRequestParams) SetMemType(memType UcsMemoryType) *UcpRequestParams {
Expand All @@ -29,18 +31,52 @@ func (p *UcpRequestParams) SetMemType(memType UcsMemoryType) *UcpRequestParams {
return p
}

func setMemType(params *UcpRequestParams, p *C.ucp_request_param_t) {
if (params != nil) && params.memTypeSet {
p.op_attr_mask = C.UCP_OP_ATTR_FIELD_MEMORY_TYPE
p.memory_type = C.ucs_memory_type_t(params.memType)
}
func (p *UcpRequestParams) SetMulti() *UcpRequestParams {
p.multi = true
return p
}

func (p *UcpRequestParams) SetMemory(m *UcpMemory) *UcpRequestParams {
p.Memory = m
return p
}

func (p *UcpRequestParams) SetCallback(cb UcpCallback) *UcpRequestParams {
p.Cb = cb
return p
}

func packParams(params *UcpRequestParams, p *C.ucp_request_param_t, cb unsafe.Pointer) uint64 {
if params == nil {
return 0
}

var cbId uint64
if params.Cb != nil {
cbId = register(params.Cb)
p.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA
cbAddr := (*unsafe.Pointer)(unsafe.Pointer(&p.cb[0]))
*cbAddr = cb
p.user_data = unsafe.Pointer(uintptr(cbId))
}

if params.memTypeSet {
p.op_attr_mask = C.UCP_OP_ATTR_FIELD_MEMORY_TYPE
p.memory_type = C.ucs_memory_type_t(params.memType)
}

if params.multi {
p.op_attr_mask |= C.UCP_OP_ATTR_FLAG_MULTI_SEND
}

if params.Memory != nil {
p.op_attr_mask |= C.UCP_OP_ATTR_FIELD_MEMH
p.memh = params.Memory.memHandle
}

return cbId
}

// Checks whether request is a pointer
func isRequestPtr(request C.ucs_status_ptr_t) bool {
errLast := UCS_ERR_LAST
Expand Down
29 changes: 2 additions & 27 deletions bindings/go/src/ucx/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,24 +227,12 @@ func (w *UcpWorker) RecvTagNonBlocking(address unsafe.Pointer, size uint64,
tag uint64, tagMask uint64, params *UcpRequestParams) (*UcpRequest, error) {
var requestParams C.ucp_request_param_t
var recvInfo C.ucp_tag_recv_info_t
var cbId uint64

requestParams.op_attr_mask = C.UCP_OP_ATTR_FIELD_RECV_INFO
recvInfoPtr := (*C.ucp_tag_recv_info_t)(unsafe.Pointer(&requestParams.recv_info[0]))
*recvInfoPtr = recvInfo

if params != nil {
setMemType(params, &requestParams)

if params.Cb != nil {
cbId = register(params.Cb)
requestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA
cbAddr := (*C.ucp_tag_recv_nbx_callback_t)(unsafe.Pointer(&requestParams.cb[0]))
*cbAddr = (C.ucp_tag_recv_nbx_callback_t)(C.ucxgo_completeGoTagRecvRequest)
requestParams.user_data = unsafe.Pointer(uintptr(cbId))
}
}

cbId := packParams(params, &requestParams, unsafe.Pointer(C.ucxgo_completeGoTagRecvRequest))
request := C.ucp_tag_recv_nbx(w.worker, address, C.size_t(size), C.ucp_tag_t(tag),
C.ucp_tag_t(tagMask), &requestParams)

Expand Down Expand Up @@ -298,26 +286,13 @@ func (w *UcpWorker) SetAmRecvHandler(id uint, flags UcpAmCbFlags, cb UcpAmRecvCa
func (w *UcpWorker) RecvAmDataNonBlocking(dataDesc *UcpAmData, recvBuffer unsafe.Pointer, size uint64,
params *UcpRequestParams) (*UcpRequest, error) {
var requestParams C.ucp_request_param_t
var cbId uint64
var length C.size_t

requestParams.op_attr_mask = C.UCP_OP_ATTR_FIELD_RECV_INFO
recvInfoPtr := (**C.size_t)(unsafe.Pointer(&requestParams.recv_info[0]))
*recvInfoPtr = &length

if params != nil {
setMemType(params, &requestParams)

if params.Cb != nil {
cbId = register(params.Cb)
requestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA
cbAddr := (*C.ucp_am_recv_data_nbx_callback_t)(unsafe.Pointer(&requestParams.cb[0]))
*cbAddr = (C.ucp_am_recv_data_nbx_callback_t)(C.ucxgo_completeAmRecvData)

requestParams.user_data = unsafe.Pointer(uintptr(cbId))
}
}

cbId := packParams(params, &requestParams, unsafe.Pointer(C.ucxgo_completeAmRecvData))
request := C.ucp_am_recv_data_nbx(w.worker, dataDesc.dataPtr, recvBuffer, C.size_t(size), &requestParams)

return NewRequest(request, cbId, length)
Expand Down

0 comments on commit 6aacca7

Please sign in to comment.