Skip to content

Commit

Permalink
BINDINGS/GO: Add RMA support
Browse files Browse the repository at this point in the history
  • Loading branch information
Artemy-Mellanox committed Oct 7, 2024
1 parent b1a268b commit e34c38c
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
78 changes: 78 additions & 0 deletions bindings/go/src/ucx/rma.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (C) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/

package ucx

// #include <ucp/api/ucp.h>
import "C"
import "unsafe"

type UcpRKey struct {
rkey C.ucp_rkey_h
}

type UcpRKeyBuffer struct {
buffer unsafe.Pointer
size C.size_t
}

func NewRKeyBuffer(buffer []byte) *UcpRKeyBuffer {
return &UcpRKeyBuffer{
buffer: unsafe.Pointer(&buffer[0]),
size: C.size_t(len(buffer)),
}
}

func (m *UcpMemory) Pack() (*UcpRKeyBuffer, error) {
result := &UcpRKeyBuffer{}

if status := C.ucp_rkey_pack(m.context, m.memHandle, &result.buffer, &result.size); status != C.UCS_OK {
return nil, newUcxError(status)
}

return result, nil
}

func (b *UcpRKeyBuffer) Bytes() []byte {
return unsafe.Slice((*byte)(b.buffer), b.size)
}

func (b *UcpRKeyBuffer) Close() {
var releaseParam C.ucp_memh_buffer_release_params_t
C.ucp_memh_buffer_release(b.buffer, &releaseParam)
}

func (e *UcpEp) Unpack(buffer *UcpRKeyBuffer) (*UcpRKey, error) {
result := &UcpRKey{}
if status := C.ucp_ep_rkey_unpack(e.ep, buffer.buffer, &result.rkey); status != C.UCS_OK {
return nil, newUcxError(status)
}

return result, nil
}

func (r *UcpRKey) Close() {
C.ucp_rkey_destroy(r.rkey)
}

func (e *UcpEp) RmaPut(buffer unsafe.Pointer, size uint64, remote_addr uint64, rkey *UcpRKey, params *UcpRequestParams) (*UcpRequest, error) {
var requestParams C.ucp_request_param_t

cbId := setSendParams(params, &requestParams)

request := C.ucp_put_nbx(e.ep, buffer, C.size_t(size), C.uint64_t(remote_addr), rkey.rkey, &requestParams)

return NewRequest(request, cbId, nil)
}

func (e *UcpEp) RmaGet(buffer unsafe.Pointer, size uint64, remote_addr uint64, rkey *UcpRKey, params *UcpRequestParams) (*UcpRequest, error) {
var requestParams C.ucp_request_param_t

cbId := setSendParams(params, &requestParams)

request := C.ucp_get_nbx(e.ep, buffer, C.size_t(size), C.uint64_t(remote_addr), rkey.rkey, &requestParams)

return NewRequest(request, cbId, nil)
}
63 changes: 63 additions & 0 deletions bindings/go/tests/rma_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (C) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package goucxtests

import (
"testing"
. "ucx"
)

func TestUcpRma(t *testing.T) {
const sendData string = "Hello GO"

for _, memType := range get_mem_types() {
sender := prepareContext(t, (&UcpParams{}).EnableRMA().EnableTag())
receiver := prepareContext(t, (&UcpParams{}).EnableRMA().EnableTag())
t.Logf("Testing RMA %v -> %v", memType.senderMemType, memType.recvMemType)

ucpWorkerParams := (&UcpWorkerParams{}).SetThreadMode(UCS_THREAD_MODE_MULTI)

receiver.worker, _ = receiver.context.NewWorker(ucpWorkerParams)
sender.worker, _ = sender.context.NewWorker(ucpWorkerParams)
connect(sender, receiver)

sendMem := memoryAllocate(sender, uint64(len(sendData)), memType.senderMemType)
memorySet(sender, []byte(sendData))

receiveMem := memoryAllocate(receiver, 4096, memType.recvMemType)

rkeyBuf, _ := receiver.mem.Pack()
rkey, _ := sender.ep.Unpack(rkeyBuf)
rkeyBuf.Close()

sendRequest, _ := sender.ep.RmaPut(sendMem, uint64(len(sendData)), uint64(uintptr(receiveMem)), rkey, &UcpRequestParams{
Cb: func(request *UcpRequest, status UcsStatus) {
if status != UCS_OK {
t.Fatalf("Request failed with status: %d", status)
}

request.Close()
}})

for sendRequest.GetStatus() == UCS_INPROGRESS {
sender.worker.Progress()
receiver.worker.Progress()
}
if recvString := string(memoryGet(receiver)[:len(sendData)]); recvString != sendData {
t.Fatalf("Send data %s != recv data %s", sendData, recvString)
}

closeReq, _ := sender.ep.CloseNonBlockingFlush(nil)
for closeReq.GetStatus() == UCS_INPROGRESS {
sender.worker.Progress()
receiver.worker.Progress()
}
closeReq.Close()
rkey.Close()

sender.Close()
receiver.Close()
}
}

0 comments on commit e34c38c

Please sign in to comment.