Skip to content

Commit

Permalink
Update subsample_process and iblt_subsampling to support client value…
Browse files Browse the repository at this point in the history
… unique count.

PiperOrigin-RevId: 563118274
  • Loading branch information
zitengsun authored and tensorflow-copybara committed Sep 7, 2023
1 parent a18a5b6 commit 46531d1
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 10 deletions.
28 changes: 28 additions & 0 deletions tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ py_test(
],
)

py_library(
name = "iblt_subsampling",
srcs = ["iblt_subsampling.py"],
deps = [
":iblt_factory",
":subsample_process",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
"//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/templates:aggregation_process",
],
)

py_test(
name = "iblt_subsampling_test",
srcs = ["iblt_subsampling_test.py"],
deps = [
":iblt_factory",
":iblt_subsampling",
":subsample_process",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)

py_library(
name = "iblt_lib",
srcs = [
Expand Down Expand Up @@ -195,6 +222,7 @@ py_library(

py_test(
name = "subsample_process_test",
timeout = "long",
srcs = ["subsample_process_test.py"],
deps = [
":iblt_factory",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for subsampling client strings before aggregation via IBLT."""

import collections

import tensorflow as tf

from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory
from tensorflow_federated.python.analytics.heavy_hitters.iblt import subsample_process
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.templates import aggregation_process


class SubsampledIbltFactory(factory.UnweightedAggregationFactory):
"""Factory for subsampling client data before aggregation."""

def __init__(
self,
inner_iblt_agg: iblt_factory.IbltFactory,
sampling_process: subsample_process.SubsampleProcess,
unique_counts: bool = False,
):
"""Initializes ClientPreprocessingAggregationFactory.
Args:
inner_iblt_agg: An instance of IbltFactory.
sampling_process: An instance of SubsampleProcess specifying parameters
and methods related to dataset subsampling at client side.
unique_counts: Whether the input dataset contain unique counts in its
values, if yes, if value will be of form `[count, 1]`.
"""
self.inner_iblt_agg = inner_iblt_agg
self.sampling_process = sampling_process
self.unique_counts = unique_counts

def create(
self, value_type: factory.ValueType
) -> aggregation_process.AggregationProcess:
expected_value_type = computation_types.SequenceType(
collections.OrderedDict([
(iblt_factory.DATASET_KEY, tf.string),
(
iblt_factory.DATASET_VALUE,
computation_types.TensorType(shape=[None], dtype=tf.int64),
),
])
)

if not expected_value_type.is_assignable_from(value_type):
raise ValueError(
'value_shape must be compatible with '
f'{expected_value_type}. Found {value_type} instead.'
)

if self.sampling_process.is_process_adaptive:
raise ValueError(
'Current implementaion only support nonadaptive process.'
)

subsample_param = self.sampling_process.get_init_param()
if self.unique_counts:
subsample_fn = self.sampling_process.subsample_fn_with_unique_count
else:
subsample_fn = self.sampling_process.subsample_fn

@tensorflow_computation.tf_computation(value_type)
@tf.function
def subsample(client_data):
return subsample_fn(client_data, subsample_param)

inner_process = self.inner_iblt_agg.create(subsample.type_signature.result)

@federated_computation.federated_computation(
inner_process.initialize.type_signature.result,
computation_types.at_clients(value_type),
)
def next_fn(state, client_data):
preprocessed = intrinsics.federated_map(subsample, client_data)
return inner_process.next(state, preprocessed)

return aggregation_process.AggregationProcess(
inner_process.initialize, next_fn
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2023, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import math
from typing import Union

from absl.testing import parameterized
import tensorflow as tf

from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory
from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_subsampling
from tensorflow_federated.python.analytics.heavy_hitters.iblt import subsample_process
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.types import computation_types


DATA = [
(
['seattle', 'hello', 'world', 'bye'],
[[1], [4], [1], [2]],
),
(['hi', 'seattle'], [[2], [5]]),
(
['good', 'morning', 'hi', 'bye'],
[[3], [6], [2], [3]],
),
]

AGGREGATED_DATA = {
'seattle': [6],
'hello': [4],
'world': [1],
'hi': [4],
'good': [3],
'morning': [6],
'bye': [5],
}


def _generate_client_data(
input_structure: list[tuple[list[Union[str, bytes]], list[list[int]]]]
) -> list[tf.data.Dataset]:
client_data = []
for input_strings, string_values in input_structure:
client = collections.OrderedDict([
(
iblt_factory.DATASET_KEY,
tf.constant(input_strings, dtype=tf.string),
),
(
iblt_factory.DATASET_VALUE,
tf.constant(string_values, dtype=tf.int64),
),
])
client_data.append(tf.data.Dataset.from_tensor_slices(client))
return client_data


CLIENT_DATA = _generate_client_data(DATA)


class IbltSubsamplingTest(tf.test.TestCase, parameterized.TestCase):

def test_incorrect_value_type(self):
iblt_fac = iblt_factory.IbltFactory(
capacity=100, string_max_bytes=10, repetitions=3, seed=0
)
sampling_process = subsample_process.ThresholdSamplingProcess(1.0)
subsample_fac = iblt_subsampling.SubsampledIbltFactory(
iblt_fac, sampling_process
)
wrong_type = computation_types.SequenceType(
collections.OrderedDict([
(iblt_factory.DATASET_KEY, tf.string),
(
iblt_factory.DATASET_VALUE,
computation_types.TensorType(shape=[None], dtype=tf.int32),
),
])
)
with self.assertRaises(ValueError):
subsample_fac.create(wrong_type)

@parameterized.named_parameters(
{'testcase_name': 'threshold 1.0', 'threshold': 1.0},
{'testcase_name': 'threshold 2.0', 'threshold': 2.0},
{'testcase_name': 'threshold 5.0', 'threshold': 5.0},
{'testcase_name': 'threshold 10.0', 'threshold': 10.0},
)
def test_subsampling_factory(self, threshold: float):
iblt_fac = iblt_factory.IbltFactory(
capacity=100, string_max_bytes=10, repetitions=3
)
sampling_process = subsample_process.ThresholdSamplingProcess(threshold)
subsample_fac = iblt_subsampling.SubsampledIbltFactory(
iblt_fac, sampling_process
)
value_type = computation_types.SequenceType(
collections.OrderedDict([
(iblt_factory.DATASET_KEY, tf.string),
(
iblt_factory.DATASET_VALUE,
computation_types.TensorType(shape=(1,), dtype=tf.int64),
),
])
)
agg_process = subsample_fac.create(value_type)
state = agg_process.initialize()
num_rounds = 100
output_counts = {
'seattle': 0,
'hello': 0,
'world': 0,
'hi': 0,
'bye': 0,
'good': 0,
'morning': 0,
}
for _ in range(num_rounds):
process_output = agg_process.next(state, CLIENT_DATA)
state = process_output.state
heavy_hitters = process_output.result.output_strings
heavy_hitters_counts = process_output.result.string_values[:, 0]
hist_round = dict(zip(heavy_hitters, heavy_hitters_counts))
for x in hist_round:
output_counts[x.decode('utf-8')] += hist_round[x]
for x in AGGREGATED_DATA:
self.assertAllClose(
output_counts[x] / float(num_rounds),
AGGREGATED_DATA[x][0],
atol=threshold / math.sqrt(num_rounds),
)


if __name__ == '__main__':
execution_contexts.set_sync_local_cpp_execution_context()
tf.test.main()
Loading

0 comments on commit 46531d1

Please sign in to comment.