From 1c77a86c04d76e63978e45971774351280485477 Mon Sep 17 00:00:00 2001 From: Artemy Kovalyov Date: Mon, 7 Oct 2024 14:11:20 +0000 Subject: [PATCH] BINDINGS/GO: Add RMA support - 2 --- bindings/go/tests/rma_test.go | 66 +++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/bindings/go/tests/rma_test.go b/bindings/go/tests/rma_test.go index d8ad3997cb6..502faf0434c 100644 --- a/bindings/go/tests/rma_test.go +++ b/bindings/go/tests/rma_test.go @@ -10,29 +10,30 @@ import ( ) func TestUcpRma(t *testing.T) { - const sendData string = "Hello GO" + const data string = "Hello GO" + const length uint64 = uint64(len(data)) for _, memType := range get_mem_types() { - sender := prepareContext(t, (&UcpParams{}).EnableRMA().EnableTag()) - receiver := prepareContext(t, (&UcpParams{}).EnableRMA().EnableTag()) + requestor := prepareContext(t, (&UcpParams{}).EnableRMA().EnableTag()) + responder := prepareContext(t, (&UcpParams{}).EnableRMA().EnableTag()) t.Logf("Testing RMA %v -> %v", memType.senderMemType, memType.recvMemType) ucpWorkerParams := (&UcpWorkerParams{}).SetThreadMode(UCS_THREAD_MODE_MULTI) - receiver.worker, _ = receiver.context.NewWorker(ucpWorkerParams) - sender.worker, _ = sender.context.NewWorker(ucpWorkerParams) - connect(sender, receiver) + requestor.worker, _ = requestor.context.NewWorker(ucpWorkerParams) + responder.worker, _ = responder.context.NewWorker(ucpWorkerParams) + connect(requestor, responder) - sendMem := memoryAllocate(sender, uint64(len(sendData)), memType.senderMemType) - memorySet(sender, []byte(sendData)) + localMem := memoryAllocate(requestor, length, memType.senderMemType) + memorySet(requestor, []byte(data)) - receiveMem := memoryAllocate(receiver, 4096, memType.recvMemType) + remoteMem := memoryAllocate(responder, 4096, memType.recvMemType) - rkeyBuf, _ := receiver.mem.Pack() - rkey, _ := sender.ep.Unpack(rkeyBuf) + rkeyBuf, _ := responder.mem.Pack() + rkey, _ := requestor.ep.Unpack(rkeyBuf) rkeyBuf.Close() - sendRequest, _ := sender.ep.RmaPut(sendMem, uint64(len(sendData)), uint64(uintptr(receiveMem)), rkey, &UcpRequestParams{ + putRequest, _ := requestor.ep.RmaPut(localMem, length, uint64(uintptr(remoteMem)), rkey, &UcpRequestParams{ Cb: func(request *UcpRequest, status UcsStatus) { if status != UCS_OK { t.Fatalf("Request failed with status: %d", status) @@ -41,23 +42,44 @@ func TestUcpRma(t *testing.T) { request.Close() }}) - for sendRequest.GetStatus() == UCS_INPROGRESS { - sender.worker.Progress() - receiver.worker.Progress() + for putRequest.GetStatus() == UCS_INPROGRESS { + requestor.worker.Progress() + responder.worker.Progress() } - if recvString := string(memoryGet(receiver)[:len(sendData)]); recvString != sendData { - t.Fatalf("Send data %s != recv data %s", sendData, recvString) + + if remoteData := string(memoryGet(responder)[:length]); remoteData != data { + t.Fatalf("Remote data %s != data %s", remoteData, data) + } + + memorySet(requestor, make([]byte, length)) + + getRequest, _ := requestor.ep.RmaGet(localMem, length, uint64(uintptr(remoteMem)), rkey, &UcpRequestParams{ + Cb: func(request *UcpRequest, status UcsStatus) { + if status != UCS_OK { + t.Fatalf("Request failed with status: %d", status) + } + + request.Close() + }}) + + for getRequest.GetStatus() == UCS_INPROGRESS { + requestor.worker.Progress() + responder.worker.Progress() + } + + if localData := string(memoryGet(responder)[:length]); localData != data { + t.Fatalf("Local data %s != data %s", localData, data) } - closeReq, _ := sender.ep.CloseNonBlockingFlush(nil) + closeReq, _ := requestor.ep.CloseNonBlockingFlush(nil) for closeReq.GetStatus() == UCS_INPROGRESS { - sender.worker.Progress() - receiver.worker.Progress() + requestor.worker.Progress() + responder.worker.Progress() } closeReq.Close() rkey.Close() - sender.Close() - receiver.Close() + requestor.Close() + responder.Close() } }