Skip to content

Commit

Permalink
Stackless yashful
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681582933
  • Loading branch information
dougalm authored and copybara-github committed Oct 3, 2024
1 parent af844e9 commit 89bef09
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
7 changes: 0 additions & 7 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ class JaxTraceLevel(NamedTuple):

@classmethod
def current(cls):
if jax.__version_info__ <= (0, 4, 33):
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
top_type = trace_stack[0].trace_type
level = trace_stack[-1].level
sublevel = jax_core.cur_sublevel()
return JaxTraceLevel(opaque=(top_type, level, sublevel))

ts = jax_core.get_opaque_trace_state(convention="haiku")
return JaxTraceLevel(opaque=ts)

Expand Down
45 changes: 21 additions & 24 deletions haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def method_hook(mod: module.Module, method_name: str):
graph_stack.peek().subgraphs.append(subg.evolve(title=title))

with graph_stack(graph), \
module.hook_methods(method_hook), \
jax.core.new_main(DotTrace) as main:
out_flat = _interpret_subtrace(flat_fun, main).call_wrapped(*args_flat)
module.hook_methods(method_hook):
tag = jax.core.TraceTag()
out_flat = _interpret_subtrace(flat_fun, tag).call_wrapped(*args_flat)
out = jax.tree_util.tree_unflatten(out_tree(), out_flat)

return graph, args, out
Expand All @@ -162,44 +162,41 @@ def method_hook(mod: module.Module, method_name: str):


@lu.transformation
def _interpret_subtrace(main, *in_vals):
trace = DotTrace(main, jax.core.cur_sublevel())
in_tracers = [DotTracer(trace, val) for val in in_vals]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals = [t.val for t in out_tracers]
yield out_vals
def _interpret_subtrace(tag, *in_vals):
trace = DotTrace(tag)
with jax.core.set_current_trace(trace):
in_tracers = [DotTracer(trace, val) for val in in_vals]
outs = yield in_tracers, {}
return [trace.to_val(t) for t in outs]


class DotTracer(jax.core.Tracer):
"""JAX tracer used in DotTrace."""

def __init__(self, trace, val):
super().__init__(trace)
self._trace = trace
self.val = val

@property
def aval(self):
return jax.core.get_aval(self.val)

def full_lower(self):
return self


class DotTrace(jax.core.Trace):
"""Traces a JAX function to dot."""

def pure(self, val):
return DotTracer(self, val)

def lift(self, val):
return DotTracer(self, val)
def __init__(self, tag):
self.tag = tag

def sublift(self, val):
return DotTracer(self, val.val)
def to_val(self, val):
if isinstance(val, DotTracer) and val._trace.tag is self.tag: # pylint:disable=protected-access
return val.val
else:
return val

def process_primitive(self, primitive, tracers, params):
val_out = primitive.bind(*[t.val for t in tracers], **params)
with jax.core.concrete_eval():
val_out = primitive.bind(*[self.to_val(t) for t in tracers], **params)
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
Expand All @@ -220,14 +217,14 @@ def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
if (call_primitive in (pjit.pjit_p,) and
params.get('inline', False)):
f = _interpret_subtrace(f, self.main)
f = _interpret_subtrace(f, self.tag)
vals_out = f.call_wrapped(*[t.val for t in tracers])
return [DotTracer(self, v) for v in vals_out]

graph = Graph.create(title=f'{call_primitive} ({name_or_str(f.f)})')
graph_stack.peek().subgraphs.append(graph)
with graph_stack(graph):
f = _interpret_subtrace(f, self.main)
f = _interpret_subtrace(f, self.tag)
vals_out = f.call_wrapped(*[t.val for t in tracers])
return [DotTracer(self, v) for v in vals_out]

Expand Down

0 comments on commit 89bef09

Please sign in to comment.