Skip to content

Commit

Permalink
Replace usage of jax.xla_computation with JAX AOT APIs. `jax.xla_comp…
Browse files Browse the repository at this point in the history
…utation` is deprecated and will be deleted soon.

PiperOrigin-RevId: 644434191
  • Loading branch information
jkr26 authored and copybara-github committed Jun 18, 2024
1 parent adeba2d commit c6a120c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeNoArgCallSingleTensor) {
xla::XlaOp constant = xla::ConstantR0<float>(&builder, 2.0);
// To mimic the Python tracing which always returns tuples, event for single
// element results, after passing through MLIR
// (https://github.com/google/jax/blob/38f91bdaa564a4de1e06bde7d191af0bff610bbf/jax/_src/api.py#L958),
// results are always in tuples.
xla::Tuple(&builder, {constant});
tensorflow::StatusOr<xla::XlaComputation> xla_computation = builder.Build();
Expand Down Expand Up @@ -523,7 +522,6 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityScalar) {
&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x");
// To mimic the Python tracing which always returns tuples, event for single
// element results, after passing through MLIR
// (https://github.com/google/jax/blob/38f91bdaa564a4de1e06bde7d191af0bff610bbf/jax/_src/api.py#L958),
// results are always in tuples.
xla::Tuple(&builder, {parameter});
tensorflow::StatusOr<xla::XlaComputation> xla_computation = builder.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def _jax_shape_dtype_struct_to_tff_tensor(
Raises:
TypeError: if arg type mismatches.
"""
py_typecheck.check_type(val, jax.ShapeDtypeStruct)
return computation_types.TensorType(val.dtype, val.shape)


Expand Down Expand Up @@ -214,18 +213,38 @@ def serialize_jax_computation(
tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

context = jax_computation_context.JaxComputationContext()
with context_stack.install(context):
tracer_callable = jax.xla_computation(fn, return_shape=True)
compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

if isinstance(returned_shape, jax.ShapeDtypeStruct):
returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape)
# TODO: b/347811116 - Remove this version check when the JAX version can be
# upgraded.
if jax.__version_info__ > (0, 4, 29):
with context_stack.install(context):
lowered = jax.jit(fn).lower(*args, **kwargs)
compiled_xla = lowered.compiler_ir('hlo')

if isinstance(lowered.out_info, jax.stages.OutInfo):
returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
jax.ShapeDtypeStruct(
shape=lowered.out_info.shape, dtype=lowered.out_info.dtype
)
)
else:
returned_type_spec = computation_types.to_type(
jax.tree_util.tree_map(
_jax_shape_dtype_struct_to_tff_tensor, lowered.out_info
)
)
else:
returned_type_spec = computation_types.to_type(
jax.tree_util.tree_map(
_jax_shape_dtype_struct_to_tff_tensor, returned_shape
)
)
with context_stack.install(context):
tracer_callable = jax.xla_computation(fn, return_shape=True)
compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

if isinstance(returned_shape, jax.ShapeDtypeStruct):
returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape)
else:
returned_type_spec = computation_types.to_type(
jax.tree_util.tree_map(
_jax_shape_dtype_struct_to_tff_tensor, returned_shape
)
)

computation_type = computation_types.FunctionType(
parameter_type, returned_type_spec
Expand Down

0 comments on commit c6a120c

Please sign in to comment.