Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete jax.lib.xla_client.execute_with_python_values. #24040

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion jax/lib/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from jax._src.lib import xla_client as _xc

dtype_to_etype = _xc.dtype_to_etype
execute_with_python_values = _xc.execute_with_python_values
get_topology_for_devices = _xc.get_topology_for_devices
heap_profile = _xc.heap_profile
mlir_api_version = _xc.mlir_api_version
Expand Down
13 changes: 4 additions & 9 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
import numpy as np
Expand Down Expand Up @@ -177,15 +176,11 @@ def test_put_executable(self):
executable_retrieved, compile_time_retrieved = cc.get_executable_and_time(
key, compile_options, backend)
inputs_to_executable = (
np.array(1, dtype=np.int32),
np.array(2, dtype=np.int32),
)
expected = xla_client.execute_with_python_values(
executable, inputs_to_executable, backend
)
actual = xla_client.execute_with_python_values(
executable_retrieved, inputs_to_executable, backend
jnp.array(1, dtype=np.int32),
jnp.array(2, dtype=np.int32),
)
expected = executable.execute(inputs_to_executable)
actual = executable_retrieved.execute(inputs_to_executable)
self.assertEqual(expected, actual)
self.assertEqual(FAKE_COMPILE_TIME, compile_time_retrieved)

Expand Down
7 changes: 3 additions & 4 deletions tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5590,7 +5590,6 @@ def test_isdtype(self, dtype, kind):
self.assertEqual(jax_result, numpy_result)


from jaxlib import xla_client
@unittest.skipIf(metal_plugin == None, "Tests require jax-metal plugin.")
class ReportedIssuesTests(jtu.JaxTestCase):
def dispatchOn(self, args, func, device=jax.devices('cpu')[0]):
Expand All @@ -5602,10 +5601,10 @@ def dispatchOn(self, args, func, device=jax.devices('cpu')[0]):
@staticmethod
def compile_and_exec(module, args, run_on_cpu=False):
backend = jax.lib.xla_bridge.get_backend('METAL')
if (run_on_cpu):
if run_on_cpu:
backend = jax.lib.xla_bridge.get_backend('cpu')
executables = backend.compile(module)
return xla_client.execute_with_python_values(executables, args, backend)
executable = backend.compile(module)
return executable.execute(args)

@staticmethod
def jax_metal_supported(target_ver):
Expand Down