Skip to content

Commit

Permalink
Revert "Revert "Add intel phe (#2612)""
Browse files Browse the repository at this point in the history
This reverts commit 1cd4504.
  • Loading branch information
trivialfis committed Jun 14, 2024
1 parent ec13f82 commit b16dffb
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 184 deletions.
2 changes: 1 addition & 1 deletion nvflare/app_opt/xgboost/histogram_based_v2/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Constant:

TASK_CHECK_INTERVAL = 0.5
JOB_STATUS_CHECK_INTERVAL = 2.0
MAX_CLIENT_OP_INTERVAL = 90.0
MAX_CLIENT_OP_INTERVAL = 600.0
WORKFLOW_PROGRESS_TIMEOUT = 3600.0

# message topics
Expand Down
76 changes: 0 additions & 76 deletions nvflare/app_opt/xgboost/histogram_based_v2/mock_he/util.py

This file was deleted.

12 changes: 6 additions & 6 deletions nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.mock_he.adder import Adder
from nvflare.app_opt.xgboost.histogram_based_v2.mock_he.decrypter import Decrypter
from nvflare.app_opt.xgboost.histogram_based_v2.mock_he.encryptor import Encryptor
from nvflare.app_opt.xgboost.histogram_based_v2.mock_he.util import (
from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.adder import Adder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.decrypter import Decrypter
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.encryptor import Encryptor
from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.util import (
combine,
decode_encrypted_data,
decode_feature_aggregations,
Expand All @@ -32,8 +34,6 @@
generate_keys,
split,
)
from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult
from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import (
DATA_SET_HISTOGRAMS,
ProcessorDataConverter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class Adder:
def __init__(self, max_workers=10):
self.exe = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
self.num_workers = max_workers

def add(self, encrypted_numbers, features, sample_groups=None, encode_sum=True):
"""
Expand Down Expand Up @@ -50,7 +51,9 @@ def add(self, encrypted_numbers, features, sample_groups=None, encode_sum=True):
gid, sample_id_list = g
items.append((encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list))

results = self.exe.map(_do_add, items)
chunk_size = int((len(items) - 1) / self.num_workers) + 1

results = self.exe.map(_do_add, items, chunksize=chunk_size)
rl = []
for r in results:
rl.append(r)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,27 @@ def decrypt(self, encrypted_number_groups):
Returns: list of lists of decrypted numbers
"""
# print(f"decrypting {len(encrypted_number_groups)} number groups")
items = []

for g in encrypted_number_groups:
items.append(
(
self.private_key,
g,
)
)

results = self.exe.map(_do_decrypt, items)
items = [None] * len(encrypted_number_groups)

for i, g in enumerate(encrypted_number_groups):
items[i] = (self.private_key, g)

chunk_size = int((len(items) - 1) / self.max_workers) + 1

results = self.exe.map(_do_decrypt, items, chunksize=chunk_size)
rl = []
for r in results:
rl.append(r)
return rl


def _do_decrypt(item):
# t = time.time()
private_key, numbers = item
return numbers
ev = [None] * len(numbers)
for i, v in enumerate(numbers):
if isinstance(v, int):
d = v
else:
d = private_key.decrypt(v)
ev[i] = d
return ev
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent.futures


Expand All @@ -23,52 +24,22 @@ def __init__(self, pubkey, max_workers=10):
def encrypt(self, numbers):
"""
Encrypt a list of clear text numbers
Args:
numbers: clear text numbers to be encrypted
Returns: list of encrypted numbers
"""
items = [(self.pubkey, numbers[i]) for i in range(len(numbers))]
chunk_size = int(len(items) / self.max_workers)
if chunk_size == 0:
chunk_size = 1

num_values = len(numbers)
if num_values <= self.max_workers:
w_values = [numbers]
workers_needed = 1
else:
workers_needed = self.max_workers
w_values = [None for _ in range(self.max_workers)]
n = int(num_values / self.max_workers)
w_load = [n for _ in range(self.max_workers)]
r = num_values % self.max_workers
if r > 0:
for i in range(r):
w_load[i] += 1

start = 0
for i in range(self.max_workers):
end = start + w_load[i]
w_values[i] = numbers[start:end]
start = end

total_count = 0
for v in w_values:
total_count += len(v)
assert total_count == num_values

items = []
for i in range(workers_needed):
items.append((self.pubkey, w_values[i]))
return self._encrypt(items)

def _encrypt(self, items):
results = self.exe.map(_do_enc, items)
results = self.exe.map(_do_enc, items, chunksize=chunk_size)
rl = []
for r in results:
rl.extend(r)
rl.append(r)
return rl


def _do_enc(item):
pubkey, numbers = item
return numbers
pubkey, num = item
return pubkey.encrypt(num)
147 changes: 147 additions & 0 deletions nvflare/app_opt/xgboost/histogram_based_v2/sec/partial_he/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from base64 import urlsafe_b64decode, urlsafe_b64encode
from binascii import hexlify, unhexlify

import ipcl_python
from ipcl_python import PaillierEncryptedNumber as EncryptedNumber

SCALE_FACTOR = 10000000000000
ENABLE_DJN = True


def generate_keys(n_length=1024):
return ipcl_python.PaillierKeypair.generate_keypair(n_length=n_length, enable_DJN=ENABLE_DJN)


def encrypt_number(pubkey, ciphertext, exponent):
return EncryptedNumber(pubkey, ciphertext, [exponent], 1)


def create_pub_key(key, n_length=1024):
return ipcl_python.PaillierPublicKey(key=key, n_length=n_length, enable_DJN=ENABLE_DJN)


def ciphertext_to_int(d):
cifer = d.ciphertextBN()
return ipcl_python.BNUtils.BN2int(cifer[0])


def int_to_ciphertext(d, pubkey):
return ipcl_python.ipclCipherText(pubkey.pubkey, ipcl_python.BNUtils.int2BN(d))


def get_exponent(d):
return d.exponent(idx=0)


# base64 utils from jwcrypto
def base64url_encode(payload):
if not isinstance(payload, bytes):
payload = payload.encode("utf-8")
encode = urlsafe_b64encode(payload)
return encode.decode("utf-8").rstrip("=")


def base64url_decode(payload):
l = len(payload) % 4
if l == 2:
payload += "=="
elif l == 3:
payload += "="
elif l != 0:
raise ValueError("Invalid base64 string")
return urlsafe_b64decode(payload.encode("utf-8"))


def base64_to_int(source):
return int(hexlify(base64url_decode(source)), 16)


def int_to_base64(source):
assert source != 0
I = hex(source).rstrip("L").lstrip("0x")
return base64url_encode(unhexlify((len(I) % 2) * "0" + I))


def combine(g, h):
return g * SCALE_FACTOR + h


def split(d):
combined_g = d / SCALE_FACTOR
g = int(round(combined_g, 0))
h = d - g * SCALE_FACTOR
return g, h


def _encode_encrypted_numbers(numbers):
result = []
for x in numbers:
if isinstance(x, EncryptedNumber):
result.append((int_to_base64(ciphertext_to_int(x)), get_exponent(x)))
else:
result.append(x)
return result


def encode_encrypted_numbers_to_str(numbers):
return json.dumps(_encode_encrypted_numbers(numbers))


def encode_encrypted_data(pubkey, encrypted_numbers) -> str:
result = {"key": {"n": int_to_base64(pubkey.n)}, "nums": _encode_encrypted_numbers(encrypted_numbers)}
return json.dumps(result)


def decode_encrypted_data(encoded: str, n_length=1024):
data = json.loads(encoded)
pubkey = create_pub_key(key=base64_to_int(data["key"]["n"]), n_length=n_length)
numbers = data["nums"]
result = _decode_encrypted_numbers(pubkey, numbers)
return pubkey, result


def decode_encrypted_numbers_from_str(pubkey, encoded: str):
j = json.loads(encoded)
return _decode_encrypted_numbers(pubkey, j)


def _decode_encrypted_numbers(pubkey, data):
result = []
for v in data:
if isinstance(v, int):
d = v
else:
d = encrypt_number(
pubkey, ciphertext=int_to_ciphertext(base64_to_int(v[0]), pubkey=pubkey), exponent=int(v[1])
)
result.append(d)
return result


def encode_feature_aggregations(aggrs: list):
return json.dumps(aggrs)


def decode_feature_aggregations(pubkey, encoded: str):
result = []
aggrs = json.loads(encoded)
for aggr in aggrs:
feature_id, gid, encoded_nums_str = aggr
encrypted_numbers = decode_encrypted_numbers_from_str(pubkey, encoded_nums_str)
result.append((feature_id, gid, encrypted_numbers))
return result
Loading

0 comments on commit b16dffb

Please sign in to comment.