diff --git a/tensorflow_federated/python/core/impl/compiler/transformations.py b/tensorflow_federated/python/core/impl/compiler/transformations.py index cf90241d75..925f559a43 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformations.py +++ b/tensorflow_federated/python/core/impl/compiler/transformations.py @@ -201,11 +201,13 @@ def _build(comp, scope): def get_normalized_call_dominant_lambda( comp: building_blocks.Lambda, + normalize_all_equal_bit: bool = True, ) -> building_blocks.Lambda: """Creates normalized call dominant form for a lambda computation. Args: comp: A computation to normalize. + normalize_all_equal_bit: Whether to normalize the all-equal bit. Returns: A transformed but semantically-equivalent `comp`. The result will be a @@ -258,7 +260,9 @@ def get_normalized_call_dominant_lambda( ) comp.result.check_block() - comp = tree_transformations.normalize_all_equal_bit(comp) + if normalize_all_equal_bit: + comp = tree_transformations.normalize_all_equal_bit(comp) + tree_analysis.check_contains_no_unbound_references(comp) return comp @@ -1474,9 +1478,15 @@ def divisive_force_align_and_split_by_intrinsics( ############################### Step 8 ###################################### # Normalize all of the output computations. - before_comp = get_normalized_call_dominant_lambda(before_comp) - intrinsic_comp = get_normalized_call_dominant_lambda(intrinsic_comp) - after_comp = get_normalized_call_dominant_lambda(after_comp) + before_comp = get_normalized_call_dominant_lambda( + before_comp, normalize_all_equal_bit=False + ) + intrinsic_comp = get_normalized_call_dominant_lambda( + intrinsic_comp, normalize_all_equal_bit=False + ) + after_comp = get_normalized_call_dominant_lambda( + after_comp, normalize_all_equal_bit=False + ) # Check that the intrinsic comp consists of a block containing locals that are # exclusively calls for the allowed intrinsics and that the results are