diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index a51625eb072e..dbf07b8afa3e 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -15,7 +15,6 @@ from jax._src.lib import xla_client as _xc dtype_to_etype = _xc.dtype_to_etype -execute_with_python_values = _xc.execute_with_python_values get_topology_for_devices = _xc.get_topology_for_devices heap_profile = _xc.heap_profile mlir_api_version = _xc.mlir_api_version diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 75c52822a223..e5222814fb02 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -40,7 +40,6 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface -from jax._src.lib import xla_client from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -177,15 +176,11 @@ def test_put_executable(self): executable_retrieved, compile_time_retrieved = cc.get_executable_and_time( key, compile_options, backend) inputs_to_executable = ( - np.array(1, dtype=np.int32), - np.array(2, dtype=np.int32), - ) - expected = xla_client.execute_with_python_values( - executable, inputs_to_executable, backend - ) - actual = xla_client.execute_with_python_values( - executable_retrieved, inputs_to_executable, backend + jnp.array(1, dtype=np.int32), + jnp.array(2, dtype=np.int32), ) + expected = executable.execute(inputs_to_executable) + actual = executable_retrieved.execute(inputs_to_executable) self.assertEqual(expected, actual) self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index d3dada0d750a..e403daba5254 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -5590,7 +5590,6 @@ def test_isdtype(self, dtype, kind): self.assertEqual(jax_result, numpy_result) -from jaxlib import xla_client @unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.") class ReportedIssuesTests(jtu.JaxTestCase): def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): @@ -5602,10 +5601,10 @@ def dispatchOn(self, args, func, device=jax.devices('cpu')[0]): @staticmethod def compile_and_exec(module, args, run_on_cpu=False): backend = jax.lib.xla_bridge.get_backend('METAL') - if (run_on_cpu): + if run_on_cpu: backend = jax.lib.xla_bridge.get_backend('cpu') - executables = backend.compile(module) - return xla_client.execute_with_python_values(executables, args, backend) + executable = backend.compile(module) + return executable.execute(args) @staticmethod def jax_metal_supported(target_ver):