diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD b/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD index e1788dd0ac..7c5f50f883 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/BUILD @@ -58,6 +58,33 @@ py_test( ], ) +py_library( + name = "iblt_subsampling", + srcs = ["iblt_subsampling.py"], + deps = [ + ":iblt_factory", + ":subsample_process", + "//tensorflow_federated/python/aggregators:factory", + "//tensorflow_federated/python/core/impl/federated_context:federated_computation", + "//tensorflow_federated/python/core/impl/federated_context:intrinsics", + "//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation", + "//tensorflow_federated/python/core/impl/types:computation_types", + "//tensorflow_federated/python/core/templates:aggregation_process", + ], +) + +py_test( + name = "iblt_subsampling_test", + srcs = ["iblt_subsampling_test.py"], + deps = [ + ":iblt_factory", + ":iblt_subsampling", + ":subsample_process", + "//tensorflow_federated/python/core/backends/native:execution_contexts", + "//tensorflow_federated/python/core/impl/types:computation_types", + ], +) + py_library( name = "iblt_lib", srcs = [ @@ -195,6 +222,7 @@ py_library( py_test( name = "subsample_process_test", + timeout = "long", srcs = ["subsample_process_test.py"], deps = [ ":iblt_factory", diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_subsampling.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_subsampling.py new file mode 100644 index 0000000000..3331e618a5 --- /dev/null +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_subsampling.py @@ -0,0 +1,99 @@ +# Copyright 2023, The TensorFlow Federated Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Factory for subsampling client strings before aggregation via IBLT.""" + +import collections + +import tensorflow as tf + +from tensorflow_federated.python.aggregators import factory +from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory +from tensorflow_federated.python.analytics.heavy_hitters.iblt import subsample_process +from tensorflow_federated.python.core.impl.federated_context import federated_computation +from tensorflow_federated.python.core.impl.federated_context import intrinsics +from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation +from tensorflow_federated.python.core.impl.types import computation_types +from tensorflow_federated.python.core.templates import aggregation_process + + +class SubsampledIbltFactory(factory.UnweightedAggregationFactory): + """Factory for subsampling client data before aggregation.""" + + def __init__( + self, + inner_iblt_agg: iblt_factory.IbltFactory, + sampling_process: subsample_process.SubsampleProcess, + unique_counts: bool = False, + ): + """Initializes ClientPreprocessingAggregationFactory. + + Args: + inner_iblt_agg: An instance of IbltFactory. + sampling_process: An instance of SubsampleProcess specifying parameters + and methods related to dataset subsampling at client side. + unique_counts: Whether the input dataset contain unique counts in its + values, if yes, if value will be of form `[count, 1]`. + """ + self.inner_iblt_agg = inner_iblt_agg + self.sampling_process = sampling_process + self.unique_counts = unique_counts + + def create( + self, value_type: factory.ValueType + ) -> aggregation_process.AggregationProcess: + expected_value_type = computation_types.SequenceType( + collections.OrderedDict([ + (iblt_factory.DATASET_KEY, tf.string), + ( + iblt_factory.DATASET_VALUE, + computation_types.TensorType(shape=[None], dtype=tf.int64), + ), + ]) + ) + + if not expected_value_type.is_assignable_from(value_type): + raise ValueError( + 'value_shape must be compatible with ' + f'{expected_value_type}. Found {value_type} instead.' + ) + + if self.sampling_process.is_process_adaptive: + raise ValueError( + 'Current implementaion only support nonadaptive process.' + ) + + subsample_param = self.sampling_process.get_init_param() + if self.unique_counts: + subsample_fn = self.sampling_process.subsample_fn_with_unique_count + else: + subsample_fn = self.sampling_process.subsample_fn + + @tensorflow_computation.tf_computation(value_type) + @tf.function + def subsample(client_data): + return subsample_fn(client_data, subsample_param) + + inner_process = self.inner_iblt_agg.create(subsample.type_signature.result) + + @federated_computation.federated_computation( + inner_process.initialize.type_signature.result, + computation_types.at_clients(value_type), + ) + def next_fn(state, client_data): + preprocessed = intrinsics.federated_map(subsample, client_data) + return inner_process.next(state, preprocessed) + + return aggregation_process.AggregationProcess( + inner_process.initialize, next_fn + ) diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_subsampling_test.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_subsampling_test.py new file mode 100644 index 0000000000..402ff92763 --- /dev/null +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/iblt_subsampling_test.py @@ -0,0 +1,149 @@ +# Copyright 2023, The TensorFlow Federated Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import math +from typing import Union + +from absl.testing import parameterized +import tensorflow as tf + +from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_factory +from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_subsampling +from tensorflow_federated.python.analytics.heavy_hitters.iblt import subsample_process +from tensorflow_federated.python.core.backends.native import execution_contexts +from tensorflow_federated.python.core.impl.types import computation_types + + +DATA = [ + ( + ['seattle', 'hello', 'world', 'bye'], + [[1], [4], [1], [2]], + ), + (['hi', 'seattle'], [[2], [5]]), + ( + ['good', 'morning', 'hi', 'bye'], + [[3], [6], [2], [3]], + ), +] + +AGGREGATED_DATA = { + 'seattle': [6], + 'hello': [4], + 'world': [1], + 'hi': [4], + 'good': [3], + 'morning': [6], + 'bye': [5], +} + + +def _generate_client_data( + input_structure: list[tuple[list[Union[str, bytes]], list[list[int]]]] +) -> list[tf.data.Dataset]: + client_data = [] + for input_strings, string_values in input_structure: + client = collections.OrderedDict([ + ( + iblt_factory.DATASET_KEY, + tf.constant(input_strings, dtype=tf.string), + ), + ( + iblt_factory.DATASET_VALUE, + tf.constant(string_values, dtype=tf.int64), + ), + ]) + client_data.append(tf.data.Dataset.from_tensor_slices(client)) + return client_data + + +CLIENT_DATA = _generate_client_data(DATA) + + +class IbltSubsamplingTest(tf.test.TestCase, parameterized.TestCase): + + def test_incorrect_value_type(self): + iblt_fac = iblt_factory.IbltFactory( + capacity=100, string_max_bytes=10, repetitions=3, seed=0 + ) + sampling_process = subsample_process.ThresholdSamplingProcess(1.0) + subsample_fac = iblt_subsampling.SubsampledIbltFactory( + iblt_fac, sampling_process + ) + wrong_type = computation_types.SequenceType( + collections.OrderedDict([ + (iblt_factory.DATASET_KEY, tf.string), + ( + iblt_factory.DATASET_VALUE, + computation_types.TensorType(shape=[None], dtype=tf.int32), + ), + ]) + ) + with self.assertRaises(ValueError): + subsample_fac.create(wrong_type) + + @parameterized.named_parameters( + {'testcase_name': 'threshold 1.0', 'threshold': 1.0}, + {'testcase_name': 'threshold 2.0', 'threshold': 2.0}, + {'testcase_name': 'threshold 5.0', 'threshold': 5.0}, + {'testcase_name': 'threshold 10.0', 'threshold': 10.0}, + ) + def test_subsampling_factory(self, threshold: float): + iblt_fac = iblt_factory.IbltFactory( + capacity=100, string_max_bytes=10, repetitions=3 + ) + sampling_process = subsample_process.ThresholdSamplingProcess(threshold) + subsample_fac = iblt_subsampling.SubsampledIbltFactory( + iblt_fac, sampling_process + ) + value_type = computation_types.SequenceType( + collections.OrderedDict([ + (iblt_factory.DATASET_KEY, tf.string), + ( + iblt_factory.DATASET_VALUE, + computation_types.TensorType(shape=(1,), dtype=tf.int64), + ), + ]) + ) + agg_process = subsample_fac.create(value_type) + state = agg_process.initialize() + num_rounds = 100 + output_counts = { + 'seattle': 0, + 'hello': 0, + 'world': 0, + 'hi': 0, + 'bye': 0, + 'good': 0, + 'morning': 0, + } + for _ in range(num_rounds): + process_output = agg_process.next(state, CLIENT_DATA) + state = process_output.state + heavy_hitters = process_output.result.output_strings + heavy_hitters_counts = process_output.result.string_values[:, 0] + hist_round = dict(zip(heavy_hitters, heavy_hitters_counts)) + for x in hist_round: + output_counts[x.decode('utf-8')] += hist_round[x] + for x in AGGREGATED_DATA: + self.assertAllClose( + output_counts[x] / float(num_rounds), + AGGREGATED_DATA[x][0], + atol=threshold / math.sqrt(num_rounds), + ) + + +if __name__ == '__main__': + execution_contexts.set_sync_local_cpp_execution_context() + tf.test.main() diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process.py index ec47296489..30f139c74e 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process.py @@ -164,6 +164,12 @@ def subsample_fn( subsampled client dataset with the same format as client_data. """ + @abc.abstractmethod + def subsample_fn_with_unique_count( + self, client_data: tf.data.Dataset, subsampling_param: float + ): + """Performs subsampling when client values have unique counts appended.""" + class ThresholdSamplingProcess(SubsampleProcess): """Implements threshold sampling. @@ -190,7 +196,10 @@ class ThresholdSamplingProcess(SubsampleProcess): """ def __init__( - self, init_param: float, is_adaptive: bool = False, beta: float = 0.5 + self, + init_param: float, + is_adaptive: bool = False, + beta: float = 0.5, ): """Initialize the subsamping precoess. @@ -246,24 +255,31 @@ def get_init_param(self): return self._init_param def subsample_fn( - self, client_data: tf.data.Dataset, subsampling_param: float - ): + self, + client_data: tf.data.Dataset, + subsampling_param: float, + ) -> tf.data.Dataset: """See base class. Raise ValueError if client data has negative counts.""" - generator = tf.random.Generator.from_non_deterministic_state() - @tf.function def threshold_sampling(element): - count = element[iblt_factory.DATASET_VALUE] + count = element[iblt_factory.DATASET_VALUE][0] tf.debugging.assert_non_negative( count, 'Current implementation only supports positive values.' ) - if count >= subsampling_param: + if count >= tf.cast(subsampling_param, dtype=count.dtype): return element - random_val = generator.uniform( - shape=(), minval=0, maxval=subsampling_param, dtype=count.dtype + + random_val = tf.random.uniform( + shape=(), + minval=0, + maxval=tf.cast(subsampling_param, dtype=count.dtype), + dtype=count.dtype, ) - thresholded_val = subsampling_param if count > random_val else 0 + if count > random_val: + thresholded_val = tf.cast(subsampling_param, dtype=count.dtype) + else: + thresholded_val = tf.cast(0, dtype=count.dtype) return collections.OrderedDict([ (iblt_factory.DATASET_KEY, element[iblt_factory.DATASET_KEY]), ( @@ -271,6 +287,44 @@ def threshold_sampling(element): tf.cast([thresholded_val], dtype=count.dtype), ), ]) + subsampled_client_data = client_data.map(threshold_sampling) + return subsampled_client_data.filter( + lambda x: x[iblt_factory.DATASET_VALUE][0] > 0 + ) + + def subsample_fn_with_unique_count( + self, + client_data: tf.data.Dataset, + subsampling_param: float, + ) -> tf.data.Dataset: + """See base class. Raise ValueError if client data has negative counts.""" + + @tf.function + def threshold_sampling(element): + count = element[iblt_factory.DATASET_VALUE][0] + tf.debugging.assert_non_negative( + count, 'Current implementation only supports positive values.' + ) + if count >= tf.cast(subsampling_param, dtype=count.dtype): + return element + + random_val = tf.random.uniform( + shape=(), + minval=0, + maxval=tf.cast(subsampling_param, dtype=count.dtype), + dtype=count.dtype, + ) + if count > random_val: + thresholded_val = tf.cast(subsampling_param, dtype=count.dtype) + else: + thresholded_val = tf.cast(0, dtype=count.dtype) + return collections.OrderedDict([ + (iblt_factory.DATASET_KEY, element[iblt_factory.DATASET_KEY]), + ( + iblt_factory.DATASET_VALUE, + tf.cast([thresholded_val, 1], dtype=count.dtype), + ), + ]) subsampled_client_data = client_data.map(threshold_sampling) return subsampled_client_data.filter( diff --git a/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process_test.py b/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process_test.py index 1c8dfeb42e..ece7ac0b05 100644 --- a/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process_test.py +++ b/tensorflow_federated/python/analytics/heavy_hitters/iblt/subsample_process_test.py @@ -62,6 +62,9 @@ def _generate_client_local_data( (['good', 'morning', 'hi', 'bye'], [[3], [1], [2], [5]]) ) DATA_ALL_ZERO = _generate_client_local_data((['new', 'york'], [[0], [0]])) +DATA_SOME_IN_BETWEEN_WITH_UNIQUE_COUNT = _generate_client_local_data( + (['good', 'morning', 'hi', 'bye'], [[3, 1], [1, 1], [2, 1], [5, 1]]) +) class ThresholdSubsampleProcessTest(tf.test.TestCase, parameterized.TestCase): @@ -150,6 +153,24 @@ def test_sampling_in_between(self): counts[j] += _get_count_from_dataset(sampled_dataset, strings[j]) self.assertAllClose(counts / rep, expected_freqs, atol=0.45) + def test_sampling_in_between_with_unique_count(self): + threshold_sampling = subsample_process.ThresholdSamplingProcess( + init_param=THRESHOLD + ) + rep = 300 + strings = ['good', 'morning', 'hi', 'bye'] + expected_freqs = np.array([3, 1, 2, 5]) + counts = np.zeros(len(strings)) + subsample_param = threshold_sampling.get_init_param() + for i in range(rep): + tf.random.set_seed(i) + sampled_dataset = threshold_sampling.subsample_fn_with_unique_count( + DATA_SOME_IN_BETWEEN_WITH_UNIQUE_COUNT, subsample_param + ) + for j, _ in enumerate(strings): + counts[j] += _get_count_from_dataset(sampled_dataset, strings[j]) + self.assertAllClose(counts / rep, expected_freqs, atol=0.45) + @parameterized.named_parameters( { 'testcase_name': 'init_threshold_one',