You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/bin/ai-models", line 8, in <module>
sys.exit(main())
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models/__main__.py", line 291, in main
_main()
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models/__main__.py", line 264, in _main
model.run()
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 237, in run
output = self.model(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 114, in <lambda>
return lambda **kw: fn(**kw)[0]
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 255, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/api.py", line 325, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 485, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 962, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 348, in memoized_fun
ans = call(fun, *args)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 915, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2203, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2225, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 190, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/haiku/_src/transform.py", line 457, in apply_fn
out = f(*args, **kwargs)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 168, in run_forward
return predictor(
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/graphcast/autoregressive.py", line 169, in __call__
target_template = targets_template.isel(time=[0])
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/dataset.py", line 2920, in isel
var = var.isel(var_indexers)
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/variable.py", line 1135, in isel
return self[key]
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/variable.py", line 811, in __getitem__
data = as_indexable(self._data)[indexer]
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/indexing.py", line 1336, in __getitem__
return array[key]
File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/graphcast/xarray_jax.py", line 352, in wrapped_func
args, kwargs = tree.map_structure(unwrap, (args, kwargs))
AttributeError: module 'tree' has no attribute 'map_structure'
Running on python 3.10 with the following package versions:
Hello, it looks like this may relate to an issue where graphcast was pulling in the wrong tree library dependency. That should now be resolved as the following PR has been merged: google-deepmind/graphcast#25
would you mind reinstalling graphcast from git and checking if you still see the problem?
Trying to run graphcast with the following command:
but I get the following error:
Running on python 3.10 with the following package versions:
Any idea on what the problem might be? It seems to originate from xarray_jax which is trying to call an attribute that doesnt exist.
The text was updated successfully, but these errors were encountered: