Skip to content

Commit

Permalink
migrate from deprecated jax.linear_util to jax.extend.linear_util
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576905766
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Oct 26, 2023
1 parent 402a701 commit e82294e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ hk_py_library(
":module",
":utils",
# pip: jax
# pip: jax:extend
# pip: tree
],
)
Expand Down
7 changes: 4 additions & 3 deletions haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import jax
import jax.core
from jax.experimental import pjit
from jax.extend import linear_util as lu


# Import tree if available, but only throw error at runtime.
Expand Down Expand Up @@ -134,7 +135,7 @@ def to_graph(fun):
@functools.wraps(fun)
def wrapped_fun(*args):
"""See `fun`."""
f = jax.linear_util.wrap_init(fun)
f = lu.wrap_init(fun)
args_flat, in_tree = jax.tree_util.tree_flatten((args, {}))
flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree)
graph = Graph.create(title=name_or_str(fun))
Expand All @@ -160,7 +161,7 @@ def method_hook(mod: module.Module, method_name: str):
return wrapped_fun


@jax.linear_util.transformation
@lu.transformation
def _interpret_subtrace(main, *in_vals):
trace = DotTrace(main, jax.core.cur_sublevel())
in_tracers = [DotTracer(trace, val) for val in in_vals]
Expand Down Expand Up @@ -202,7 +203,7 @@ def process_primitive(self, primitive, tracers, params):
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
fun = jax.linear_util.wrap_init(f)
fun = lu.wrap_init(f)
return self.process_call(primitive, fun, tracers, params)

inputs = [t.val for t in tracers]
Expand Down

0 comments on commit e82294e

Please sign in to comment.