Skip to content

Commit

Permalink
Delete jax.lib.xla_client.execute_with_python_values.
Browse files Browse the repository at this point in the history
Nothing under jax.lib.xla_client is public, so there's no deprecation period required.

PiperOrigin-RevId: 681166972
  • Loading branch information
hawkinsp authored and copybara-github committed Oct 1, 2024
1 parent 22e7489 commit b3a3075
Showing 1 changed file with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from jax.lib import xla_client
from jax.lib import xla_extension
import jax.numpy as jnp
import numpy as np

from tensorflow_federated.proto.v0 import computation_pb2 as pb
Expand Down Expand Up @@ -166,12 +167,13 @@ def __call__(self, *args, **kwargs):
flat_py_args = structure.flatten(positional_arg)

reordered_flat_py_args = [
flat_py_args[idx] for idx in self._inverted_parameter_tensor_indexes
jnp.asarray(flat_py_args[idx])
for idx in self._inverted_parameter_tensor_indexes
]

unordered_result = xla_client.execute_with_python_values(
self._executable, reordered_flat_py_args, self._backend
)
unordered_result = [
np.asarray(x) for x in self._executable.execute(reordered_flat_py_args)
]
py_typecheck.check_type(unordered_result, list)
result = [unordered_result[idx] for idx in self._result_tensor_indexes]
result_type = self.type_signature.result
Expand Down

0 comments on commit b3a3075

Please sign in to comment.