diff --git a/haiku/_src/base.py b/haiku/_src/base.py index c23579540..60d56c54c 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -70,13 +70,6 @@ class JaxTraceLevel(NamedTuple): @classmethod def current(cls): - if jax.__version_info__ <= (0, 4, 33): - trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack - top_type = trace_stack[0].trace_type - level = trace_stack[-1].level - sublevel = jax_core.cur_sublevel() - return JaxTraceLevel(opaque=(top_type, level, sublevel)) - ts = jax_core.get_opaque_trace_state(convention="haiku") return JaxTraceLevel(opaque=ts) diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index cefa1738a..62509850c 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -151,9 +151,9 @@ def method_hook(mod: module.Module, method_name: str): graph_stack.peek().subgraphs.append(subg.evolve(title=title)) with graph_stack(graph), \ - module.hook_methods(method_hook), \ - jax.core.new_main(DotTrace) as main: - out_flat = _interpret_subtrace(flat_fun, main).call_wrapped(*args_flat) + module.hook_methods(method_hook): + tag = jax.core.TraceTag() + out_flat = _interpret_subtrace(flat_fun, tag).call_wrapped(*args_flat) out = jax.tree_util.tree_unflatten(out_tree(), out_flat) return graph, args, out @@ -162,44 +162,41 @@ def method_hook(mod: module.Module, method_name: str): @lu.transformation -def _interpret_subtrace(main, *in_vals): - trace = DotTrace(main, jax.core.cur_sublevel()) - in_tracers = [DotTracer(trace, val) for val in in_vals] - outs = yield in_tracers, {} - out_tracers = map(trace.full_raise, outs) - out_vals = [t.val for t in out_tracers] - yield out_vals +def _interpret_subtrace(tag, *in_vals): + trace = DotTrace(tag) + with jax.core.set_current_trace(trace): + in_tracers = [DotTracer(trace, val) for val in in_vals] + outs = yield in_tracers, {} + return [trace.to_val(t) for t in outs] class DotTracer(jax.core.Tracer): """JAX tracer used in DotTrace.""" def __init__(self, trace, val): - super().__init__(trace) + self._trace = trace self.val = val @property def aval(self): return jax.core.get_aval(self.val) - def full_lower(self): - return self - class DotTrace(jax.core.Trace): """Traces a JAX function to dot.""" - def pure(self, val): - return DotTracer(self, val) - - def lift(self, val): - return DotTracer(self, val) + def __init__(self, tag): + self.tag = tag - def sublift(self, val): - return DotTracer(self, val.val) + def to_val(self, val): + if isinstance(val, DotTracer) and val._trace.tag is self.tag: # pylint:disable=protected-access + return val.val + else: + return val def process_primitive(self, primitive, tracers, params): - val_out = primitive.bind(*[t.val for t in tracers], **params) + with jax.core.concrete_eval(): + val_out = primitive.bind(*[self.to_val(t) for t in tracers], **params) if primitive is pjit.pjit_p: f = jax.core.jaxpr_as_fun(params['jaxpr']) f.__name__ = params['name'] @@ -220,14 +217,14 @@ def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results if (call_primitive in (pjit.pjit_p,) and params.get('inline', False)): - f = _interpret_subtrace(f, self.main) + f = _interpret_subtrace(f, self.tag) vals_out = f.call_wrapped(*[t.val for t in tracers]) return [DotTracer(self, v) for v in vals_out] graph = Graph.create(title=f'{call_primitive} ({name_or_str(f.f)})') graph_stack.peek().subgraphs.append(graph) with graph_stack(graph): - f = _interpret_subtrace(f, self.main) + f = _interpret_subtrace(f, self.tag) vals_out = f.call_wrapped(*[t.val for t in tracers]) return [DotTracer(self, v) for v in vals_out]