From 9face59d8d0586891a5959b3bc35d88c8c8f4ea3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 7 Sep 2023 13:27:10 -0700 Subject: [PATCH] Disable normalize_all_equal_bit call for DistributeAggregateForm. PiperOrigin-RevId: 563528327 --- .../core/impl/compiler/transformations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tensorflow_federated/python/core/impl/compiler/transformations.py b/tensorflow_federated/python/core/impl/compiler/transformations.py index cf90241d75..925f559a43 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformations.py +++ b/tensorflow_federated/python/core/impl/compiler/transformations.py @@ -201,11 +201,13 @@ def _build(comp, scope): def get_normalized_call_dominant_lambda( comp: building_blocks.Lambda, + normalize_all_equal_bit: bool = True, ) -> building_blocks.Lambda: """Creates normalized call dominant form for a lambda computation. Args: comp: A computation to normalize. + normalize_all_equal_bit: Whether to normalize the all-equal bit. Returns: A transformed but semantically-equivalent `comp`. The result will be a @@ -258,7 +260,9 @@ def get_normalized_call_dominant_lambda( ) comp.result.check_block() - comp = tree_transformations.normalize_all_equal_bit(comp) + if normalize_all_equal_bit: + comp = tree_transformations.normalize_all_equal_bit(comp) + tree_analysis.check_contains_no_unbound_references(comp) return comp @@ -1474,9 +1478,15 @@ def divisive_force_align_and_split_by_intrinsics( ############################### Step 8 ###################################### # Normalize all of the output computations. - before_comp = get_normalized_call_dominant_lambda(before_comp) - intrinsic_comp = get_normalized_call_dominant_lambda(intrinsic_comp) - after_comp = get_normalized_call_dominant_lambda(after_comp) + before_comp = get_normalized_call_dominant_lambda( + before_comp, normalize_all_equal_bit=False + ) + intrinsic_comp = get_normalized_call_dominant_lambda( + intrinsic_comp, normalize_all_equal_bit=False + ) + after_comp = get_normalized_call_dominant_lambda( + after_comp, normalize_all_equal_bit=False + ) # Check that the intrinsic comp consists of a block containing locals that are # exclusively calls for the allowed intrinsics and that the results are