Skip to content

Commit

Permalink
Simple tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 13, 2024
1 parent 1cd4504 commit 449f1f4
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
12 changes: 4 additions & 8 deletions integration/xgboost/processor/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
# Build Instruction

This plugin build requires xgboost source code, checkout xgboost source and build it with FEDERATED plugin,

cd xgboost
mkdir build
cd build
cmake .. -DPLUGIN_FEDERATED=ON
make

``` sh
cd NVFlare/integration/xgboost/processor
mkdir build
cd build
cmake ..
make
```

See [tests](./tests) for simple examples.
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ NVF_C char const *FederatedPluginErrorMsg() {

FederatedPluginHandle NVF_C FederatedPluginCreate(int argc, char const **argv) {
using namespace nvflare;
CHandleT pptr = new std::shared_ptr<TensealPlugin>;
try {
CHandleT pptr = new std::shared_ptr<TensealPlugin>;
std::vector<std::pair<std::string_view, std::string_view>> args;
std::transform(
argv, argv + argc, std::back_inserter(args), [](char const *carg) {
Expand Down
73 changes: 73 additions & 0 deletions integration/xgboost/processor/tests/test_tenseal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ctypes
import os
from contextlib import contextmanager

import numpy as np


def _check_call(rc: int) -> None:
assert rc == 0


plugin_path = os.path.join(
os.path.dirname(os.path.normpath(os.path.abspath(__file__))), os.pardir, "build", "libproc_nvflare.so"
)


@contextmanager
def load_plugin():
nvflare = ctypes.cdll.LoadLibrary(plugin_path)
nvflare.FederatedPluginCreate.restype = ctypes.c_void_p
handle = ctypes.c_void_p(nvflare.FederatedPluginCreate(ctypes.c_int(0), None))
try:
yield nvflare, handle
finally:
_check_call(nvflare.FederatedPluginClose(handle))


def test_load():
with load_plugin() as nvflare:
pass


def test_grad():
array = np.arange(16, dtype=np.float32)
out = ctypes.POINTER(ctypes.c_uint8)()
out_len = ctypes.c_size_t()

with load_plugin() as (nvflare, handle):
_check_call(
nvflare.FederatedPluginEncryptGPairs(
handle,
array.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
array.size,
ctypes.byref(out),
ctypes.byref(out_len),
)
)

out1 = ctypes.POINTER(ctypes.c_uint8)()
out_len1 = ctypes.c_size_t()

_check_call(
nvflare.FederatedPluginEncryptGPairs(
handle,
out,
out_len,
ctypes.byref(out1),
ctypes.byref(out_len1),
)
)

0 comments on commit 449f1f4

Please sign in to comment.