diff --git a/bindings/go/src/ucx/endpoint.go b/bindings/go/src/ucx/endpoint.go index 52f797481f5..74cacd8d525 100644 --- a/bindings/go/src/ucx/endpoint.go +++ b/bindings/go/src/ucx/endpoint.go @@ -19,20 +19,7 @@ type UcpEp struct { var errorHandles = make(map[C.ucp_ep_h]UcpEpErrHandler) func setSendParams(goRequestParams *UcpRequestParams, cRequestParams *C.ucp_request_param_t) uint64 { - var cbId uint64 - if goRequestParams != nil { - if goRequestParams.Cb != nil { - cbId = register(goRequestParams.Cb) - cRequestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA - cbAddr := (*C.ucp_send_nbx_callback_t)(unsafe.Pointer(&cRequestParams.cb[0])) - *cbAddr = (C.ucp_send_nbx_callback_t)(C.ucxgo_completeGoSendRequest) - cRequestParams.user_data = unsafe.Pointer(uintptr(cbId)) - } - - setMemType(goRequestParams, cRequestParams) - } - - return cbId + return packParams(goRequestParams, cRequestParams, unsafe.Pointer(C.ucxgo_completeGoSendRequest)) } // This routine flushes all outstanding AMO and RMA communications on the endpoint. diff --git a/bindings/go/src/ucx/request.go b/bindings/go/src/ucx/request.go index 38ae22c28cb..d5daf01e5f6 100644 --- a/bindings/go/src/ucx/request.go +++ b/bindings/go/src/ucx/request.go @@ -21,6 +21,8 @@ type UcpRequestParams struct { memTypeSet bool memType UcsMemoryType Cb UcpCallback + multi bool + Memory *UcpMemory } func (p *UcpRequestParams) SetMemType(memType UcsMemoryType) *UcpRequestParams { @@ -29,11 +31,14 @@ func (p *UcpRequestParams) SetMemType(memType UcsMemoryType) *UcpRequestParams { return p } -func setMemType(params *UcpRequestParams, p *C.ucp_request_param_t) { - if (params != nil) && params.memTypeSet { - p.op_attr_mask = C.UCP_OP_ATTR_FIELD_MEMORY_TYPE - p.memory_type = C.ucs_memory_type_t(params.memType) - } +func (p *UcpRequestParams) SetMulti() *UcpRequestParams { + p.multi = true + return p +} + +func (p *UcpRequestParams) SetMemory(m *UcpMemory) *UcpRequestParams { + p.Memory = m + return p } func (p *UcpRequestParams) SetCallback(cb UcpCallback) *UcpRequestParams { @@ -41,6 +46,37 @@ func (p *UcpRequestParams) SetCallback(cb UcpCallback) *UcpRequestParams { return p } +func packParams(params *UcpRequestParams, p *C.ucp_request_param_t, cb unsafe.Pointer) uint64 { + if params == nil { + return 0 + } + + var cbId uint64 + if params.Cb != nil { + cbId = register(params.Cb) + p.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA + cbAddr := (*unsafe.Pointer)(unsafe.Pointer(&p.cb[0])) + *cbAddr = cb + p.user_data = unsafe.Pointer(uintptr(cbId)) + } + + if params.memTypeSet { + p.op_attr_mask = C.UCP_OP_ATTR_FIELD_MEMORY_TYPE + p.memory_type = C.ucs_memory_type_t(params.memType) + } + + if params.multi { + p.op_attr_mask |= C.UCP_OP_ATTR_FLAG_MULTI_SEND + } + + if params.Memory != nil { + p.op_attr_mask |= C.UCP_OP_ATTR_FIELD_MEMH + p.memh = params.Memory.memHandle + } + + return cbId +} + // Checks whether request is a pointer func isRequestPtr(request C.ucs_status_ptr_t) bool { errLast := UCS_ERR_LAST diff --git a/bindings/go/src/ucx/worker.go b/bindings/go/src/ucx/worker.go index 91cea632be4..67d0b648f3a 100644 --- a/bindings/go/src/ucx/worker.go +++ b/bindings/go/src/ucx/worker.go @@ -227,24 +227,12 @@ func (w *UcpWorker) RecvTagNonBlocking(address unsafe.Pointer, size uint64, tag uint64, tagMask uint64, params *UcpRequestParams) (*UcpRequest, error) { var requestParams C.ucp_request_param_t var recvInfo C.ucp_tag_recv_info_t - var cbId uint64 requestParams.op_attr_mask = C.UCP_OP_ATTR_FIELD_RECV_INFO recvInfoPtr := (*C.ucp_tag_recv_info_t)(unsafe.Pointer(&requestParams.recv_info[0])) *recvInfoPtr = recvInfo - if params != nil { - setMemType(params, &requestParams) - - if params.Cb != nil { - cbId = register(params.Cb) - requestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA - cbAddr := (*C.ucp_tag_recv_nbx_callback_t)(unsafe.Pointer(&requestParams.cb[0])) - *cbAddr = (C.ucp_tag_recv_nbx_callback_t)(C.ucxgo_completeGoTagRecvRequest) - requestParams.user_data = unsafe.Pointer(uintptr(cbId)) - } - } - + cbId := packParams(params, &requestParams, unsafe.Pointer(C.ucxgo_completeGoTagRecvRequest)) request := C.ucp_tag_recv_nbx(w.worker, address, C.size_t(size), C.ucp_tag_t(tag), C.ucp_tag_t(tagMask), &requestParams) @@ -298,26 +286,13 @@ func (w *UcpWorker) SetAmRecvHandler(id uint, flags UcpAmCbFlags, cb UcpAmRecvCa func (w *UcpWorker) RecvAmDataNonBlocking(dataDesc *UcpAmData, recvBuffer unsafe.Pointer, size uint64, params *UcpRequestParams) (*UcpRequest, error) { var requestParams C.ucp_request_param_t - var cbId uint64 var length C.size_t requestParams.op_attr_mask = C.UCP_OP_ATTR_FIELD_RECV_INFO recvInfoPtr := (**C.size_t)(unsafe.Pointer(&requestParams.recv_info[0])) *recvInfoPtr = &length - if params != nil { - setMemType(params, &requestParams) - - if params.Cb != nil { - cbId = register(params.Cb) - requestParams.op_attr_mask |= C.UCP_OP_ATTR_FIELD_CALLBACK | C.UCP_OP_ATTR_FIELD_USER_DATA - cbAddr := (*C.ucp_am_recv_data_nbx_callback_t)(unsafe.Pointer(&requestParams.cb[0])) - *cbAddr = (C.ucp_am_recv_data_nbx_callback_t)(C.ucxgo_completeAmRecvData) - - requestParams.user_data = unsafe.Pointer(uintptr(cbId)) - } - } - + cbId := packParams(params, &requestParams, unsafe.Pointer(C.ucxgo_completeAmRecvData)) request := C.ucp_am_recv_data_nbx(w.worker, dataDesc.dataPtr, recvBuffer, C.size_t(size), &requestParams) return NewRequest(request, cbId, length)