From 23bd6c475b9444800b20c4ecb5e24aaa272b5067 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Wed, 2 Oct 2024 14:01:18 -0700 Subject: [PATCH] Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on. PiperOrigin-RevId: 681582933 --- haiku/_src/dot.py | 63 +++++++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index 8c80e6357..d525a2d6b 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -152,9 +152,9 @@ def method_hook(mod: module.Module, method_name: str): graph_stack.peek().subgraphs.append(subg.evolve(title=title)) with graph_stack(graph), \ - module.hook_methods(method_hook), \ - jax.core.new_main(DotTrace) as main: - out_flat = _interpret_subtrace(flat_fun, main).call_wrapped(*args_flat) + module.hook_methods(method_hook): + tag = jax.core.TraceTag() + out_flat = _interpret_subtrace(flat_fun, tag).call_wrapped(*args_flat) out = jax.tree.unflatten(out_tree(), out_flat) return graph, args, out @@ -163,20 +163,20 @@ def method_hook(mod: module.Module, method_name: str): @lu.transformation -def _interpret_subtrace(main, *in_vals): - trace = DotTrace(main, jax.core.cur_sublevel()) - in_tracers = [DotTracer(trace, val) for val in in_vals] - outs = yield in_tracers, {} - out_tracers = map(trace.full_raise, outs) - out_vals = [t.val for t in out_tracers] - yield out_vals +def _interpret_subtrace(tag, *in_vals): + with jax.core.take_current_trace() as parent_trace: + trace = DotTrace(parent_trace, tag) + with jax.core.set_current_trace(trace): + in_tracers = [DotTracer(trace, val) for val in in_vals] + outs = yield in_tracers, {} + yield [trace.to_val(t) for t in outs] class DotTracer(jax.core.Tracer): """JAX tracer used in DotTrace.""" def __init__(self, trace, val): - super().__init__(trace) + self._trace = trace self.val = val @property @@ -190,30 +190,31 @@ def full_lower(self): class DotTrace(jax.core.Trace): """Traces a JAX function to dot.""" - def pure(self, val): - return DotTracer(self, val) + def __init__(self, parent_trace, tag): + self.parent_trace = parent_trace + self.tag = tag - def lift(self, val): - return DotTracer(self, val) - - def sublift(self, val): - return DotTracer(self, val.val) + def to_val(self, val): + if isinstance(val, DotTracer) and val._trace.tag is self.tag: # pylint:disable=protected-access + return val.val + else: + return val def process_primitive(self, primitive, tracers, params): - val_out = primitive.bind(*[t.val for t in tracers], **params) + vals = [self.to_val(t) for t in tracers] + val_out = primitive.bind_with_trace(self.parent_trace, vals, params) if primitive is pjit.pjit_p: f = jax.core.jaxpr_as_fun(params['jaxpr']) f.__name__ = params['name'] fun = lu.wrap_init(f) return self.process_call(primitive, fun, tracers, params) - inputs = [t.val for t in tracers] outputs = list(jax.tree.leaves(val_out)) graph = graph_stack.peek() node = Node(id=outputs[0], title=str(primitive), outputs=outputs) graph.nodes.append(node) - graph.edges.extend([(i, outputs[0]) for i in inputs]) + graph.edges.extend([(i, outputs[0]) for i in vals]) return jax.tree.map(lambda v: DotTracer(self, v), val_out) @@ -221,16 +222,18 @@ def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results if (call_primitive in (pjit.pjit_p,) and params.get('inline', False)): - f = _interpret_subtrace(f, self.main) - vals_out = f.call_wrapped(*[t.val for t in tracers]) - return [DotTracer(self, v) for v in vals_out] + f = _interpret_subtrace(f, self.tag) + with jax.core.set_current_trace(self.parent_trace): + vals_out = f.call_wrapped(*[self.to_val(t) for t in tracers]) + return [DotTracer(self, v) for v in vals_out] graph = Graph.create(title=f'{call_primitive} ({name_or_str(f.f)})') graph_stack.peek().subgraphs.append(graph) with graph_stack(graph): - f = _interpret_subtrace(f, self.main) - vals_out = f.call_wrapped(*[t.val for t in tracers]) - return [DotTracer(self, v) for v in vals_out] + f = _interpret_subtrace(f, self.tag) + with jax.core.set_current_trace(self.parent_trace): + vals_out = f.call_wrapped(*[self.to_val(t) for t in tracers]) + return [DotTracer(self, v) for v in vals_out] process_map = process_call @@ -238,13 +241,15 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros): # Drop the custom differentiation rule. del primitive, jvp, symbolic_zeros # Unused. - return fun.call_wrapped(*tracers) + with jax.core.set_current_trace(self.parent_trace): + return fun.call_wrapped(*tracers) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): # Drop the custom differentiation rule. del primitive, fwd, bwd, out_trees, symbolic_zeros # Unused. - return fun.call_wrapped(*tracers) + with jax.core.set_current_trace(self.parent_trace): + return fun.call_wrapped(*tracers) def _format_val(val):