Skip to content

Commit

Permalink
Updates MeanFactory to allow safe division to avoid nan from divi…
Browse files Browse the repository at this point in the history
…sion by zero.

PiperOrigin-RevId: 339245271
  • Loading branch information
PraChetit authored and ZacharyGarrett committed Oct 27, 2020
1 parent 7e254a6 commit ec6dfe0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
31 changes: 26 additions & 5 deletions tensorflow_federated/python/aggregators/mean_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ class MeanFactory(factory.AggregationProcessFactory):
realize the sum of weighted values and weights.
- Division of summed weighted values and summed weights at `SERVER`.
Note that the the division at `SERVER` can protect against division by 0, as
specified by `no_nan_division` constructor argument.
The `state` is the composed `state` of the aggregation processes created by
the two inner aggregation factories. The same holds for `measurements`.
"""

def __init__(
self,
value_sum_factory: Optional[factory.AggregationProcessFactory] = None,
weight_sum_factory: Optional[factory.AggregationProcessFactory] = None):
weight_sum_factory: Optional[factory.AggregationProcessFactory] = None,
no_nan_division: Optional[bool] = False):
"""Initializes `MeanFactory`.
Args:
Expand All @@ -64,6 +68,8 @@ def __init__(
weight_sum_factory: An optional
`tff.aggregators.AggregationProcessFactory` responsible for summation of
weights. If not specified, `tff.aggregators.SumFactory` is used.
no_nan_division: A bool. If True, the computed mean is 0 if sum of weights
is equal to 0.
Raises:
TypeError: If provided `value_sum_factory` or `weight_sum_factory` is not
Expand All @@ -81,6 +87,9 @@ def __init__(
factory.AggregationProcessFactory)
self._weight_sum_factory = weight_sum_factory

py_typecheck.check_type(no_nan_division, bool)
self._no_nan_division = no_nan_division

def create(
self,
value_type: factory.ValueType) -> aggregation_process.AggregationProcess:
Expand Down Expand Up @@ -136,8 +145,12 @@ def next_fn(state, value, weight):
weight)

# Server computation.
weighted_mean_value = intrinsics.federated_map(
_div, (value_output.result, weight_output.result))
if self._no_nan_division:
weighted_mean_value = intrinsics.federated_map(
_div_no_nan, (value_output.result, weight_output.result))
else:
weighted_mean_value = intrinsics.federated_map(
_div, (value_output.result, weight_output.result))

# Output preparation.
state = collections.OrderedDict(
Expand All @@ -160,5 +173,13 @@ def _mul(value, weight):

@computations.tf_computation()
def _div(weighted_value_sum, weight_sum):
return tf.nest.map_structure(lambda x: x / tf.cast(weight_sum, x.dtype),
weighted_value_sum)
return tf.nest.map_structure(
lambda x: tf.math.divide(x, tf.cast(weight_sum, x.dtype)),
weighted_value_sum)


@computations.tf_computation()
def _div_no_nan(weighted_value_sum, weight_sum):
return tf.nest.map_structure(
lambda x: tf.math.divide_no_nan(x, tf.cast(weight_sum, x.dtype)),
weighted_value_sum)
24 changes: 24 additions & 0 deletions tensorflow_federated/python/aggregators/mean_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import math
from absl.testing import parameterized
import tensorflow as tf

Expand Down Expand Up @@ -180,6 +181,29 @@ def test_weight_arg(self):
weights = [6.0, 3.0, 1.0]
self.assertEqual(1.5, process.next(state, client_data, weights).result)

def test_weight_arg_all_zeros_nan_division(self):
mean_f = mean_factory.MeanFactory(no_nan_division=False)
value_type = computation_types.to_type(tf.float32)
process = mean_f.create(value_type)

state = process.initialize()
client_data = [1.0, 2.0, 3.0]
weights = [0.0, 0.0, 0.0]
# Division by zero resulting in NaN/Inf *should* occur.
self.assertFalse(
math.isfinite(process.next(state, client_data, weights).result))

def test_weight_arg_all_zeros_no_nan_division(self):
mean_f = mean_factory.MeanFactory(no_nan_division=True)
value_type = computation_types.to_type(tf.float32)
process = mean_f.create(value_type)

state = process.initialize()
client_data = [1.0, 2.0, 3.0]
weights = [0.0, 0.0, 0.0]
# Division by zero resulting in NaN/Inf *should not* occur.
self.assertEqual(0.0, process.next(state, client_data, weights).result)

def test_inner_value_sum_factory(self):
sum_factory = aggregators_test_utils.SumPlusOneFactory()
mean_f = mean_factory.MeanFactory(value_sum_factory=sum_factory)
Expand Down

0 comments on commit ec6dfe0

Please sign in to comment.