Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chex.Dimensions API enhancement #231

Open
wbrenton opened this issue Jan 27, 2023 · 1 comment
Open

chex.Dimensions API enhancement #231

wbrenton opened this issue Jan 27, 2023 · 1 comment

Comments

@wbrenton
Copy link

I would like to propose an API enhancement that allow the use of chex.Dimensions inside function annotations. If there is interest I'd like to contribute. Example below:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
def foo(arr: chex.Array):
     chex.assert_shape(arr, dims['BTE'])
     # fn logic

### turns into ###

def foo(arr: chex.Array(dims['BTE'])): # behind the scenes assert on function call
     # fn logic

This is particularly useful for dataclasses e.g.

dims = chex.Dimensions(B=batch_size, T=rollout_len)

# asserts are run on instantiation
class TimeStep:
     q_values: chex.Array(dims['BT']) 
     discounts: chex.Array(dims['BT']) 
     rewards: chex.Array(dims['BT']) 

Pros:

  • reduces clutter that asserts can add
  • allows user to view the shape expected by function or class in editor (not sure what you call the VScode popup)
    • example: using RLax, in order to know what shape is expected for each arg in a loss fn you need to either look at source code or wait for fn call to raise an assert

Cons:

  • increased API complexity
  • ...?
@KristianHolsheimer
Copy link
Contributor

Thanks for your interest in chex!

This suggestion is very interesting. Many of us working with arrays in python on a daily basis are eagerly awaiting PEP 646, which was accepted into python version 3.11.

Once python 3.11 becomes more mainstream we will definitely consider incorporating shape annotations into chex. And perhaps we could augment or fork chex.Dimensions to return TypeVarTuples, along the lines of your suggestion.

For the time being, however, we will not implement such a change. In particular, mixing runtime checks with static type annotation is out of scope, at least for now.

P.S. If you're interested in doing type annotation at runtime, you might find the pydantic project useful: https://docs.pydantic.dev/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants