Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Jax 0.4.13 support #279

Open
arnon-1 opened this issue Oct 23, 2024 · 4 comments
Open

Bug: Jax 0.4.13 support #279

arnon-1 opened this issue Oct 23, 2024 · 4 comments

Comments

@arnon-1
Copy link

arnon-1 commented Oct 23, 2024

Hello,

While trying to use kfac_jax in jax 0.4.13 (which is supported if I am not mistaken), I had to fix some errors.
I installed commit a4531e9 which is fairly recent.

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/utils/types.py", line 27, in <module>
    DType = jax.typing.DTypeLike
AttributeError: module 'jax.typing' has no attribute 'DTypeLike'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/curvature_blocks/curvature_block.py", line 123, in parameters_shapes
    return tuple(jax.tree.map(
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tracer.py", line 790, in forward
    write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tag_graph_matcher.py", line 68, in eval_jaxpr_eqn
    user_context = jax_extend.source_info_util.user_context
AttributeError: module 'jax.extend' has no attribute 'source_info_util'
@arnon-1 arnon-1 changed the title Jax 0.4.13 support Bug: Jax 0.4.13 support Oct 28, 2024
@james-martens
Copy link
Collaborator

I would recommend using the latest version of JAX. It's possible that changes made to the code to support the new version of JAX has broke compatibility with older version. e.g. the change to use jax.tree.map instead of jax.tree_util.tree_map. The 'source_info_util' error is strange to me since we have a version check for that. Maybe you have a weird corrupted version of the JAX library, or maybe our version check on that line is wrong.

@arnon-1
Copy link
Author

arnon-1 commented Nov 9, 2024

Thank you for your answer.
I would use the latest version if I could... I don't think the JAX library is the error as I received the same errors on multiple systems with different installation methods. Increasing the version check to 0.4.13 indeed seemed to fix that error (I wasn't sure how much information I was allowed to provide since PRs have been explicitly disallowed).
I guess the version bump solves this issue.

@james-martens
Copy link
Collaborator

Increasing the version check to 0.4.13 indeed seemed to fix that error

Sorry, which error was this? The 'source_info_util' one?

How did you fix the other errors if you are still using 0.4.13?

@arnon-1
Copy link
Author

arnon-1 commented Nov 9, 2024

For the first error: I copied the DTypeLike definition from the later jax version.

+from jax._src.typing import SupportsDType
-DType = jax.typing.DTypeLike
+DType = Union[
+  str,            # like 'float32', 'int32'
+  type[Any],      # like np.float32, np.int32, float, int
+  np.dtype,       # like np.dtype('float32'), np.dtype('int32')
+  SupportsDType,  # like jnp.float32, jnp.int32
+]

For the second error:
Most of the functionality in jax.tree used to be located in jax.tree_util (now a legacy api).
For example, I changed jax.tree.map to jax.tree_util.tree_map

For the third error:
I updated the previously mentioned version check, this seems to work:

-  if jax_version > (0, 4, 11):
+  if jax_version > (0, 4, 13):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants