Skip to content

Commit

Permalink
Merge pull request #24930 from hawkinsp:dedup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697544119
  • Loading branch information
Google-ML-Automation committed Nov 18, 2024
2 parents ed250b8 + 626aea0 commit f7ae0f9
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,38 @@ def _emit_lowering_rule_as_fun(lowering_rule,
return func_op


class HashableLiteral:
"""Hashable wrapper of core.Literal, used for deduplicating IR constants."""

__slots__ = ["value"]

value: core.Literal

def __init__(self, value):
self.value = value

def __hash__(self):
h = self.value.hash
return id(self.value.val) if h is None else h

def __eq__(self, other):
if self is other:
return True
if type(self.value.val) != type(other.value.val):
return False
if self.value.aval != other.value.aval:
return False
if isinstance(self.value.val, (bool, int, float, complex)):
return self.value == other.value
if isinstance(self.value.val, (np.generic, np.ndarray)):
return np.array_equal(
self.value.val, other.value.val,
equal_nan=np.issubdtype(self.value.val.dtype, np.inexact))
# Since the use case is constant deduplication, it's safe to return
# False in unhandled cases.
return False


def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
name_stack: source_info_util.NameStack,
tokens: TokenSet,
Expand All @@ -1767,9 +1799,16 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert "gpu" not in ctx.platforms
cached_ir_consts: dict[HashableLiteral, IrValues] = {}

def read(v: core.Atom) -> IrValues:
if type(v) is core.Literal:
return ir_constant(xla.canonicalize_dtype(v.val))
h = HashableLiteral(v)
c = cached_ir_consts.get(h)
if c is None:
c = ir_constant(xla.canonicalize_dtype(v.val))
cached_ir_consts[h] = c
return c
else:
assert isinstance(v, core.Var)
return env[v]
Expand Down

0 comments on commit f7ae0f9

Please sign in to comment.