Skip to content

Commit

Permalink
Move ModelWeights class into the learning/models/ sub-directory.
Browse files Browse the repository at this point in the history
Deprecated the `tff.learning.ModelWeights` and `tff.learning.framework.ModelWeights` APIs, replaced by `tff.learning.models.ModelWeights`.

PiperOrigin-RevId: 481704926
  • Loading branch information
ZacharyGarrett authored and tensorflow-copybara committed Oct 17, 2022
1 parent 282e4bd commit 3e8cf29
Show file tree
Hide file tree
Showing 41 changed files with 269 additions and 237 deletions.
39 changes: 7 additions & 32 deletions tensorflow_federated/python/learning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ py_library(
":keras_utils",
":model",
":model_update_aggregator",
":model_utils",
":personalization_eval",
"//tensorflow_federated/python/common_libs:deprecation",
"//tensorflow_federated/python/learning/algorithms",
"//tensorflow_federated/python/learning/framework",
"//tensorflow_federated/python/learning/framework:optimizer_utils",
"//tensorflow_federated/python/learning/metrics",
"//tensorflow_federated/python/learning/models",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/optimizers",
"//tensorflow_federated/python/learning/reconstruction",
"//tensorflow_federated/python/learning/templates",
Expand Down Expand Up @@ -84,7 +85,6 @@ py_library(
srcs_version = "PY3",
deps = [
":model",
":model_utils",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
Expand All @@ -95,6 +95,7 @@ py_library(
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_weights",
],
)

Expand All @@ -107,7 +108,6 @@ py_cpu_gpu_test(
":federated_evaluation",
":keras_utils",
":model",
":model_utils",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/federated_context:federated_computation",
"//tensorflow_federated/python/core/impl/federated_context:intrinsics",
Expand All @@ -119,6 +119,7 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/tensorflow_libs:tensorflow_test_utils",
],
)
Expand Down Expand Up @@ -150,14 +151,14 @@ py_test(
":keras_utils",
":model",
":model_examples",
":model_utils",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:counters",
"//tensorflow_federated/python/learning/models:model_weights",
],
)

Expand Down Expand Up @@ -187,19 +188,6 @@ py_test(
],
)

py_library(
name = "model_utils",
srcs = ["model_utils.py"],
srcs_version = "PY3",
deps = [
":model",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:type_conversions",
],
)

py_library(
name = "model_update_aggregator",
srcs = ["model_update_aggregator.py"],
Expand Down Expand Up @@ -248,7 +236,6 @@ py_library(
srcs = ["personalization_eval.py"],
srcs_version = "PY3",
deps = [
":model_utils",
"//tensorflow_federated/python/aggregators:sampling",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
Expand All @@ -257,19 +244,7 @@ py_library(
"//tensorflow_federated/python/core/impl/tensorflow_context:tensorflow_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
],
)

py_test(
name = "model_utils_test",
srcs = ["model_utils_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":model",
":model_utils",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning/models:model_weights",
],
)

Expand All @@ -282,10 +257,10 @@ py_cpu_gpu_test(
deps = [
":keras_utils",
":model_examples",
":model_utils",
":personalization_eval",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/models:model_weights",
],
)
3 changes: 2 additions & 1 deletion tensorflow_federated/python/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
`tff.learning.models` for related model classes.
"""

from tensorflow_federated.python.common_libs import deprecation
from tensorflow_federated.python.learning import algorithms
from tensorflow_federated.python.learning import framework
from tensorflow_federated.python.learning import metrics
Expand All @@ -59,5 +60,5 @@
from tensorflow_federated.python.learning.model_update_aggregator import entropy_compression_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import robust_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import secure_aggregator
from tensorflow_federated.python.learning.model_utils import ModelWeights
from tensorflow_federated.python.learning.models.model_weights import ModelWeights
from tensorflow_federated.python.learning.personalization_eval import build_personalization_eval
20 changes: 10 additions & 10 deletions tensorflow_federated/python/learning/algorithms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ py_library(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer",
"//tensorflow_federated/python/learning/templates:composers",
Expand Down Expand Up @@ -94,8 +94,8 @@ py_library(
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer",
"//tensorflow_federated/python/learning/templates:client_works",
Expand Down Expand Up @@ -138,9 +138,9 @@ py_library(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer",
"//tensorflow_federated/python/learning/templates:composers",
Expand All @@ -165,9 +165,9 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/models:test_models",
"//tensorflow_federated/python/learning/optimizers:sgdm",
"//tensorflow_federated/python/learning/templates:distributors",
Expand All @@ -191,10 +191,10 @@ py_library(
"//tensorflow_federated/python/core/impl/types:type_conversions",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer",
"//tensorflow_federated/python/learning/templates:client_works",
Expand All @@ -217,10 +217,10 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/models:test_models",
"//tensorflow_federated/python/tensorflow_libs:tensorflow_test_utils",
],
Expand Down Expand Up @@ -279,10 +279,10 @@ py_library(
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/optimizers:optimizer",
"//tensorflow_federated/python/learning/optimizers:sgdm",
"//tensorflow_federated/python/learning/templates:apply_optimizer_finalizer",
Expand Down Expand Up @@ -319,11 +319,11 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/learning:keras_utils",
"//tensorflow_federated/python/learning:model_examples",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/framework:dataset_reduce",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:counters",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/models:test_models",
"//tensorflow_federated/python/learning/optimizers:adagrad",
"//tensorflow_federated/python/learning/optimizers:adam",
Expand Down Expand Up @@ -354,8 +354,8 @@ py_library(
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:federated_evaluation",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/metrics:aggregation_factory",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/templates:client_works",
"//tensorflow_federated/python/learning/templates:composers",
"//tensorflow_federated/python/learning/templates:distributors",
Expand Down Expand Up @@ -383,9 +383,9 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:model",
"//tensorflow_federated/python/learning:model_utils",
"//tensorflow_federated/python/learning/metrics:aggregation_factory",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/templates:composers",
"//tensorflow_federated/python/learning/templates:distributors",
"//tensorflow_federated/python/learning/templates:learning_process",
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_federated/python/learning/algorithms/fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.learning import client_weight_lib
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator
from tensorflow_federated.python.learning.models import functional
from tensorflow_federated.python.learning.models import model_weights
from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base
from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer
from tensorflow_federated.python.learning.templates import composers
Expand Down Expand Up @@ -172,7 +172,7 @@ def build_weighted_fed_avg(
@tensorflow_computation.tf_computation()
def initial_model_weights_fn():
trainable_weights, non_trainable_weights = model_fn.initial_weights
return model_utils.ModelWeights(
return model_weights.ModelWeights(
tuple(tf.convert_to_tensor(w) for w in trainable_weights),
tuple(tf.convert_to_tensor(w) for w in non_trainable_weights))

Expand All @@ -186,7 +186,7 @@ def initial_model_weights_fn():
raise TypeError('When `model_fn` is a callable, it return instances of '
'tff.learning.Model. Instead callable returned type: '
f'{type(model)}')
return model_utils.ModelWeights.from_model(model)
return model_weights.ModelWeights.from_model(model)

model_weights_type = initial_model_weights_fn.type_signature.result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import client_weight_lib
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.algorithms import fed_avg
from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator
from tensorflow_federated.python.learning.models import model_weights
from tensorflow_federated.python.learning.optimizers import optimizer as optimizer_base
from tensorflow_federated.python.learning.templates import apply_optimizer_finalizer
from tensorflow_federated.python.learning.templates import client_works
Expand Down Expand Up @@ -101,7 +101,7 @@ def build_scheduled_client_work(
metrics_aggregation_fn = metrics_aggregator(
whimsy_model.metric_finalizers(), unfinalized_metrics_type)
data_type = computation_types.SequenceType(whimsy_model.input_spec)
weights_type = model_utils.weights_type_from_model(whimsy_model)
weights_type = model_weights.weights_type_from_model(whimsy_model)

if isinstance(whimsy_optimizer, optimizer_base.Optimizer):
build_client_update_fn = model_delta_client_work.build_model_delta_update_with_tff_optimizer
Expand Down Expand Up @@ -249,7 +249,7 @@ def build_weighted_fed_avg_with_optimizer_schedule(

@tensorflow_computation.tf_computation()
def initial_model_weights_fn():
return model_utils.ModelWeights.from_model(model_fn())
return model_weights.ModelWeights.from_model(model_fn())

model_weights_type = initial_model_weights_fn.type_signature.result

Expand Down
4 changes: 2 additions & 2 deletions tensorflow_federated/python/learning/algorithms/fed_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import federated_evaluation
from tensorflow_federated.python.learning import model as model_lib
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.metrics import aggregation_factory
from tensorflow_federated.python.learning.models import model_weights as model_weights_lib
from tensorflow_federated.python.learning.templates import client_works
from tensorflow_federated.python.learning.templates import composers
from tensorflow_federated.python.learning.templates import distributors
Expand Down Expand Up @@ -198,7 +198,7 @@ def build_fed_eval(

@tensorflow_computation.tf_computation()
def initial_model_weights_fn():
return model_utils.ModelWeights.from_model(model_fn())
return model_weights_lib.ModelWeights.from_model(model_fn())

model_weights_type = initial_model_weights_fn.type_signature.result

Expand Down
12 changes: 7 additions & 5 deletions tensorflow_federated/python/learning/algorithms/fed_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import model
from tensorflow_federated.python.learning import model_utils
from tensorflow_federated.python.learning.algorithms import fed_eval
from tensorflow_federated.python.learning.metrics import aggregation_factory
from tensorflow_federated.python.learning.metrics import aggregator
from tensorflow_federated.python.learning.models import model_weights as model_weights_lib
from tensorflow_federated.python.learning.templates import composers
from tensorflow_federated.python.learning.templates import distributors
from tensorflow_federated.python.learning.templates import learning_process
Expand Down Expand Up @@ -171,7 +171,7 @@ class FedEvalProcessTest(tf.test.TestCase):
def test_fed_eval_process_type_properties(self):
model_fn = TestModel
test_model = model_fn()
model_weights_type = model_utils.weights_type_from_model(test_model)
model_weights_type = model_weights_lib.weights_type_from_model(test_model)
metric_finalizers = test_model.metric_finalizers()
unfinalized_metrics = test_model.report_local_unfinalized_metrics()
local_unfinalized_metrics_type = type_conversions.type_from_tensors(
Expand Down Expand Up @@ -239,9 +239,11 @@ def test_fed_eval_process_execution(self):
# Update the state with the model weights to be evaluated, and verify that
# the `get_model_weights` method returns the same model weights.
state = eval_process.initialize()
model_weights = model_utils.ModelWeights(trainable=[5.0], non_trainable=[])
model_weights = model_weights_lib.ModelWeights(
trainable=[5.0], non_trainable=[])
new_state = eval_process.set_model_weights(
state, model_utils.ModelWeights(trainable=[5.0], non_trainable=[]))
state,
model_weights_lib.ModelWeights(trainable=[5.0], non_trainable=[]))
tf.nest.map_structure(self.assertAllEqual, model_weights,
eval_process.get_model_weights(new_state))

Expand Down Expand Up @@ -269,7 +271,7 @@ def _temp_dict(temps):
total_rounds_metrics=collections.OrderedDict(num_over=9.0))))

def test_fed_eval_with_model_distributor(self):
model_weights_type = model_utils.weights_type_from_model(TestModel)
model_weights_type = model_weights_lib.weights_type_from_model(TestModel)

def test_distributor():

Expand Down
Loading

0 comments on commit 3e8cf29

Please sign in to comment.