Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graphcast fails at xarray_jax #2

Open
EricLeer opened this issue Oct 21, 2023 · 2 comments
Open

Graphcast fails at xarray_jax #2

EricLeer opened this issue Oct 21, 2023 · 2 comments

Comments

@EricLeer
Copy link

Trying to run graphcast with the following command:

ai-models --input cds --date 20231001 --time 0000 graphcast

but I get the following error:

Traceback (most recent call last):
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/bin/ai-models", line 8, in <module>
    sys.exit(main())
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models/__main__.py", line 291, in main
    _main()
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models/__main__.py", line 264, in _main
    model.run()
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 237, in run
    output = self.model(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 114, in <lambda>
    return lambda **kw: fn(**kw)[0]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 255, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 161, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/api.py", line 325, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 485, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 962, in _pjit_jaxpr
    jaxpr, final_consts, out_type = _create_pjit_jaxpr(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 348, in memoized_fun
    ans = call(fun, *args)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/pjit.py", line 915, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2203, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2225, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/jax/_src/linear_util.py", line 190, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/haiku/_src/transform.py", line 457, in apply_fn
    out = f(*args, **kwargs)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/ai_models_graphcast/model.py", line 168, in run_forward
    return predictor(
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/graphcast/autoregressive.py", line 169, in __call__
    target_template = targets_template.isel(time=[0])
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/dataset.py", line 2920, in isel
    var = var.isel(var_indexers)
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/variable.py", line 1135, in isel
    return self[key]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/variable.py", line 811, in __getitem__
    data = as_indexable(self._data)[indexer]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/xarray/core/indexing.py", line 1336, in __getitem__
    return array[key]
  File "/Users/ericleer/.pyenv/versions/ec-graphcast-env/lib/python3.10/site-packages/graphcast/xarray_jax.py", line 352, in wrapped_func
    args, kwargs = tree.map_structure(unwrap, (args, kwargs))
AttributeError: module 'tree' has no attribute 'map_structure'

Running on python 3.10 with the following package versions:

ai-models           0.2.14
ai-models-graphcast 0.0.4
jax                 0.4.19
jaxlib              0.4.19

Any idea on what the problem might be? It seems to originate from xarray_jax which is trying to call an attribute that doesnt exist.

@Dadoof
Copy link

Dadoof commented Nov 1, 2023

In my build, I did this:
sudo pip3 install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Unsure how you installed Jax, but if you did not do the above - might be worth a try. This is only a guess, on my part.

@mjwillson
Copy link
Collaborator

Hello, it looks like this may relate to an issue where graphcast was pulling in the wrong tree library dependency. That should now be resolved as the following PR has been merged: google-deepmind/graphcast#25
would you mind reinstalling graphcast from git and checking if you still see the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants