From 3e8cf294933ee8bae1d731a49f04a3b4d0f9d58e Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Mon, 17 Oct 2022 12:06:49 -0700 Subject: [PATCH] Move `ModelWeights` class into the learning/models/ sub-directory. Deprecated the `tff.learning.ModelWeights` and `tff.learning.framework.ModelWeights` APIs, replaced by `tff.learning.models.ModelWeights`. PiperOrigin-RevId: 481704926 --- tensorflow_federated/python/learning/BUILD | 39 ++++--------------- .../python/learning/__init__.py | 3 +- .../python/learning/algorithms/BUILD | 20 +++++----- .../python/learning/algorithms/fed_avg.py | 6 +-- .../fed_avg_with_optimizer_schedule.py | 6 +-- .../python/learning/algorithms/fed_eval.py | 4 +- .../learning/algorithms/fed_eval_test.py | 12 +++--- .../python/learning/algorithms/fed_prox.py | 6 +-- .../learning/algorithms/fed_prox_test.py | 5 ++- .../python/learning/algorithms/fed_sgd.py | 12 +++--- .../learning/algorithms/fed_sgd_test.py | 6 +-- .../python/learning/algorithms/mime.py | 24 ++++++------ .../python/learning/algorithms/mime_test.py | 12 +++--- .../python/learning/federated_evaluation.py | 6 +-- .../learning/federated_evaluation_test.py | 18 ++++----- .../python/learning/framework/BUILD | 7 ++-- .../python/learning/framework/__init__.py | 10 ++++- .../learning/framework/optimizer_utils.py | 12 +++--- .../framework/optimizer_utils_test.py | 24 ++++++------ .../python/learning/keras_utils_test.py | 12 +++--- .../python/learning/models/BUILD | 38 ++++++++++++++++++ .../python/learning/models/__init__.py | 2 + .../model_weights.py} | 0 .../model_weights_test.py} | 36 ++++++++--------- .../python/learning/personalization_eval.py | 10 ++--- .../learning/personalization_eval_test.py | 4 +- .../python/learning/reconstruction/BUILD | 6 +-- .../reconstruction/reconstruction_utils.py | 10 ++--- .../reconstruction_utils_test.py | 6 +-- .../reconstruction/training_process_test.py | 6 +-- .../python/learning/templates/BUILD | 20 +++++----- .../templates/apply_optimizer_finalizer.py | 16 ++++---- .../apply_optimizer_finalizer_test.py | 30 +++++++------- .../learning/templates/client_works_test.py | 4 +- .../python/learning/templates/composers.py | 4 +- .../learning/templates/composers_test.py | 18 +++++---- .../learning/templates/finalizers_test.py | 12 +++--- .../templates/model_delta_client_work.py | 10 ++--- .../templates/model_delta_client_work_test.py | 12 +++--- .../templates/proximal_client_work.py | 10 ++--- .../templates/proximal_client_work_test.py | 8 ++-- 41 files changed, 269 insertions(+), 237 deletions(-) rename tensorflow_federated/python/learning/{model_utils.py => models/model_weights.py} (100%) rename tensorflow_federated/python/learning/{model_utils_test.py => models/model_weights_test.py} (88%) diff --git a/tensorflow_federated/python/learning/BUILD b/tensorflow_federated/python/learning/BUILD index 657fc645a3..c9ab63f45d 100644 --- a/tensorflow_federated/python/learning/BUILD +++ b/tensorflow_federated/python/learning/BUILD @@ -32,13 +32,14 @@ py_library( ":keras_utils", ":model", ":model_update_aggregator", - ":model_utils", ":personalization_eval", + "//tensorflow_federated/python/common_libs:deprecation", "//tensorflow_federated/python/learning/algorithms", "//tensorflow_federated/python/learning/framework", "//tensorflow_federated/python/learning/framework:optimizer_utils", "//tensorflow_federated/python/learning/metrics", "//tensorflow_federated/python/learning/models", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers", "//tensorflow_federated/python/learning/reconstruction", "//tensorflow_federated/python/learning/templates", @@ -84,7 +85,6 @@ py_library( srcs_version = "PY3", deps = [ ":model", - ":model_utils", "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/core/impl/federated_context:federated_computation", "//tensorflow_federated/python/core/impl/federated_context:intrinsics", @@ -95,6 +95,7 @@ py_library( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -107,7 +108,6 @@ py_cpu_gpu_test( ":federated_evaluation", ":keras_utils", ":model", - ":model_utils", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/impl/federated_context:federated_computation", "//tensorflow_federated/python/core/impl/federated_context:intrinsics", @@ -119,6 +119,7 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/tensorflow_libs:tensorflow_test_utils", ], ) @@ -150,7 +151,6 @@ py_test( ":keras_utils", ":model", ":model_examples", - ":model_utils", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/impl/computation:computation_base", "//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation", @@ -158,6 +158,7 @@ py_test( "//tensorflow_federated/python/core/impl/types:type_conversions", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/metrics:counters", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -187,19 +188,6 @@ py_test( ], ) -py_library( - name = "model_utils", - srcs = ["model_utils.py"], - srcs_version = "PY3", - deps = [ - ":model", - "//tensorflow_federated/python/common_libs:py_typecheck", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", - "//tensorflow_federated/python/core/impl/types:type_conversions", - ], -) - py_library( name = "model_update_aggregator", srcs = ["model_update_aggregator.py"], @@ -248,7 +236,6 @@ py_library( srcs = ["personalization_eval.py"], srcs_version = "PY3", deps = [ - ":model_utils", "//tensorflow_federated/python/aggregators:sampling", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", @@ -257,19 +244,7 @@ py_library( "//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation", "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/impl/types:placements", - ], -) - -py_test( - name = "model_utils_test", - srcs = ["model_utils_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [ - ":model", - ":model_utils", - "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/types:computation_types", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -282,10 +257,10 @@ py_cpu_gpu_test( deps = [ ":keras_utils", ":model_examples", - ":model_utils", ":personalization_eval", "//tensorflow_federated/python/core/backends/native:execution_contexts", "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning/framework:dataset_reduce", + "//tensorflow_federated/python/learning/models:model_weights", ], ) diff --git a/tensorflow_federated/python/learning/__init__.py b/tensorflow_federated/python/learning/__init__.py index 4ca53a2398..f69392054c 100644 --- a/tensorflow_federated/python/learning/__init__.py +++ b/tensorflow_federated/python/learning/__init__.py @@ -35,6 +35,7 @@ `tff.learning.models` for related model classes. """ +from tensorflow_federated.python.common_libs import deprecation from tensorflow_federated.python.learning import algorithms from tensorflow_federated.python.learning import framework from tensorflow_federated.python.learning import metrics @@ -59,5 +60,5 @@ from tensorflow_federated.python.learning.model_update_aggregator import entropy_compression_aggregator from tensorflow_federated.python.learning.model_update_aggregator import robust_aggregator from tensorflow_federated.python.learning.model_update_aggregator import secure_aggregator -from tensorflow_federated.python.learning.model_utils import ModelWeights +from tensorflow_federated.python.learning.models.model_weights import ModelWeights from tensorflow_federated.python.learning.personalization_eval import build_personalization_eval diff --git a/tensorflow_federated/python/learning/algorithms/BUILD b/tensorflow_federated/python/learning/algorithms/BUILD index 88c235f620..c4f817e230 100644 --- a/tensorflow_federated/python/learning/algorithms/BUILD +++ b/tensorflow_federated/python/learning/algorithms/BUILD @@ -43,9 +43,9 @@ py_library( "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer", "//tensorflow_federated/python/learning/templates:composers", @@ -94,8 +94,8 @@ py_library( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:aggregator", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer", "//tensorflow_federated/python/learning/templates:client_works", @@ -138,9 +138,9 @@ py_library( "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer", "//tensorflow_federated/python/learning/templates:composers", @@ -165,9 +165,9 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:model_examples", "//tensorflow_federated/python/learning:model_update_aggregator", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", "//tensorflow_federated/python/learning/templates:distributors", @@ -191,10 +191,10 @@ py_library( "//tensorflow_federated/python/core/impl/types:type_conversions", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer", "//tensorflow_federated/python/learning/templates:client_works", @@ -217,10 +217,10 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/test:static_assert", "//tensorflow_federated/python/learning:model_examples", "//tensorflow_federated/python/learning:model_update_aggregator", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/tensorflow_libs:tensorflow_test_utils", ], @@ -279,10 +279,10 @@ py_library( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/optimizers:sgdm", "//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer", @@ -319,11 +319,11 @@ py_cpu_gpu_test( "//tensorflow_federated/python/learning:keras_utils", "//tensorflow_federated/python/learning:model_examples", "//tensorflow_federated/python/learning:model_update_aggregator", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/metrics:counters", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:adagrad", "//tensorflow_federated/python/learning/optimizers:adam", @@ -354,8 +354,8 @@ py_library( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:federated_evaluation", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:aggregation_factory", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/templates:client_works", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", @@ -383,9 +383,9 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:aggregation_factory", "//tensorflow_federated/python/learning/metrics:aggregator", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/templates:composers", "//tensorflow_federated/python/learning/templates:distributors", "//tensorflow_federated/python/learning/templates:learning_process", diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg.py b/tensorflow_federated/python/learning/algorithms/fed_avg.py index edc36a35aa..211a3db81d 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg.py @@ -46,9 +46,9 @@ from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import composers @@ -172,7 +172,7 @@ def build_weighted_fed_avg( @tensorflow_computation.tf_computation() def initial_model_weights_fn(): trainable_weights, non_trainable_weights = model_fn.initial_weights - return model_utils.ModelWeights( + return model_weights.ModelWeights( tuple(tf.convert_to_tensor(w) for w in trainable_weights), tuple(tf.convert_to_tensor(w) for w in non_trainable_weights)) @@ -186,7 +186,7 @@ def initial_model_weights_fn(): raise TypeError('When `model_fn` is a callable, it return instances of ' 'tff.learning.Model. Instead callable returned type: ' f'{type(model)}') - return model_utils.ModelWeights.from_model(model) + return model_weights.ModelWeights.from_model(model) model_weights_type = initial_model_weights_fn.type_signature.result diff --git a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py index bbffebdbea..b02f26be86 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py +++ b/tensorflow_federated/python/learning/algorithms/fed_avg_with_optimizer_schedule.py @@ -36,9 +36,9 @@ from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.algorithms import fed_avg from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import client_works @@ -101,7 +101,7 @@ def build_scheduled_client_work( metrics_aggregation_fn = metrics_aggregator( whimsy_model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(whimsy_model.input_spec) - weights_type = model_utils.weights_type_from_model(whimsy_model) + weights_type = model_weights.weights_type_from_model(whimsy_model) if isinstance(whimsy_optimizer, optimizer_base.Optimizer): build_client_update_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer @@ -249,7 +249,7 @@ def build_weighted_fed_avg_with_optimizer_schedule( @tensorflow_computation.tf_computation() def initial_model_weights_fn(): - return model_utils.ModelWeights.from_model(model_fn()) + return model_weights.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result diff --git a/tensorflow_federated/python/learning/algorithms/fed_eval.py b/tensorflow_federated/python/learning/algorithms/fed_eval.py index 52a0bceef6..c21388089e 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_eval.py +++ b/tensorflow_federated/python/learning/algorithms/fed_eval.py @@ -35,8 +35,8 @@ from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import federated_evaluation from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.metrics import aggregation_factory +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import distributors @@ -198,7 +198,7 @@ def build_fed_eval( @tensorflow_computation.tf_computation() def initial_model_weights_fn(): - return model_utils.ModelWeights.from_model(model_fn()) + return model_weights_lib.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result diff --git a/tensorflow_federated/python/learning/algorithms/fed_eval_test.py b/tensorflow_federated/python/learning/algorithms/fed_eval_test.py index ff689359e0..0ce34852e9 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_eval_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_eval_test.py @@ -31,10 +31,10 @@ from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.algorithms import fed_eval from tensorflow_federated.python.learning.metrics import aggregation_factory from tensorflow_federated.python.learning.metrics import aggregator +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import distributors from tensorflow_federated.python.learning.templates import learning_process @@ -171,7 +171,7 @@ class FedEvalProcessTest(tf.test.TestCase): def test_fed_eval_process_type_properties(self): model_fn = TestModel test_model = model_fn() - model_weights_type = model_utils.weights_type_from_model(test_model) + model_weights_type = model_weights_lib.weights_type_from_model(test_model) metric_finalizers = test_model.metric_finalizers() unfinalized_metrics = test_model.report_local_unfinalized_metrics() local_unfinalized_metrics_type = type_conversions.type_from_tensors( @@ -239,9 +239,11 @@ def test_fed_eval_process_execution(self): # Update the state with the model weights to be evaluated, and verify that # the `get_model_weights` method returns the same model weights. state = eval_process.initialize() - model_weights = model_utils.ModelWeights(trainable=[5.0], non_trainable=[]) + model_weights = model_weights_lib.ModelWeights( + trainable=[5.0], non_trainable=[]) new_state = eval_process.set_model_weights( - state, model_utils.ModelWeights(trainable=[5.0], non_trainable=[])) + state, + model_weights_lib.ModelWeights(trainable=[5.0], non_trainable=[])) tf.nest.map_structure(self.assertAllEqual, model_weights, eval_process.get_model_weights(new_state)) @@ -269,7 +271,7 @@ def _temp_dict(temps): total_rounds_metrics=collections.OrderedDict(num_over=9.0)))) def test_fed_eval_with_model_distributor(self): - model_weights_type = model_utils.weights_type_from_model(TestModel) + model_weights_type = model_weights_lib.weights_type_from_model(TestModel) def test_distributor(): diff --git a/tensorflow_federated/python/learning/algorithms/fed_prox.py b/tensorflow_federated/python/learning/algorithms/fed_prox.py index d7d85b8c65..89d54708bd 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_prox.py +++ b/tensorflow_federated/python/learning/algorithms/fed_prox.py @@ -38,9 +38,9 @@ from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import composers @@ -175,7 +175,7 @@ def build_weighted_fed_prox( @tensorflow_computation.tf_computation() def initial_model_weights_fn(): trainable_weights, non_trainable_weights = model_fn.initial_weights - return model_utils.ModelWeights( + return model_weights.ModelWeights( tuple(tf.convert_to_tensor(w) for w in trainable_weights), tuple(tf.convert_to_tensor(w) for w in non_trainable_weights)) @@ -189,7 +189,7 @@ def initial_model_weights_fn(): raise TypeError('When `model_fn` is a callable, it returns instances of' ' tff.learning.Model. Instead callable returned type: ' f'{type(model)}') - return model_utils.ModelWeights.from_model(model) + return model_weights.ModelWeights.from_model(model) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: diff --git a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py index 7d3ee77b02..da598c1e95 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_prox_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_prox_test.py @@ -25,10 +25,10 @@ from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import model_examples from tensorflow_federated.python.learning import model_update_aggregator -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.algorithms import fed_prox from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import test_models from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import distributors @@ -137,7 +137,8 @@ def test_raises_on_negative_proximal_strength(self): def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( - model_utils.ModelWeights.from_model(model_examples.LinearRegression())) + model_weights.ModelWeights.from_model( + model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) diff --git a/tensorflow_federated/python/learning/algorithms/fed_sgd.py b/tensorflow_federated/python/learning/algorithms/fed_sgd.py index d69f5fa55b..2b85774c69 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_sgd.py +++ b/tensorflow_federated/python/learning/algorithms/fed_sgd.py @@ -43,10 +43,10 @@ from tensorflow_federated.python.core.impl.types import type_conversions from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import client_works @@ -73,7 +73,7 @@ def _build_client_update(model: model_lib.Model, @tf.function def client_update(initial_weights, dataset): - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) @@ -161,7 +161,7 @@ def _build_fed_sgd_client_work( metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) - weights_type = model_utils.weights_type_from_model(model) + weights_type = model_weights_lib.weights_type_from_model(model) @federated_computation.federated_computation def init_fn(): @@ -293,7 +293,7 @@ def ndarray_to_tensorspec(ndarray): # Wrap in a `ModelWeights` structure that is required by the `finalizer.` trainable_weights, non_trainable_weights = model.initial_weights - weights_type = model_utils.ModelWeights( + weights_type = model_weights_lib.ModelWeights( tuple(ndarray_to_tensorspec(w) for w in trainable_weights), tuple(ndarray_to_tensorspec(w) for w in non_trainable_weights)) @@ -411,7 +411,7 @@ def build_fed_sgd( @tensorflow_computation.tf_computation() def initial_model_weights_fn(): trainable_weights, non_trainable_weights = model_fn.initial_weights - return model_utils.ModelWeights( + return model_weights_lib.ModelWeights( tuple(tf.convert_to_tensor(w) for w in trainable_weights), tuple(tf.convert_to_tensor(w) for w in non_trainable_weights)) @@ -425,7 +425,7 @@ def initial_model_weights_fn(): raise TypeError('When `model_fn` is a callable, it returns instances of' ' tff.learning.Model. Instead callable returned type: ' f'{type(model)}') - return model_utils.ModelWeights.from_model(model) + return model_weights_lib.ModelWeights.from_model(model) model_weights_type = initial_model_weights_fn.type_signature.result diff --git a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py index 878dd8accf..2ad0eb0aca 100644 --- a/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py +++ b/tensorflow_federated/python/learning/algorithms/fed_sgd_test.py @@ -22,11 +22,11 @@ from tensorflow_federated.python.core.test import static_assert from tensorflow_federated.python.learning import model_examples from tensorflow_federated.python.learning import model_update_aggregator -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.algorithms import fed_sgd from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import test_models from tensorflow_federated.python.tensorflow_libs import tensorflow_test_utils @@ -52,8 +52,8 @@ def _build_functional_model() -> functional.FunctionalModel: return test_models.build_functional_linear_regression(feature_dim=2) -def _initial_weights() -> model_utils.ModelWeights: - return model_utils.ModelWeights( +def _initial_weights() -> model_weights.ModelWeights: + return model_weights.ModelWeights( trainable=[tf.constant([[0.0], [0.0]]), tf.constant(0.0)], non_trainable=[0.0]) diff --git a/tensorflow_federated/python/learning/algorithms/mime.py b/tensorflow_federated/python/learning/algorithms/mime.py index feec1a63ab..1ad1d7e3f0 100644 --- a/tensorflow_federated/python/learning/algorithms/mime.py +++ b/tensorflow_federated/python/learning/algorithms/mime.py @@ -46,10 +46,10 @@ from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer @@ -85,11 +85,11 @@ def client_update_fn(global_optimizer_state, initial_weights, data): dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs( - model_utils.weights_type_from_model(model)) + model_weights_lib.weights_type_from_model(model)) @tf.function def client_update(global_optimizer_state, initial_weights, data): - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) @@ -227,7 +227,7 @@ def _build_mime_lite_client_work( metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) - weights_type = model_utils.weights_type_from_model(model) + weights_type = model_weights_lib.weights_type_from_model(model) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(weights_type) full_gradient_aggregator = full_gradient_aggregator.create( @@ -299,7 +299,7 @@ def client_update_fn(global_optimizer_state, incoming_weights, data): @tf.function def client_update(global_optimizer_state: Any, - incoming_weights: model_utils.ModelWeights, + incoming_weights: model_weights_lib.ModelWeights, data: tf.data.Dataset) -> Any: trainable_weights, _ = incoming_weights @@ -380,8 +380,8 @@ def initial_training_weights(): update=client_weights_delta, update_weight=client_weight), unfinalized_metrics, full_gradient - # Convert `tff.learning.ModelWeights` type weights back into the initial - # shape used by the model. + # Convert `tff.learning.models.ModelWeights` type weights back into the + # initial shape used by the model. incoming_weights = (incoming_weights.trainable, incoming_weights.non_trainable) return client_update(global_optimizer_state, incoming_weights, data) @@ -440,7 +440,7 @@ def _build_mime_lite_functional_client_work( data_type = computation_types.SequenceType(model.input_spec) trainable_weights, non_trainable_weights = model.initial_weights weights_type = type_conversions.infer_type( - model_utils.ModelWeights( + model_weights_lib.ModelWeights( tuple(trainable_weights), tuple(non_trainable_weights))) weight_tensor_specs = type_conversions.type_to_tf_tensor_specs(weights_type) @@ -726,7 +726,7 @@ def build_weighted_mime_lite( @tensorflow_computation.tf_computation def initial_model_weights_fn(): trainable_weights, non_trainable_weights = model_fn.initial_weights - return model_utils.ModelWeights( + return model_weights_lib.ModelWeights( tuple(tf.convert_to_tensor(w) for w in trainable_weights), tuple(tf.convert_to_tensor(w) for w in non_trainable_weights)) @@ -739,7 +739,7 @@ def initial_model_weights_fn(): raise TypeError('When `model_fn` is a callable, it returns instances of' ' tff.learning.Model. Instead callable returned type: ' f'{type(model)}') - return model_utils.ModelWeights.from_model(model) + return model_weights_lib.ModelWeights.from_model(model) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: @@ -1029,7 +1029,7 @@ def build_mime_lite_with_optimizer_schedule( @tensorflow_computation.tf_computation def initial_model_weights_fn(): trainable_weights, non_trainable_weights = model_fn.initial_weights - return model_utils.ModelWeights( + return model_weights_lib.ModelWeights( tuple(tf.convert_to_tensor(w) for w in trainable_weights), tuple(tf.convert_to_tensor(w) for w in non_trainable_weights)) @@ -1042,7 +1042,7 @@ def initial_model_weights_fn(): raise TypeError('When `model_fn` is a callable, it returns instances of' ' tff.learning.Model. Instead callable returned type: ' f'{type(model)}') - return model_utils.ModelWeights.from_model(model) + return model_weights_lib.ModelWeights.from_model(model) model_weights_type = initial_model_weights_fn.type_signature.result if model_distributor is None: diff --git a/tensorflow_federated/python/learning/algorithms/mime_test.py b/tensorflow_federated/python/learning/algorithms/mime_test.py index c61dd810fd..6e3b834bf6 100644 --- a/tensorflow_federated/python/learning/algorithms/mime_test.py +++ b/tensorflow_federated/python/learning/algorithms/mime_test.py @@ -34,13 +34,13 @@ from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples from tensorflow_federated.python.learning import model_update_aggregator -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.algorithms import fed_avg from tensorflow_federated.python.learning.algorithms import mime from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator as metrics_aggregator from tensorflow_federated.python.learning.metrics import counters from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import test_models from tensorflow_federated.python.learning.optimizers import adagrad from tensorflow_federated.python.learning.optimizers import adam @@ -66,7 +66,7 @@ def test_type_properties(self, weighting): model_fn, optimizer, weighting) self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) - mw_type = model_utils.ModelWeights( + mw_type = model_weights.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients(mw_type) @@ -133,7 +133,7 @@ def _create_model(): def _initial_weights(): - return model_utils.ModelWeights( + return model_weights.ModelWeights( trainable=[tf.zeros((2, 1)), tf.constant(0.0)], non_trainable=[0.0]) @@ -375,7 +375,8 @@ def test_raises_on_invalid_client_weighting(self): def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( - model_utils.ModelWeights.from_model(model_examples.LinearRegression())) + model_weights.ModelWeights.from_model( + model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) @@ -511,7 +512,8 @@ def test_raises_on_invalid_client_weighting(self): def test_raises_on_invalid_distributor(self): model_weights_type = type_conversions.type_from_tensors( - model_utils.ModelWeights.from_model(model_examples.LinearRegression())) + model_weights.ModelWeights.from_model( + model_examples.LinearRegression())) distributor = distributors.build_broadcast_process(model_weights_type) invalid_distributor = iterative_process.IterativeProcess( distributor.initialize, distributor.next) diff --git a/tensorflow_federated/python/learning/federated_evaluation.py b/tensorflow_federated/python/learning/federated_evaluation.py index e62585aa2a..38fd5aab3b 100644 --- a/tensorflow_federated/python/learning/federated_evaluation.py +++ b/tensorflow_federated/python/learning/federated_evaluation.py @@ -32,9 +32,9 @@ from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib # Convenience aliases. SequenceType = computation_types.SequenceType @@ -81,7 +81,7 @@ def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" with tf.init_scope(): model = model_fn() - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, incoming_model_weights) @@ -158,7 +158,7 @@ def build_federated_evaluation( # with some other mechanism. with tf.Graph().as_default(): model = model_fn() - model_weights_type = model_utils.weights_type_from_model(model) + model_weights_type = model_weights_lib.weights_type_from_model(model) batch_type = computation_types.to_type(model.input_spec) unfinalized_metrics_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()) diff --git a/tensorflow_federated/python/learning/federated_evaluation_test.py b/tensorflow_federated/python/learning/federated_evaluation_test.py index 0d11371b56..ebed9cd44b 100644 --- a/tensorflow_federated/python/learning/federated_evaluation_test.py +++ b/tensorflow_federated/python/learning/federated_evaluation_test.py @@ -32,9 +32,9 @@ from tensorflow_federated.python.learning import federated_evaluation from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.tensorflow_libs import tensorflow_test_utils from tensorflow_model_optimization.python.core.internal import tensor_encoding as te @@ -209,9 +209,9 @@ def _build_expected_broadcaster_next_signature(): (), ]), ('non_trainable', [])])) value_type = computation_types.at_server( - model_utils.weights_type_from_model(TestModelQuant)) + model_weights.weights_type_from_model(TestModelQuant)) result_type = computation_types.at_clients( - model_utils.weights_type_from_model(TestModelQuant)) + model_weights.weights_type_from_model(TestModelQuant)) measurements_type = computation_types.at_server(()) return computation_types.FunctionType( parameter=collections.OrderedDict(state=state_type, value=value_type), @@ -222,7 +222,7 @@ def _build_expected_broadcaster_next_signature(): def _build_expected_test_quant_model_eval_signature(): """Returns signature for build_federated_evaluation using TestModelQuant.""" weights_parameter_type = computation_types.at_server( - model_utils.weights_type_from_model(TestModelQuant)) + model_weights.weights_type_from_model(TestModelQuant)) data_parameter_type = computation_types.at_clients( computation_types.SequenceType( collections.OrderedDict( @@ -242,7 +242,7 @@ class FederatedEvaluationTest(parameterized.TestCase): @tensorflow_test_utils.skip_test_for_multi_gpu def test_local_evaluation(self): - model_weights_type = model_utils.weights_type_from_model(TestModel) + model_weights_type = model_weights.weights_type_from_model(TestModel) batch_type = computation_types.to_type(TestModel().input_spec) client_evaluate = federated_evaluation.build_local_evaluation( TestModel, model_weights_type, batch_type) @@ -277,7 +277,7 @@ def _temp_dict(temps): @tensorflow_test_utils.skip_test_for_multi_gpu def test_federated_evaluation(self): evaluate = federated_evaluation.build_federated_evaluation(TestModel) - model_weights_type = model_utils.weights_type_from_model(TestModel) + model_weights_type = model_weights.weights_type_from_model(TestModel) type_test_utils.assert_types_equivalent( evaluate.type_signature, FunctionType( @@ -347,7 +347,7 @@ def test_federated_evaluation_with_keras(self, simulation): _model_fn_from_keras, use_experimental_simulation_loop=simulation) initial_weights = tf.nest.map_structure( lambda x: x.read_value(), - model_utils.ModelWeights.from_model(_model_fn_from_keras())) + model_weights.ModelWeights.from_model(_model_fn_from_keras())) def _input_dict(temps): return collections.OrderedDict( @@ -377,7 +377,7 @@ def test_federated_evaluation_dataset_reduce(self, mock_method): _model_fn_from_keras, use_experimental_simulation_loop=False) initial_weights = tf.nest.map_structure( lambda x: x.read_value(), - model_utils.ModelWeights.from_model(_model_fn_from_keras())) + model_weights.ModelWeights.from_model(_model_fn_from_keras())) def _input_dict(temps): return collections.OrderedDict( @@ -402,7 +402,7 @@ def test_federated_evaluation_simulation_loop(self, mock_method): _model_fn_from_keras, use_experimental_simulation_loop=True) initial_weights = tf.nest.map_structure( lambda x: x.read_value(), - model_utils.ModelWeights.from_model(_model_fn_from_keras())) + model_weights.ModelWeights.from_model(_model_fn_from_keras())) def _input_dict(temps): return collections.OrderedDict( diff --git a/tensorflow_federated/python/learning/framework/BUILD b/tensorflow_federated/python/learning/framework/BUILD index a20ee364be..ae07635b44 100644 --- a/tensorflow_federated/python/learning/framework/BUILD +++ b/tensorflow_federated/python/learning/framework/BUILD @@ -27,7 +27,8 @@ py_library( visibility = ["//tensorflow_federated/python/learning:__pkg__"], deps = [ ":optimizer_utils", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/common_libs:deprecation", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -63,8 +64,8 @@ py_library( "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:aggregator", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:keras_optimizer", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/tensorflow_libs:tensor_utils", @@ -95,7 +96,7 @@ py_cpu_gpu_test( "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:model_examples", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/optimizers:sgdm", "//tensorflow_federated/python/tensorflow_libs:tensorflow_test_utils", diff --git a/tensorflow_federated/python/learning/framework/__init__.py b/tensorflow_federated/python/learning/framework/__init__.py index 28654754e1..f12cbbc04a 100644 --- a/tensorflow_federated/python/learning/framework/__init__.py +++ b/tensorflow_federated/python/learning/framework/__init__.py @@ -13,10 +13,16 @@ # limitations under the License. """Libraries for developing federated learning algorithms.""" +from tensorflow_federated.python.common_libs import deprecation from tensorflow_federated.python.learning.framework.optimizer_utils import build_model_delta_optimizer_process from tensorflow_federated.python.learning.framework.optimizer_utils import build_stateless_broadcaster from tensorflow_federated.python.learning.framework.optimizer_utils import ClientDeltaFn from tensorflow_federated.python.learning.framework.optimizer_utils import ClientOutput from tensorflow_federated.python.learning.framework.optimizer_utils import ServerState -from tensorflow_federated.python.learning.model_utils import ModelWeights -from tensorflow_federated.python.learning.model_utils import weights_type_from_model +from tensorflow_federated.python.learning.models.model_weights import ModelWeights +from tensorflow_federated.python.learning.models.model_weights import weights_type_from_model + +weights_type_from_model = deprecation.deprecated( + weights_type_from_model, + '`tff.learning.framework.weights_type_from_model` is deprecated, use ' + '`tff.learning.models.weights_type_from_model`.') diff --git a/tensorflow_federated/python/learning/framework/optimizer_utils.py b/tensorflow_federated/python/learning/framework/optimizer_utils.py index b9df029766..e46e49446e 100644 --- a/tensorflow_federated/python/learning/framework/optimizer_utils.py +++ b/tensorflow_federated/python/learning/framework/optimizer_utils.py @@ -40,8 +40,8 @@ from tensorflow_federated.python.core.templates import iterative_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.metrics import aggregator +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import keras_optimizer from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.tensorflow_libs import tensor_utils @@ -200,7 +200,7 @@ def assert_weight_lists_match(old_value, new_value): assert_weight_lists_match(server_state.model.non_trainable, non_trainable_weights) new_server_state = ServerState( - model=model_utils.ModelWeights( + model=model_weights_lib.ModelWeights( trainable=trainable_weights, non_trainable=non_trainable_weights), optimizer_state=server_state.optimizer_state, delta_aggregate_state=server_state.delta_aggregate_state, @@ -211,7 +211,7 @@ def assert_weight_lists_match(old_value, new_value): def _apply_delta( *, optimizer: tf.keras.optimizers.Optimizer, - model_variables: model_utils.ModelWeights, + model_variables: model_weights_lib.ModelWeights, delta, ) -> None: """Applies `delta` to `model` using `optimizer`.""" @@ -319,7 +319,7 @@ def _build_one_round_computation( # should re-evaluate what happens here. with tf.Graph().as_default(): whimsy_model_for_metadata = model_fn() - model_weights = model_utils.ModelWeights.from_model( + model_weights = model_weights_lib.ModelWeights.from_model( whimsy_model_for_metadata) model_weights_type = type_conversions.type_from_tensors(model_weights) @@ -588,9 +588,9 @@ def build_model_delta_optimizer_process( @tensorflow_computation.tf_computation def model_and_optimizer_init_fn( - ) -> Tuple[model_utils.ModelWeights, List[tf.Variable]]: + ) -> Tuple[model_weights_lib.ModelWeights, List[tf.Variable]]: """Returns initial model weights and state of the global optimizer.""" - model_variables = model_utils.ModelWeights.from_model(model_fn()) + model_variables = model_weights_lib.ModelWeights.from_model(model_fn()) optimizer = keras_optimizer.build_or_verify_tff_optimizer( server_optimizer_fn, model_variables.trainable, diff --git a/tensorflow_federated/python/learning/framework/optimizer_utils_test.py b/tensorflow_federated/python/learning/framework/optimizer_utils_test.py index 7ef23f1dad..79c398ab4c 100644 --- a/tensorflow_federated/python/learning/framework/optimizer_utils_test.py +++ b/tensorflow_federated/python/learning/framework/optimizer_utils_test.py @@ -35,8 +35,8 @@ from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model_examples -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import optimizer_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import optimizer from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.tensorflow_libs import tensorflow_test_utils @@ -144,7 +144,7 @@ def test_state_with_model_weights_success(self): new_non_trainable = [np.array(3), b'bytes check', 6, 3.0] state = optimizer_utils.ServerState( - model=model_utils.ModelWeights( + model=model_weights.ModelWeights( trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), @@ -174,7 +174,7 @@ def test_state_with_new_model_weights_failure(self, new_trainable, trainable = [np.array([1.0, 2.0]), np.array([[1.0]]), np.int64(3)] non_trainable = [np.array(1), b'bytes type', 5, 2.0] state = optimizer_utils.ServerState( - model=model_utils.ModelWeights( + model=model_weights.ModelWeights( trainable=trainable, non_trainable=non_trainable), optimizer_state=[], delta_aggregate_state=tf.constant(0), @@ -285,7 +285,7 @@ def test_construction(self, weighted): server_state_type = computation_types.FederatedType( optimizer_utils.ServerState( - model=model_utils.ModelWeights( + model=model_weights.ModelWeights( trainable=[ computation_types.TensorType(tf.float32, [2, 1]), computation_types.TensorType(tf.float32) @@ -344,20 +344,20 @@ def test_initial_weights_pulled_from_model(self, server_optimizer): def _model_fn_with_zero_weights(): linear_regression_model = model_examples.LinearRegression - weights = model_utils.ModelWeights.from_model(linear_regression_model) + weights = model_weights.ModelWeights.from_model(linear_regression_model) zero_trainable = [tf.zeros_like(x) for x in weights.trainable] zero_non_trainable = [tf.zeros_like(x) for x in weights.non_trainable] - zero_weights = model_utils.ModelWeights( + zero_weights = model_weights.ModelWeights( trainable=zero_trainable, non_trainable=zero_non_trainable) zero_weights.assign_weights_to(linear_regression_model) return linear_regression_model def _model_fn_with_one_weights(): linear_regression_model = model_examples.LinearRegression - weights = model_utils.ModelWeights.from_model(linear_regression_model) + weights = model_weights.ModelWeights.from_model(linear_regression_model) ones_trainable = [tf.ones_like(x) for x in weights.trainable] ones_non_trainable = [tf.ones_like(x) for x in weights.non_trainable] - ones_weights = model_utils.ModelWeights( + ones_weights = model_weights.ModelWeights( trainable=ones_trainable, non_trainable=ones_non_trainable) ones_weights.assign_weights_to(linear_regression_model) return linear_regression_model @@ -439,7 +439,7 @@ def test_construction_with_tff_sgdm_optimizer(self, momentum, state_len): ('keras_optimizer', _keras_optimizer_fn), ]) def test_construction_with_aggregation_process(self, server_optimizer): - model_update_type = model_utils.weights_type_from_model( + model_update_type = model_weights.weights_type_from_model( model_examples.LinearRegression).trainable model_update_aggregator = TestMeasuredMeanFactory() iterative_process = optimizer_utils.build_model_delta_optimizer_process( @@ -476,7 +476,7 @@ def test_construction_with_aggregation_process(self, server_optimizer): @parameterized.named_parameters([('tff_optimizer', _tff_optimizer), ('keras_optimizer', _keras_optimizer_fn)]) def test_construction_with_broadcast_process(self, server_optimizer): - model_weights_type = model_utils.weights_type_from_model( + model_weights_type = model_weights.weights_type_from_model( model_examples.LinearRegression) broadcast_process = _build_test_measured_broadcast(model_weights_type) iterative_process = optimizer_utils.build_model_delta_optimizer_process( @@ -506,7 +506,7 @@ def test_construction_with_broadcast_process(self, server_optimizer): ('keras_optimizer', _keras_optimizer_fn)]) @tensorflow_test_utils.skip_test_for_multi_gpu def test_orchestration_execute_measured_process(self, server_optimizer): - model_weights_type = model_utils.weights_type_from_model( + model_weights_type = model_weights.weights_type_from_model( model_examples.LinearRegression) learning_rate = 1.0 server_optimizer_fn = server_optimizer(learning_rate) @@ -569,7 +569,7 @@ def test_orchestration_execute_measured_process(self, server_optimizer): @tensorflow_test_utils.skip_test_for_multi_gpu def test_execute_measured_process_with_custom_metrics_aggregator( self, server_optimizer): - model_weights_type = model_utils.weights_type_from_model( + model_weights_type = model_weights.weights_type_from_model( model_examples.LinearRegression) learning_rate = 1.0 server_optimizer_fn = server_optimizer(learning_rate) diff --git a/tensorflow_federated/python/learning/keras_utils_test.py b/tensorflow_federated/python/learning/keras_utils_test.py index 856ba74b1a..25d917d5ce 100644 --- a/tensorflow_federated/python/learning/keras_utils_test.py +++ b/tensorflow_federated/python/learning/keras_utils_test.py @@ -31,9 +31,9 @@ from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model as model_lib from tensorflow_federated.python.learning import model_examples -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.metrics import counters +from tensorflow_federated.python.learning.models import model_weights def _create_whimsy_types(feature_dims): @@ -486,7 +486,7 @@ def test_keras_model_multiple_inputs(self): self.assertEqual(m['mean_absolute_error'][1], 4) # Ensure we can assign the FL trained model weights to a new model. - tff_weights = model_utils.ModelWeights.from_model(tff_model) + tff_weights = model_weights.ModelWeights.from_model(tff_model) keras_model = model_examples.build_multiple_inputs_keras_model() tff_weights.assign_weights_to(keras_model) loaded_model = keras_utils.from_keras_model( @@ -538,7 +538,7 @@ def test_keras_model_using_batch_norm_gets_warning(self): self.assertEqual(m['mean_absolute_error'][1], 4) # Ensure we can assign the FL trained model weights to a new model. - tff_weights = model_utils.ModelWeights.from_model(tff_model) + tff_weights = model_weights.ModelWeights.from_model(tff_model) keras_model = model_examples.build_conv_batch_norm_keras_model() tff_weights.assign_weights_to(keras_model) @@ -587,7 +587,7 @@ def _train_loop(): optimizer.apply_gradients( zip(gradients, tff_model.trainable_variables)) return (tff_model.report_local_unfinalized_metrics(), - model_utils.ModelWeights.from_model(tff_model)) + model_weights.ModelWeights.from_model(tff_model)) return _train_loop() @@ -935,7 +935,7 @@ def test_keras_model_lookup_table(self): self.assertEqual(metrics['mean_absolute_error'][1], 6) # Ensure we can assign the FL trained model weights to a new model. - tff_weights = model_utils.ModelWeights.from_model(tff_model) + tff_weights = model_weights.ModelWeights.from_model(tff_model) keras_model = model_examples.build_lookup_table_keras_model() tff_weights.assign_weights_to(keras_model) loaded_model = keras_utils.from_keras_model( @@ -978,7 +978,7 @@ def test_keras_model_preprocessing(self): self.assertEqual(metrics['mean_absolute_error'][1], 2) # Ensure we can assign the FL trained model weights to a new model. - tff_weights = model_utils.ModelWeights.from_model(tff_model) + tff_weights = model_weights.ModelWeights.from_model(tff_model) keras_model = model_examples.build_lookup_table_keras_model() tff_weights.assign_weights_to(keras_model) loaded_model = keras_utils.from_keras_model( diff --git a/tensorflow_federated/python/learning/models/BUILD b/tensorflow_federated/python/learning/models/BUILD index 5f6b62a31d..25771942fd 100644 --- a/tensorflow_federated/python/learning/models/BUILD +++ b/tensorflow_federated/python/learning/models/BUILD @@ -14,6 +14,30 @@ package_group( licenses(["notice"]) +py_library( + name = "model_weights", + srcs = ["model_weights.py"], + srcs_version = "PY3", + visibility = [ + ":models_packages", + "//tensorflow_federated/python/learning:__pkg__", + "//tensorflow_federated/python/learning:learning_users", + "//tensorflow_federated/python/learning/algorithms:__pkg__", + "//tensorflow_federated/python/learning/algorithms:algorithms_packages", + "//tensorflow_federated/python/learning/framework:__pkg__", + "//tensorflow_federated/python/learning/reconstruction:__pkg__", + "//tensorflow_federated/python/learning/templates:__pkg__", + "//tensorflow_federated/python/learning/templates:templates_packages", + ], + deps = [ + "//tensorflow_federated/python/common_libs:py_typecheck", + "//tensorflow_federated/python/common_libs:structure", + "//tensorflow_federated/python/core/impl/types:computation_types", + "//tensorflow_federated/python/core/impl/types:type_conversions", + "//tensorflow_federated/python/learning:model", + ], +) + py_library( name = "models", srcs = ["__init__.py"], @@ -21,6 +45,7 @@ py_library( visibility = ["//tensorflow_federated/python/learning:learning_packages"], deps = [ ":functional", + ":model_weights", ":serialization", ], ) @@ -72,6 +97,19 @@ py_library( ], ) +py_test( + name = "model_weights_test", + srcs = ["model_weights_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":model_weights", + "//tensorflow_federated/python/common_libs:structure", + "//tensorflow_federated/python/core/impl/types:computation_types", + "//tensorflow_federated/python/learning:model", + ], +) + py_test( name = "serialization_test", srcs = ["serialization_test.py"], diff --git a/tensorflow_federated/python/learning/models/__init__.py b/tensorflow_federated/python/learning/models/__init__.py index 5f1f57930a..f0b55d6756 100644 --- a/tensorflow_federated/python/learning/models/__init__.py +++ b/tensorflow_federated/python/learning/models/__init__.py @@ -16,6 +16,8 @@ from tensorflow_federated.python.learning.models.functional import functional_model_from_keras from tensorflow_federated.python.learning.models.functional import FunctionalModel from tensorflow_federated.python.learning.models.functional import model_from_functional +from tensorflow_federated.python.learning.models.model_weights import ModelWeights +from tensorflow_federated.python.learning.models.model_weights import weights_type_from_model from tensorflow_federated.python.learning.models.serialization import load from tensorflow_federated.python.learning.models.serialization import load_functional_model from tensorflow_federated.python.learning.models.serialization import save diff --git a/tensorflow_federated/python/learning/model_utils.py b/tensorflow_federated/python/learning/models/model_weights.py similarity index 100% rename from tensorflow_federated/python/learning/model_utils.py rename to tensorflow_federated/python/learning/models/model_weights.py diff --git a/tensorflow_federated/python/learning/model_utils_test.py b/tensorflow_federated/python/learning/models/model_weights_test.py similarity index 88% rename from tensorflow_federated/python/learning/model_utils_test.py rename to tensorflow_federated/python/learning/models/model_weights_test.py index 84d1295405..ff9e49d06c 100644 --- a/tensorflow_federated/python/learning/model_utils_test.py +++ b/tensorflow_federated/python/learning/models/model_weights_test.py @@ -21,7 +21,7 @@ from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights class TestModel(model_lib.Model): @@ -79,13 +79,13 @@ class WeightsTypeFromModelTest(absltest.TestCase): def test_fails_not_callable_or_model(self): with self.assertRaises(TypeError): - model_utils.weights_type_from_model(0) + model_weights.weights_type_from_model(0) with self.assertRaises(TypeError): - model_utils.weights_type_from_model(lambda: 0) + model_weights.weights_type_from_model(lambda: 0) def test_returns_model_weights_for_model(self): model = TestModel() - weights_type = model_utils.weights_type_from_model(model) + weights_type = model_weights.weights_type_from_model(model) self.assertEqual( computation_types.StructWithPythonType( [('trainable', @@ -96,10 +96,10 @@ def test_returns_model_weights_for_model(self): ('non_trainable', computation_types.StructWithPythonType([ computation_types.TensorType(tf.int32), - ], list))], model_utils.ModelWeights), weights_type) + ], list))], model_weights.ModelWeights), weights_type) def test_returns_model_weights_for_model_callable(self): - weights_type = model_utils.weights_type_from_model(TestModel) + weights_type = model_weights.weights_type_from_model(TestModel) self.assertEqual( computation_types.StructWithPythonType( [('trainable', @@ -110,13 +110,13 @@ def test_returns_model_weights_for_model_callable(self): ('non_trainable', computation_types.StructWithPythonType([ computation_types.TensorType(tf.int32), - ], list))], model_utils.ModelWeights), weights_type) + ], list))], model_weights.ModelWeights), weights_type) class ConvertVariablesToArraysTest(tf.test.TestCase): def test_raises_exception_in_graph_context(self): - w = model_utils.ModelWeights(0.0, 0.0) + w = model_weights.ModelWeights(0.0, 0.0) with tf.Graph().as_default(): with self.assertRaisesRegex(ValueError, 'eager'): w.convert_variables_to_arrays() @@ -127,7 +127,7 @@ def test_raises_exception_in_tf_function(self): def a_tf_function(w): return w.convert_variables_to_arrays() - w = model_utils.ModelWeights(0.0, 0.0) + w = model_weights.ModelWeights(0.0, 0.0) with self.assertRaisesRegex(ValueError, r'tf\.function'): a_tf_function(w) @@ -138,14 +138,14 @@ def test_raises_exception_in_tf_function_and_graph_context(self): def a_tf_function(w): return w.convert_variables_to_arrays() - w = model_utils.ModelWeights(0.0, 0.0) + w = model_weights.ModelWeights(0.0, 0.0) with tf.Graph().as_default(): with self.assertRaisesRegex(ValueError, 'eager'): a_tf_function(w) def test_converts_int(self): - w = model_utils.ModelWeights(1, 2) + w = model_weights.ModelWeights(1, 2) converted = w.convert_variables_to_arrays() self.assertIsInstance(converted.trainable, np.ndarray) self.assertIsInstance(converted.non_trainable, np.ndarray) @@ -153,7 +153,7 @@ def test_converts_int(self): self.assertEqual(converted.non_trainable, 2) def test_converts_float(self): - w = model_utils.ModelWeights(1.0, 2.0) + w = model_weights.ModelWeights(1.0, 2.0) converted = w.convert_variables_to_arrays() self.assertIsInstance(converted.trainable, np.ndarray) self.assertIsInstance(converted.non_trainable, np.ndarray) @@ -161,7 +161,7 @@ def test_converts_float(self): self.assertEqual(converted.non_trainable, 2.0) def test_converts_tensor(self): - w = model_utils.ModelWeights(tf.constant(1.0), tf.constant(2.0)) + w = model_weights.ModelWeights(tf.constant(1.0), tf.constant(2.0)) converted = w.convert_variables_to_arrays() self.assertIsInstance(converted.trainable, np.ndarray) self.assertIsInstance(converted.non_trainable, np.ndarray) @@ -169,7 +169,7 @@ def test_converts_tensor(self): self.assertEqual(converted.non_trainable, 2.0) def test_converts_variable(self): - w = model_utils.ModelWeights(tf.Variable(1.0), tf.Variable(2.0)) + w = model_weights.ModelWeights(tf.Variable(1.0), tf.Variable(2.0)) converted = w.convert_variables_to_arrays() self.assertIsInstance(converted.trainable, np.ndarray) self.assertIsInstance(converted.non_trainable, np.ndarray) @@ -177,7 +177,7 @@ def test_converts_variable(self): self.assertEqual(converted.non_trainable, 2.0) def test_converts_ndarray(self): - w = model_utils.ModelWeights(np.array([1.0]), np.array([2.0, 3.0])) + w = model_weights.ModelWeights(np.array([1.0]), np.array([2.0, 3.0])) converted = w.convert_variables_to_arrays() self.assertIsInstance(converted.trainable, np.ndarray) self.assertIsInstance(converted.non_trainable, np.ndarray) @@ -185,7 +185,7 @@ def test_converts_ndarray(self): self.assertAllEqual(converted.non_trainable, [2.0, 3.0]) def test_converts_heterogeneous_types(self): - w = model_utils.ModelWeights( + w = model_weights.ModelWeights( [1, 2.0, tf.constant(3), tf.Variable(4)], [np.zeros([2, 3])]) converted = w.convert_variables_to_arrays() tf.nest.map_structure(lambda item: self.assertIsInstance(item, np.ndarray), @@ -194,7 +194,7 @@ def test_converts_heterogeneous_types(self): converted.non_trainable) def test_converts_struct(self): - w = model_utils.ModelWeights( + w = model_weights.ModelWeights( structure.Struct.unnamed(1.0), structure.Struct.unnamed(2.0, 3.0)) converted = w.convert_variables_to_arrays() structure.map_structure( @@ -210,7 +210,7 @@ def test_converts_struct(self): [(None, np.array([2.0])), (None, np.array([3.0]))]) def test_converts_heterogeneous_struct(self): - w = model_utils.ModelWeights( + w = model_weights.ModelWeights( structure.Struct.named( a=1, b=2.0, diff --git a/tensorflow_federated/python/learning/personalization_eval.py b/tensorflow_federated/python/learning/personalization_eval.py index 527a1160c3..cdb5e4e959 100644 --- a/tensorflow_federated/python/learning/personalization_eval.py +++ b/tensorflow_federated/python/learning/personalization_eval.py @@ -30,7 +30,7 @@ from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib def build_personalization_eval(model_fn, @@ -89,7 +89,7 @@ def build_personalization_eval(model_fn, A federated `tff.Computation` with the functional type signature `( -> personalization_metrics@SERVER)`: - * `model_weights` is a `tff.learning.ModelWeights`. + * `model_weights` is a `tff.learning.models.ModelWeights`. * Each client's input is an `OrderedDict` of two required keys `train_data` and `test_data`; each key is mapped to an unbatched `tf.data.Dataset`. If extra context (e.g., extra datasets) is used in @@ -118,7 +118,7 @@ def build_personalization_eval(model_fn, with tf.Graph().as_default(): py_typecheck.check_callable(model_fn) model = model_fn() - model_weights_type = model_utils.weights_type_from_model(model) + model_weights_type = model_weights_lib.weights_type_from_model(model) batch_tff_type = computation_types.to_type(model.input_spec) # Define the `tff.Type` of each client's input. Since batching (as well as @@ -225,7 +225,7 @@ def _compute_baseline_metrics(model_fn, initial_model_weights, test_data, baseline_evaluate_fn): """Evaluate the model with weights being the `initial_model_weights`.""" model = model_fn() - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) @tf.function def assign_and_compute(): @@ -241,7 +241,7 @@ def _compute_p13n_metrics(model_fn, initial_model_weights, train_data, test_data, personalize_fn_dict, context): """Train and evaluate the personalized models.""" model = model_fn() - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) # Construct the `personalize_fn` (and the associated `tf.Variable`s) here. # This ensures that the new variables are created in the graphs that TFF # controls. This is the key reason why we need `personalize_fn_dict` to diff --git a/tensorflow_federated/python/learning/personalization_eval_test.py b/tensorflow_federated/python/learning/personalization_eval_test.py index 0ce2be31ab..084f53907a 100644 --- a/tensorflow_federated/python/learning/personalization_eval_test.py +++ b/tensorflow_federated/python/learning/personalization_eval_test.py @@ -23,9 +23,9 @@ from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning import personalization_eval as p13n_eval from tensorflow_federated.python.learning.framework import dataset_reduce +from tensorflow_federated.python.learning.models import model_weights # TODO(b/160896627): Switch to `dataset.reduce` once multi-GPU supports it. dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(simulation_flag=True) @@ -142,7 +142,7 @@ def _create_zero_model_weights(model_fn): """Creates the model weights with all zeros.""" whimsy_model = model_fn() return tf.nest.map_structure( - tf.zeros_like, model_utils.ModelWeights.from_model(whimsy_model)) + tf.zeros_like, model_weights.ModelWeights.from_model(whimsy_model)) class PersonalizationEvalTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_federated/python/learning/reconstruction/BUILD b/tensorflow_federated/python/learning/reconstruction/BUILD index 06f5c584b8..a57a01982b 100644 --- a/tensorflow_federated/python/learning/reconstruction/BUILD +++ b/tensorflow_federated/python/learning/reconstruction/BUILD @@ -38,7 +38,7 @@ py_library( srcs_version = "PY3", deps = [ ":model", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -50,7 +50,7 @@ py_test( deps = [ ":keras_utils", ":reconstruction_utils", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -134,8 +134,8 @@ py_test( "//tensorflow_federated/python/core/templates:iterative_process", "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/metrics:counters", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:sgdm", ], ) diff --git a/tensorflow_federated/python/learning/reconstruction/reconstruction_utils.py b/tensorflow_federated/python/learning/reconstruction/reconstruction_utils.py index 67b4840b52..b760557e8d 100644 --- a/tensorflow_federated/python/learning/reconstruction/reconstruction_utils.py +++ b/tensorflow_federated/python/learning/reconstruction/reconstruction_utils.py @@ -22,7 +22,7 @@ import tensorflow as tf -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.reconstruction import model as model_lib # Type alias for a function that takes in a TF dataset and produces two TF @@ -149,16 +149,16 @@ def dataset_split_fn( return dataset_split_fn -def get_global_variables(model: model_lib.Model) -> model_utils.ModelWeights: +def get_global_variables(model: model_lib.Model) -> model_weights.ModelWeights: """Gets global variables from a `Model` as `ModelWeights`.""" - return model_utils.ModelWeights( + return model_weights.ModelWeights( trainable=model.global_trainable_variables, non_trainable=model.global_non_trainable_variables) -def get_local_variables(model: model_lib.Model) -> model_utils.ModelWeights: +def get_local_variables(model: model_lib.Model) -> model_weights.ModelWeights: """Gets local variables from a `Model` as `ModelWeights`.""" - return model_utils.ModelWeights( + return model_weights.ModelWeights( trainable=model.local_trainable_variables, non_trainable=model.local_non_trainable_variables) diff --git a/tensorflow_federated/python/learning/reconstruction/reconstruction_utils_test.py b/tensorflow_federated/python/learning/reconstruction/reconstruction_utils_test.py index 4bfeb58be5..6a6f50cc66 100644 --- a/tensorflow_federated/python/learning/reconstruction/reconstruction_utils_test.py +++ b/tensorflow_federated/python/learning/reconstruction/reconstruction_utils_test.py @@ -17,7 +17,7 @@ import tensorflow as tf -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.reconstruction import keras_utils from tensorflow_federated.python.learning.reconstruction import reconstruction_utils @@ -228,7 +228,7 @@ def test_get_global_variables(self): global_weights = reconstruction_utils.get_global_variables(model) - self.assertIsInstance(global_weights, model_utils.ModelWeights) + self.assertIsInstance(global_weights, model_weights.ModelWeights) # The last layer of the Keras model, which is a local Dense layer, contains # 2 trainable variables for the weights and bias. self.assertEqual(global_weights.trainable, @@ -246,7 +246,7 @@ def test_get_local_variables(self): local_weights = reconstruction_utils.get_local_variables(model) - self.assertIsInstance(local_weights, model_utils.ModelWeights) + self.assertIsInstance(local_weights, model_weights.ModelWeights) # The last layer of the Keras model, which is a local Dense layer, contains # 2 trainable variables for the weights and bias. self.assertEqual(local_weights.trainable, diff --git a/tensorflow_federated/python/learning/reconstruction/training_process_test.py b/tensorflow_federated/python/learning/reconstruction/training_process_test.py index f6aec8f1e1..d50c657d61 100644 --- a/tensorflow_federated/python/learning/reconstruction/training_process_test.py +++ b/tensorflow_federated/python/learning/reconstruction/training_process_test.py @@ -38,8 +38,8 @@ from tensorflow_federated.python.core.templates import iterative_process as iterative_process_lib from tensorflow_federated.python.core.templates import measured_process as measured_process_lib from tensorflow_federated.python.learning import client_weight_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.metrics import counters +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.reconstruction import keras_utils from tensorflow_federated.python.learning.reconstruction import model as model_lib @@ -988,14 +988,14 @@ def test_get_model_weights(self): state = it_process.initialize() self.assertIsInstance( - it_process.get_model_weights(state), model_utils.ModelWeights) + it_process.get_model_weights(state), model_weights.ModelWeights) self.assertAllClose(state.model.trainable, it_process.get_model_weights(state).trainable) for _ in range(3): state, _ = it_process.next(state, federated_data) self.assertIsInstance( - it_process.get_model_weights(state), model_utils.ModelWeights) + it_process.get_model_weights(state), model_weights.ModelWeights) self.assertAllClose(state.model.trainable, it_process.get_model_weights(state).trainable) diff --git a/tensorflow_federated/python/learning/templates/BUILD b/tensorflow_federated/python/learning/templates/BUILD index fa389c3c78..ad0fc40ce4 100644 --- a/tensorflow_federated/python/learning/templates/BUILD +++ b/tensorflow_federated/python/learning/templates/BUILD @@ -43,7 +43,7 @@ py_library( "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/impl/types:type_conversions", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:keras_optimizer", "//tensorflow_federated/python/learning/optimizers:optimizer", ], @@ -61,7 +61,7 @@ py_test( "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/impl/types:type_test_utils", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/learning/optimizers:sgdm", ], @@ -98,7 +98,7 @@ py_test( "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -124,7 +124,7 @@ py_library( "//tensorflow_federated/python/core/templates:aggregation_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:sgdm", ], ) @@ -157,7 +157,7 @@ py_test( "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:keras_utils", "//tensorflow_federated/python/learning:model_examples", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:sgdm", ], ) @@ -227,7 +227,7 @@ py_test( "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/templates:errors", "//tensorflow_federated/python/core/templates:measured_process", - "//tensorflow_federated/python/learning:model_utils", + "//tensorflow_federated/python/learning/models:model_weights", ], ) @@ -303,10 +303,10 @@ py_library( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/tensorflow_libs:tensor_utils", ], @@ -332,10 +332,10 @@ py_test( "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:keras_utils", "//tensorflow_federated/python/learning:model_examples", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:counters", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:sgdm", ], ) @@ -357,10 +357,10 @@ py_library( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/metrics:aggregator", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/optimizers:optimizer", "//tensorflow_federated/python/tensorflow_libs:tensor_utils", ], @@ -384,9 +384,9 @@ py_test( "//tensorflow_federated/python/core/templates:measured_process", "//tensorflow_federated/python/learning:client_weight_lib", "//tensorflow_federated/python/learning:model_examples", - "//tensorflow_federated/python/learning:model_utils", "//tensorflow_federated/python/learning/framework:dataset_reduce", "//tensorflow_federated/python/learning/models:functional", + "//tensorflow_federated/python/learning/models:model_weights", "//tensorflow_federated/python/learning/models:test_models", "//tensorflow_federated/python/learning/optimizers:sgdm", ], diff --git a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py index e7ac144a7b..02e77c5468 100644 --- a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py +++ b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer.py @@ -30,7 +30,7 @@ from tensorflow_federated.python.core.impl.types import type_analysis from tensorflow_federated.python.core.impl.types import type_conversions from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import keras_optimizer from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import finalizers @@ -111,7 +111,7 @@ def build_apply_optimizer_finalizer( """Builds finalizer that applies a step of an optimizer. The provided `model_weights_type` must be a non-federated `tff.Type` with the - `tff.learning.ModelWeights` container. + `tff.learning.models.ModelWeights` container. The 2nd input argument of the created `FinalizerProcess.next` expects a value matching `model_weights_type` and its 3rd argument expects value matching @@ -126,13 +126,14 @@ def build_apply_optimizer_finalizer( that returns a `tf.keras.optimizers.Optimizer`. This optimizer is used to apply client updates to the server model. model_weights_type: A non-federated `tff.Type` of the model weights to be - optimized, which must have a `tff.learning.ModelWeights` container. + optimized, which must have a `tff.learning.models.ModelWeights` container. Returns: A `FinalizerProcess` that applies the `optimizer`. Raises: - TypeError: If `value_type` does not have a `tff.learning.ModelWeights` + TypeError: If `value_type` does not have a + `tff.learning.model.sModelWeights` Python container, or contains a `tff.types.FederatedType`. """ if not isinstance(optimizer_fn, optimizer_base.Optimizer): @@ -145,11 +146,11 @@ def build_apply_optimizer_finalizer( 'a no-arg callable returning a `tf.keras.optimizers.Optimizer`.') if (not model_weights_type.is_struct_with_python() or - model_weights_type.python_container != model_utils.ModelWeights or + model_weights_type.python_container != model_weights.ModelWeights or type_analysis.contains_federated_types(model_weights_type)): raise TypeError( f'Provided value_type must be a tff.types.StructType with its python ' - f'container being tff.learning.ModelWeights, not containing a ' + f'container being tff.learning.models.ModelWeights, not containing a ' f'tff.types.FederatedType, but found: {model_weights_type}') if isinstance(optimizer_fn, optimizer_base.Optimizer): @@ -171,7 +172,8 @@ def next_fn(state, weights, update): optimizer_state, new_trainable_weights = intrinsics.federated_map( next_tf, (state, weights.trainable, update)) new_weights = intrinsics.federated_zip( - model_utils.ModelWeights(new_trainable_weights, weights.non_trainable)) + model_weights.ModelWeights(new_trainable_weights, + weights.non_trainable)) empty_measurements = intrinsics.federated_value((), placements.SERVER) return measured_process.MeasuredProcessOutput(optimizer_state, new_weights, empty_measurements) diff --git a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py index e4a75aedcc..89bad622ea 100644 --- a/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py +++ b/tensorflow_federated/python/learning/templates/apply_optimizer_finalizer_test.py @@ -23,14 +23,14 @@ from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.impl.types import type_test_utils from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer SERVER_FLOAT = computation_types.FederatedType(tf.float32, placements.SERVER) MODEL_WEIGHTS_TYPE = computation_types.at_server( - computation_types.to_type(model_utils.ModelWeights(tf.float32, ()))) + computation_types.to_type(model_weights.ModelWeights(tf.float32, ()))) MeasuredProcessOutput = measured_process.MeasuredProcessOutput @@ -39,7 +39,7 @@ class ApplyOptimizerFinalizerComputationTest(tf.test.TestCase, def test_initialize_has_expected_type_with_keras_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) optimizer_fn = lambda: tf.keras.optimizers.legacy.SGD(learning_rate=1.0) @@ -53,7 +53,7 @@ def test_initialize_has_expected_type_with_keras_optimizer(self): def test_next_has_expected_type_with_keras_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) optimizer_fn = lambda: tf.keras.optimizers.legacy.SGD(learning_rate=1.0) @@ -77,7 +77,7 @@ def test_next_has_expected_type_with_keras_optimizer(self): def test_get_hparams_has_expected_type_with_keras_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) optimizer_fn = lambda: tf.keras.optimizers.legacy.SGD(learning_rate=1.0) @@ -93,7 +93,7 @@ def test_get_hparams_has_expected_type_with_keras_optimizer(self): def test_set_hparams_has_expected_type_with_keras_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) optimizer_fn = lambda: tf.keras.optimizers.legacy.SGD(learning_rate=1.0) @@ -112,7 +112,7 @@ def test_set_hparams_has_expected_type_with_keras_optimizer(self): def test_initialize_has_expected_type_with_tff_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( @@ -129,7 +129,7 @@ def test_initialize_has_expected_type_with_tff_optimizer(self): def test_next_has_expected_type_with_tff_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( @@ -155,7 +155,7 @@ def test_next_has_expected_type_with_tff_optimizer(self): def test_get_hparams_has_expected_type_with_tff_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( @@ -172,7 +172,7 @@ def test_get_hparams_has_expected_type_with_tff_optimizer(self): def test_set_hparams_has_expected_type_with_tff_optimizer(self): mw_type = computation_types.to_type( - model_utils.ModelWeights( + model_weights.ModelWeights( trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( @@ -195,7 +195,7 @@ def test_set_hparams_has_expected_type_with_tff_optimizer(self): ('federated_type', MODEL_WEIGHTS_TYPE), ('model_weights_of_federated_types', computation_types.to_type( - model_utils.ModelWeights(SERVER_FLOAT, SERVER_FLOAT))), + model_weights.ModelWeights(SERVER_FLOAT, SERVER_FLOAT))), ('not_model_weights', computation_types.to_type( (tf.float32, tf.float32))), ('function_type', computation_types.FunctionType(None, @@ -220,7 +220,7 @@ def test_execution_with_stateless_tff_optimizer(self): finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), MODEL_WEIGHTS_TYPE.member) - weights = model_utils.ModelWeights(1.0, ()) + weights = model_weights.ModelWeights(1.0, ()) update = 0.1 optimizer_state = finalizer.initialize() for i in range(5): @@ -238,7 +238,7 @@ def test_execution_with_keras_sgd_optimizer(self): finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( server_optimizer_fn, MODEL_WEIGHTS_TYPE.member) - weights = model_utils.ModelWeights(1.0, ()) + weights = model_weights.ModelWeights(1.0, ()) update = 0.1 optimizer_state = finalizer.initialize() for i in range(5): @@ -255,7 +255,7 @@ def test_execution_with_stateful_tff_optimizer(self): finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0, momentum=momentum), MODEL_WEIGHTS_TYPE.member) - weights = model_utils.ModelWeights(1.0, ()) + weights = model_weights.ModelWeights(1.0, ()) update = 0.1 expected_velocity = 0.0 optimizer_state = finalizer.initialize() @@ -278,7 +278,7 @@ def server_optimizer_fn(): finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer( server_optimizer_fn, MODEL_WEIGHTS_TYPE.member) - weights = model_utils.ModelWeights(1.0, ()) + weights = model_weights.ModelWeights(1.0, ()) update = 0.1 expected_velocity = 0.0 optimizer_state = finalizer.initialize() diff --git a/tensorflow_federated/python/learning/templates/client_works_test.py b/tensorflow_federated/python/learning/templates/client_works_test.py index 03036daa86..eb23e637d5 100644 --- a/tensorflow_federated/python/learning/templates/client_works_test.py +++ b/tensorflow_federated/python/learning/templates/client_works_test.py @@ -24,7 +24,7 @@ from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.templates import client_works SERVER_INT = computation_types.FederatedType(tf.int32, placements.SERVER) @@ -34,7 +34,7 @@ CLIENTS_FLOAT = computation_types.FederatedType(tf.float32, placements.CLIENTS) CLIENTS_INT = computation_types.FederatedType(tf.int32, placements.CLIENTS) MODEL_WEIGHTS_TYPE = computation_types.at_clients( - computation_types.to_type(model_utils.ModelWeights(tf.float32, ()))) + computation_types.to_type(model_weights.ModelWeights(tf.float32, ()))) HPARAMS_TYPE = computation_types.to_type(collections.OrderedDict(a=tf.int32)) MeasuredProcessOutput = measured_process.MeasuredProcessOutput diff --git a/tensorflow_federated/python/learning/templates/composers.py b/tensorflow_federated/python/learning/templates/composers.py index b09bedb2c7..280428ccdf 100644 --- a/tensorflow_federated/python/learning/templates/composers.py +++ b/tensorflow_federated/python/learning/templates/composers.py @@ -34,7 +34,7 @@ from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import client_works @@ -331,7 +331,7 @@ def client_udpate(m, eta): @tensorflow_computation.tf_computation() def initial_model_weights_fn(): - return model_utils.ModelWeights.from_model(model_fn()) + return model_weights_lib.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result diff --git a/tensorflow_federated/python/learning/templates/composers_test.py b/tensorflow_federated/python/learning/templates/composers_test.py index 9ea71fd93e..43ad98ce94 100644 --- a/tensorflow_federated/python/learning/templates/composers_test.py +++ b/tensorflow_federated/python/learning/templates/composers_test.py @@ -31,7 +31,7 @@ from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer from tensorflow_federated.python.learning.templates import client_works @@ -43,7 +43,7 @@ FLOAT_TYPE = computation_types.TensorType(tf.float32) MODEL_WEIGHTS_TYPE = computation_types.to_type( - model_utils.ModelWeights(FLOAT_TYPE, ())) + model_weights_lib.ModelWeights(FLOAT_TYPE, ())) CLIENTS_SEQUENCE_FLOAT_TYPE = computation_types.at_clients( computation_types.SequenceType(FLOAT_TYPE)) @@ -59,7 +59,8 @@ def empty_init_fn(): @tensorflow_computation.tf_computation() def test_init_model_weights_fn(): - return model_utils.ModelWeights(trainable=tf.constant(1.0), non_trainable=()) + return model_weights_lib.ModelWeights( + trainable=tf.constant(1.0), non_trainable=()) def test_distributor(): @@ -109,7 +110,7 @@ def next_fn(state, weights, updates): tensorflow_computation.tf_computation(lambda x, y: x + y), (weights.trainable, updates)) new_weights = intrinsics.federated_zip( - model_utils.ModelWeights(new_weights, ())) + model_weights_lib.ModelWeights(new_weights, ())) return measured_process.MeasuredProcessOutput(state, new_weights, empty_at_server()) @@ -149,7 +150,7 @@ def test_one_arg_computation_init_raises(self): @tensorflow_computation.tf_computation( computation_types.TensorType(tf.float32)) def init_model_weights_fn(x): - return model_utils.ModelWeights(trainable=x, non_trainable=()) + return model_weights_lib.ModelWeights(trainable=x, non_trainable=()) with self.assertRaisesRegex(TypeError, 'Computation'): composers.compose_learning_process(init_model_weights_fn, @@ -159,7 +160,7 @@ def init_model_weights_fn(x): def test_not_tff_computation_init_raises(self): def init_model_weights_fn(): - return model_utils.ModelWeights( + return model_weights_lib.ModelWeights( trainable=tf.constant(1.0), non_trainable=()) with self.assertRaisesRegex(TypeError, 'Computation'): @@ -237,7 +238,8 @@ def _test_data(self): def _test_batch_loss(self, model, weights): tf.nest.map_structure(lambda w, v: w.assign(v), - model_utils.ModelWeights.from_model(model), weights) + model_weights_lib.ModelWeights.from_model(model), + weights) for batch in self._test_data().take(1): batch_output = model.forward_pass(batch, training=False) return batch_output.loss @@ -291,7 +293,7 @@ def model_fn(): ) keras_model.compile(optimizer='adam', loss='mse') keras_model.fit(self._test_data().map(lambda d: (d['x'], d['y']))) - pretrained_weights = model_utils.ModelWeights.from_model(keras_model) + pretrained_weights = model_weights_lib.ModelWeights.from_model(keras_model) # Assert the initial state weights are not the same as the pretrained model. initial_weights = fedavg.get_model_weights(state) self.assertNotAllClose( diff --git a/tensorflow_federated/python/learning/templates/finalizers_test.py b/tensorflow_federated/python/learning/templates/finalizers_test.py index 7d1ac84a9c..230a267a6f 100644 --- a/tensorflow_federated/python/learning/templates/finalizers_test.py +++ b/tensorflow_federated/python/learning/templates/finalizers_test.py @@ -23,7 +23,7 @@ from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process -from tensorflow_federated.python.learning import model_utils +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.templates import finalizers SERVER_INT = computation_types.FederatedType(tf.int32, placements.SERVER) @@ -31,7 +31,7 @@ CLIENTS_INT = computation_types.FederatedType(tf.int32, placements.CLIENTS) CLIENTS_FLOAT = computation_types.FederatedType(tf.float32, placements.CLIENTS) MODEL_WEIGHTS_TYPE = computation_types.at_server( - computation_types.to_type(model_utils.ModelWeights(tf.float32, ()))) + computation_types.to_type(model_weights.ModelWeights(tf.float32, ()))) MeasuredProcessOutput = measured_process.MeasuredProcessOutput @@ -52,7 +52,7 @@ def test_initialize_fn(): def test_finalizer_result(weights, update): return intrinsics.federated_zip( - model_utils.ModelWeights(federated_add(weights.trainable, update), ())) + model_weights.ModelWeights(federated_add(weights.trainable, update), ())) @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE, @@ -166,10 +166,10 @@ def test_non_federated_init_next_raises(self): @tensorflow_computation.tf_computation( tf.int32, - computation_types.to_type(model_utils.ModelWeights(tf.float32, - ())), tf.float32) + computation_types.to_type(model_weights.ModelWeights(tf.float32, + ())), tf.float32) def next_fn(state, weights, update): - new_weigths = model_utils.ModelWeights(weights.trainable + update, ()) + new_weigths = model_weights.ModelWeights(weights.trainable + update, ()) return MeasuredProcessOutput(state, new_weigths, 0) with self.assertRaises(errors.TemplateNotFederatedError): diff --git a/tensorflow_federated/python/learning/templates/model_delta_client_work.py b/tensorflow_federated/python/learning/templates/model_delta_client_work.py index 40075730e4..090608a216 100644 --- a/tensorflow_federated/python/learning/templates/model_delta_client_work.py +++ b/tensorflow_federated/python/learning/templates/model_delta_client_work.py @@ -40,10 +40,10 @@ from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.tensorflow_libs import tensor_utils @@ -79,7 +79,7 @@ def build_model_delta_update_with_tff_optimizer( @tf.function def client_update(optimizer, initial_weights, data, optimizer_hparams=None): - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) @@ -165,7 +165,7 @@ def build_model_delta_update_with_keras_optimizer( @tf.function def client_update(optimizer, initial_weights, data): - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) @@ -291,7 +291,7 @@ def build_model_delta_client_work( metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) - weights_type = model_utils.weights_type_from_model(model) + weights_type = model_weights_lib.weights_type_from_model(model) if isinstance(optimizer, optimizer_base.Optimizer): @@ -482,7 +482,7 @@ def ndarray_to_tensorspec(ndarray): shape=ndarray.shape, dtype=tf.dtypes.as_dtype(ndarray.dtype)) # Wrap in a `ModelWeights` structure that is required by the `finalizer.` - weights_type = model_utils.ModelWeights( + weights_type = model_weights_lib.ModelWeights( tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[0]), tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[1])) diff --git a/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py b/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py index cb1b5b4034..012b6e7d1e 100644 --- a/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py +++ b/tensorflow_federated/python/learning/templates/model_delta_client_work_test.py @@ -30,10 +30,10 @@ from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import counters from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.learning.templates import model_delta_client_work @@ -70,7 +70,7 @@ def test_next_has_expected_type_signature_with_keras_optimizer( client_work_process = model_delta_client_work.build_model_delta_client_work( model_fn, optimizer_fn, weighting) - mw_type = model_utils.ModelWeights( + mw_type = model_weights_lib.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients(mw_type) @@ -125,7 +125,7 @@ def test_next_has_expected_type_signature_with_tff_optimizer(self, weighting): client_work_process = model_delta_client_work.build_model_delta_client_work( model_fn, optimizer, weighting) - mw_type = model_utils.ModelWeights( + mw_type = model_weights_lib.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients(mw_type) @@ -263,8 +263,8 @@ def create_test_dataset() -> tf.data.Dataset: return dataset.repeat(2).batch(3) -def create_test_initial_weights() -> model_utils.ModelWeights: - return model_utils.ModelWeights( +def create_test_initial_weights() -> model_weights_lib.ModelWeights: + return model_weights_lib.ModelWeights( trainable=[tf.zeros((2, 1)), tf.constant(0.0)], non_trainable=[0.0]) @@ -538,7 +538,7 @@ def model_fn(): client_update_model_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer( model_fn=model_fn, weighting=weighting) model_fn_optimizer = sgdm.build_sgdm(learning_rate=0.1) - model_fn_weights = model_utils.ModelWeights.from_model(model_fn()) + model_fn_weights = model_weights_lib.ModelWeights.from_model(model_fn()) functional_model_weights = functional_model.initial_weights for _ in range(10): diff --git a/tensorflow_federated/python/learning/templates/proximal_client_work.py b/tensorflow_federated/python/learning/templates/proximal_client_work.py index 7e218e1772..8a71df222e 100644 --- a/tensorflow_federated/python/learning/templates/proximal_client_work.py +++ b/tensorflow_federated/python/learning/templates/proximal_client_work.py @@ -40,10 +40,10 @@ from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model as model_lib -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.metrics import aggregator from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights as model_weights_lib from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.tensorflow_libs import tensor_utils @@ -82,7 +82,7 @@ def build_model_delta_update_with_tff_optimizer( @tf.function def client_update(optimizer, initial_weights, data): - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) @@ -172,7 +172,7 @@ def build_model_delta_update_with_keras_optimizer( @tf.function def client_update(optimizer, initial_weights, data): - model_weights = model_utils.ModelWeights.from_model(model) + model_weights = model_weights_lib.ModelWeights.from_model(model) tf.nest.map_structure(lambda a, b: a.assign(b), model_weights, initial_weights) @@ -396,7 +396,7 @@ def build_model_delta_client_work( metrics_aggregation_fn = metrics_aggregator(model.metric_finalizers(), unfinalized_metrics_type) data_type = computation_types.SequenceType(model.input_spec) - weights_type = model_utils.weights_type_from_model(model) + weights_type = model_weights_lib.weights_type_from_model(model) if isinstance(optimizer, optimizer_base.Optimizer): @@ -491,7 +491,7 @@ def ndarray_to_tensorspec(ndarray): return tf.TensorSpec(shape=ndarray.shape, dtype=ndarray.dtype) # Wrap in a `ModelWeights` structure that is required by the `finalizer.` - weights_type = model_utils.ModelWeights( + weights_type = model_weights_lib.ModelWeights( tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[0]), tuple(ndarray_to_tensorspec(w) for w in model.initial_weights[1])) diff --git a/tensorflow_federated/python/learning/templates/proximal_client_work_test.py b/tensorflow_federated/python/learning/templates/proximal_client_work_test.py index 9ec2d8bb55..1246ec963d 100644 --- a/tensorflow_federated/python/learning/templates/proximal_client_work_test.py +++ b/tensorflow_federated/python/learning/templates/proximal_client_work_test.py @@ -27,9 +27,9 @@ from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import client_weight_lib from tensorflow_federated.python.learning import model_examples -from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.framework import dataset_reduce from tensorflow_federated.python.learning.models import functional +from tensorflow_federated.python.learning.models import model_weights from tensorflow_federated.python.learning.models import test_models from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import client_works @@ -54,7 +54,7 @@ def test_type_properties(self, optimizer, weighting): model_fn, optimizer, weighting, delta_l2_regularizer=0.1) self.assertIsInstance(client_work_process, client_works.ClientWorkProcess) - mw_type = model_utils.ModelWeights( + mw_type = model_weights.ModelWeights( trainable=computation_types.to_type([(tf.float32, (2, 1)), tf.float32]), non_trainable=computation_types.to_type([tf.float32])) expected_param_model_weights_type = computation_types.at_clients(mw_type) @@ -133,8 +133,8 @@ def create_test_dataset() -> tf.data.Dataset: return dataset.repeat(2).batch(3) -def create_test_initial_weights() -> model_utils.ModelWeights: - return model_utils.ModelWeights( +def create_test_initial_weights() -> model_weights.ModelWeights: + return model_weights.ModelWeights( trainable=[tf.zeros((2, 1)), tf.constant(0.0)], non_trainable=[0.0])