Skip to content

Commit

Permalink
[Pallas] Add state discharge rule for pallas_call
Browse files Browse the repository at this point in the history
This enables us to avoid spurious copies in the cases outlined in [the async operations design note](https://jax.readthedocs.io/en/latest/pallas/async_note.html) but not in general, since JAX and/or XLA could introduce copies because we have value semantics. For a proper solution, we need to introduce some notion of buffer semantics to XLA/HLO and preserve it through the lowering of stateful JAX (maybe by avoiding state discharge altogether).

PiperOrigin-RevId: 676631102
  • Loading branch information
sharadmv authored and Google-ML-Automation committed Oct 1, 2024
1 parent afed9f4 commit 13e15ba
Show file tree
Hide file tree
Showing 7 changed files with 608 additions and 160 deletions.
6 changes: 4 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2745,8 +2745,10 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree,
def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree,
device_id_type: tpu_primitives.DeviceIdType):
del device_id_type
sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args)
sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in)
(_, _, ref, transforms, sem, sem_transforms, _, _, _) = tree_util.tree_unflatten(
tree, args)
(_, _, ref_aval, _, sem_aval, _, _, _, _) = tree_util.tree_unflatten(
tree, ctx.avals_in)
block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes)
ref_block_shape = block_shapes[2]
ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms)
Expand Down
110 changes: 69 additions & 41 deletions jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,18 +431,34 @@ def __post_init__(self):
def is_remote(self):
return self.src_sem is not None

def _get_args_and_tree(self, swap_src_and_dst: bool = False):
if swap_src_and_dst:
return tree_util.tree_flatten((
self.dst_ref,
self.dst_transforms,
self.src_ref,
self.src_transforms,
self.src_sem,
self.src_sem_transforms,
self.dst_sem,
self.dst_sem_transforms,
self.device_id,
))
else:
return tree_util.tree_flatten((
self.src_ref,
self.src_transforms,
self.dst_ref,
self.dst_transforms,
self.dst_sem,
self.dst_sem_transforms,
self.src_sem,
self.src_sem_transforms,
self.device_id,
))

def start(self):
flat_args, tree = tree_util.tree_flatten((
self.src_ref,
self.src_transforms,
self.dst_ref,
self.dst_transforms,
self.dst_sem,
self.dst_sem_transforms,
self.src_sem,
self.src_sem_transforms,
self.device_id,
))
flat_args, tree = self._get_args_and_tree()
dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type)

def wait(self):
Expand All @@ -451,27 +467,20 @@ def wait(self):
self.wait_recv()

def wait_recv(self):
wait_args, tree = tree_util.tree_flatten((
self.dst_sem,
self.dst_sem_transforms,
self.dst_ref,
self.dst_transforms,
))
flat_args, tree = self._get_args_and_tree()
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
*flat_args, tree=tree, device_id_type=self.device_id_type
)

def wait_send(self):
if not self.is_remote:
raise ValueError("Cannot `wait_send` on a local copy.")
wait_args, tree = tree_util.tree_flatten((
self.src_sem,
self.src_sem_transforms,
self.src_ref,
self.src_transforms,
))
# We swap src and dst since by default dma_wait_p waits on the dst_sem
# As a clean up, maybe we could modify the primitive to have a
# `wait_on_send` bool.
flat_args, tree = self._get_args_and_tree(swap_src_and_dst=True)
dma_wait_p.bind(
*wait_args, tree=tree, device_id_type=self.device_id_type
*flat_args, tree=tree, device_id_type=self.device_id_type
)


Expand Down Expand Up @@ -689,7 +698,17 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,
del settings
invars = eqn.invars
tree = eqn.params["tree"]
sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars)
(
_,
_,
ref,
transforms,
sem,
sem_transforms,
_,
_,
_,
) = tree_util.tree_unflatten(tree, invars)
return pp.concat([
pp.text("dma_wait"),
pp.text(" "),
Expand All @@ -702,29 +721,38 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn,

def dma_wait_discharge_rule(in_avals, out_avals,
*args, tree, device_id_type):
# TODO(b/370563115): perform ref update in dma_wait discharge rule instead of dma_start
del out_avals, device_id_type
(sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten(
tree, args
)
(
sem_aval,
sem_transforms_avals,
_, _, dst_ref, dst_ref_transforms, dst_sem, dst_sem_transforms, _, _, _ = (
tree_util.tree_unflatten(tree, args))
(_,
src_ref_transforms_avals,
_,
ref_transforms_avals,
dst_ref_transforms_avals,
dst_sem_aval,
dst_sem_transforms_avals,
src_sem_aval,
src_sem_transforms_avals,
device_id_aval,
) = tree_util.tree_unflatten(tree, in_avals)
num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(ref_transforms_avals))
updates = state_discharge.transform_array(ref, ref_transforms)
num_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals))
num_transforms = len(tree_util.tree_leaves(dst_ref_transforms_avals))
updates = state_discharge.transform_array(dst_ref, dst_ref_transforms)
copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE)
copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE)
sem_value = _transform_semaphore(sem, sem_transforms, sem_aval)
sem_value = _transform_semaphore(dst_sem, dst_sem_transforms, dst_sem_aval)
_, new_sem = state_discharge.transform_swap_array(
sem, sem_transforms, sem_value - copy_size
dst_sem, dst_sem_transforms, sem_value - copy_size
)
new_vals = (new_sem,) # sem
new_vals += (None,) * num_sem_transforms
new_vals = (None,) # src_ref
new_vals += (None,) * len(tree_util.tree_leaves(src_ref_transforms_avals))
new_vals += (None,) # ref
new_vals += (None,) * num_transforms
new_vals += (None,) * num_transforms # ref_transforms
new_vals += (new_sem,) # sem
new_vals += (None,) * num_sem_transforms
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_aval)) # src_sem
new_vals += (None,) * len(tree_util.tree_leaves(src_sem_transforms_avals))
new_vals += (None,) * len(tree_util.tree_leaves(device_id_aval)) # device_id
return new_vals, []
state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule)

Expand Down
151 changes: 143 additions & 8 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
Expand Down Expand Up @@ -208,6 +209,7 @@ def _pallas_call_impl_interpret(
print(discharged_jaxpr)
out = _initialize_output_vals(grid_mapping.block_mappings_output,
args, input_output_aliases)
# TODO(b/370563936): Fix correctness issue w/ io aliasing
scalars = args[grid_mapping.slice_index_ops]
block_args = args[len(scalars):]
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
Expand Down Expand Up @@ -936,7 +938,7 @@ def g():
with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()):
kernel_src_info: pallas_core.SrcInfoStr = "<Wrapped outer kernel>"

jaxpr = _trace_kernel_to_jaxpr(
jaxpr, consts = _trace_kernel_to_jaxpr(
when_wrapped_kernel,
kernel_src_info,
batched_grid_mapping,
Expand All @@ -945,6 +947,8 @@ def g():
tuple(() for _ in flat_kernel_avals),
interpret=interpret,
)
if consts:
raise NotImplementedError("consts not supported in pallas_call")

assert ragged_axis_length is not None
args = (ragged_axis_length, *args)
Expand Down Expand Up @@ -1160,7 +1164,7 @@ def _trace_kernel_to_jaxpr(
kernel_in_tree: tree_util.PyTreeDef,
kernel_in_transforms: tuple[tuple[pallas_core.Transform, ...], ...],
interpret: bool,
) -> jax_core.ClosedJaxpr:
) -> tuple[jax_core.ClosedJaxpr, tuple[jax.Array, ...]]:
if interpret:
kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval,
kernel_avals))
Expand All @@ -1174,17 +1178,18 @@ def _trace_kernel_to_jaxpr(
if consts:
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"captures constants {consts_avals}. "
"You should pass them as inputs")
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"captures constants {consts_avals}. "
"You should pass them as inputs")

kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
f"should return None. It returns a PyTree: {kernel_out_tree}")
return jaxpr
return jaxpr, tuple(consts)


_PALLAS_USE_MOSAIC_GPU = config.bool_flag(
Expand All @@ -1209,6 +1214,8 @@ def _unsupported_lowering_error(platform: str) -> Exception:
def _pallas_call_lowering(
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
):
if params['jaxpr'].constvars:
raise ValueError('Cannot lower a pallas_call with constants.')
if interpret:
# If we are in interpret mode, we don't care what platform we are on.
impl = partial(_pallas_call_impl_interpret, **params)
Expand Down Expand Up @@ -1286,6 +1293,133 @@ def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue:
return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype)


def _get_memory_space_from_ref(ref_aval: state.AbstractRef) -> Any:
if isinstance(ref_aval, pallas_core.AbstractMemoryRef):
return ref_aval.memory_space
return pallas_core.MemorySpace.ANY


@state_discharge.register_discharge_rule(pallas_call_p)
def _pallas_call_state_discharge_rule(
avals_in,
avals_out,
*args,
jaxpr: jax_core.Jaxpr,
input_output_aliases: tuple[tuple[int, int], ...],
name_and_src_info: pallas_core.NameAndSrcInfo,
grid_mapping: GridMapping,
debug: bool,
interpret: bool,
compiler_params: Any,
cost_estimate: CostEstimate | None,
out_avals: tuple[jax_core.AbstractValue, ...],
):
del avals_out
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
num_refs = len(jaxpr.constvars)
ref_avals, rest_in_avals = split_list(avals_in, [num_refs])
assert all(isinstance(ref_aval, state.AbstractRef) for ref_aval in ref_avals)
ref_avals = [
pallas_core.AbstractMemoryRef(
ref_aval.inner_aval, pallas_core.MemorySpace.ANY
)
for ref_aval in ref_avals
]
ref_block_specs = [
pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)
] * num_refs
ref_block_mappings = [
block_spec.to_block_mapping(
origin="", # TODO(sharadmv): enable origins for refs
array_aval=ref_aval.inner_aval,
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.mapped_dims,
) for ref_aval, block_spec in zip(ref_avals, ref_block_specs)
]
in_block_mappings, out_block_mappings = split_list(
grid_mapping.block_mappings, [grid_mapping.num_inputs]
)
new_block_mappings = (
*ref_block_mappings,
*in_block_mappings,
*ref_block_mappings,
*out_block_mappings,
)
new_grid_mapping = grid_mapping.replace(
block_mappings=new_block_mappings,
num_inputs=grid_mapping.num_inputs + num_refs,
num_outputs=grid_mapping.num_outputs + num_refs)
new_input_output_aliases = [
(i + grid_mapping.num_index_operands, i) for i in range(num_refs)
]
for i, o in input_output_aliases:
new_input_output_aliases.append((i + num_refs, o + num_refs))
ref_out_avals = [ref_aval.inner_aval for ref_aval in ref_avals]
new_out_avals = (*ref_out_avals, *out_avals)
ref_args, dynamic_grid_bounds, index_operands, rest_args = split_list(
args,
[
num_refs,
grid_mapping.num_dynamic_grid_bounds,
grid_mapping.num_index_operands,
],
)
def _rewritten_body(*args):
index_args, in_args, out_args, rest_args = split_list(
args, [new_grid_mapping.num_index_operands, new_grid_mapping.num_inputs,
new_grid_mapping.num_outputs])
ref_in_args, in_args = split_list(in_args, [num_refs])
ref_out_args, out_args = split_list(out_args, [num_refs])
# We don't care about ref_out_args because they are aliased to ref_in_args
del ref_out_args
jax_core.eval_jaxpr(
jaxpr, ref_in_args, *index_args, *in_args, *out_args, *rest_args
)
return []
index_map_avals, jaxpr_in_avals, jaxpr_out_avals, jaxpr_rest_avals = (
split_list(
[v.aval for v in jaxpr.invars],
[
grid_mapping.num_index_operands,
grid_mapping.num_inputs,
grid_mapping.num_outputs,
],
)
)
new_jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_rewritten_body),
[
*index_map_avals,
*ref_avals,
*jaxpr_in_avals,
*ref_avals,
*jaxpr_out_avals,
*jaxpr_rest_avals,
],
)
out_flat = pallas_call_p.bind(
*consts,
*dynamic_grid_bounds,
*index_operands,
*ref_args,
*rest_args,
jaxpr=new_jaxpr,
input_output_aliases=new_input_output_aliases,
grid_mapping=new_grid_mapping,
name_and_src_info=name_and_src_info,
debug=debug,
interpret=interpret,
compiler_params=compiler_params,
cost_estimate=cost_estimate,
out_avals=new_out_avals,
)
refs_out, rest = split_list(out_flat, [num_refs])
updated_vals_in = refs_out + [None] * len(rest_in_avals)
return updated_vals_in, rest


def pallas_call(
kernel: Callable[..., None],
out_shape: Any,
Expand Down Expand Up @@ -1440,7 +1574,7 @@ def wrapped(*args):
for x in flat_kernel_args
)
with pallas_core.interpret_mode_env(interpret):
jaxpr = _trace_kernel_to_jaxpr(
jaxpr, consts = _trace_kernel_to_jaxpr(
kernel, kernel_src_info, grid_mapping, tuple(flat_kernel_avals),
kernel_in_tree, kernel_arg_transforms, interpret=interpret)
for i_idx, o_idx in input_output_aliases.items():
Expand All @@ -1467,6 +1601,7 @@ def wrapped(*args):
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
with pallas_core.interpret_mode_env(interpret):
out_flat = pallas_call_p.bind(
*consts,
*dynamic_grid_bounds,
*index_args,
*rest_args,
Expand Down
Loading

0 comments on commit 13e15ba

Please sign in to comment.