You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi I am trying to use the levanter image but I get the following error: ModuleNotFoundError: No module named 'jax.experimental.maps'.
Was the model renamed? It worked fine yesterday
Thanks!
The complete error log:
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 119, in main
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
size = physical_axis_size(axis, mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
mesh = _get_mesh()
File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/levanter/src/levanter/main/train_lm.py", line 220, in <module>
levanter.config.main(main)()
File "/levanter/src/levanter/config.py", line 84, in wrapper_inner
response = fn(cfg, *args, **kwargs)
File "/levanter/src/levanter/main/train_lm.py", line 119, in main
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 597, in round_axis_for_partitioning
size = physical_axis_size(axis, mapping)
File "/opt/haliax/src/haliax/partitioning.py", line 566, in physical_axis_size
mesh = _get_mesh()
File "/opt/haliax/src/haliax/partitioning.py", line 606, in _get_mesh
from jax.experimental.maps import thread_resources
ModuleNotFoundError: No module named 'jax.experimental.maps'```
The text was updated successfully, but these errors were encountered:
Problem:
This problem is from the Haliax package: see here, thread_resource has moved into jax.experimental.mesh_utils.py with recent refactoring. They need to change that.
Solution:
You can fork Haliax repo yourself, fix the importation problem and replace the Haliax link in the said docker image here with link to your yours.
Hi I am trying to use the levanter image but I get the following error: ModuleNotFoundError: No module named 'jax.experimental.maps'.
Was the model renamed? It worked fine yesterday
Thanks!
The complete error log:
The text was updated successfully, but these errors were encountered: