Skip to content

Commit

Permalink
Enable T5x to take native_serialization_platforms for jax2tf
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559827754
  • Loading branch information
T5X Team authored and t5-copybara committed Aug 28, 2023
1 parent 828e910 commit 2746023
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions t5x/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def create_inference_function(
enable_xla: bool = True,
polymorphic_shapes_inputs: Optional[Any] = None,
native_lowering: bool = False,
native_lowering_platforms: Sequence[str] = (),
model_fn_extra_kwargs: Optional[Mapping[str, Any]] = None,
) -> Callable[[Mapping[str, Any], Any], PyTree]:
"""Fetches a model and returns the inference function based on inference_mode."""
Expand Down Expand Up @@ -355,6 +356,7 @@ def model_fn(
model_fn,
polymorphic_shapes=[None, polymorphic_shapes_inputs],
native_serialization=native_lowering,
native_serialization_platforms=native_lowering_platforms,
enable_xla=enable_xla,
)

Expand Down Expand Up @@ -1297,6 +1299,7 @@ def save(
mixture_or_task_name: Optional[str] = None,
validation_examples: Optional[List[Any]] = None,
native_lowering: bool = False,
native_lowering_platforms: Sequence[str] = (),
enable_xla: bool = True,
decode_outputs: Optional[bool] = None,
trailing_shapes: Optional[Mapping[str, Tuple[int, ...]]] = None,
Expand Down Expand Up @@ -1344,6 +1347,11 @@ def save(
model.
native_lowering: for experimental purposes only -- if True, don't convert
Jax fns to TF fns.
native_lowering_platforms: In conjunction with `native_lowering`, specify
the platform(s) for which to lower the code. Must be a tuple of strings,
including a subset of: 'cpu', 'cuda', 'rocm', 'tpu'. The default
(empty tuple), specifies the JAX default backend on the machine where the
lowering is done.
enable_xla: Defaults to true. If false, jax2tf conversion only emits non-XLA
ops.
decode_outputs: Optional bool. If provided, determines whether to decode the
Expand Down Expand Up @@ -1417,6 +1425,7 @@ def save(
input_signature, preprocessor
),
native_lowering=native_lowering,
native_lowering_platforms=native_lowering_platforms,
)

logging.info('Loading parameters from checkpoint...')
Expand Down

0 comments on commit 2746023

Please sign in to comment.