Thank you for your interest in contributing to CAX! We deeply appreciate you taking the time to help make CAX better. Whether you're contributing code, suggesting new features, opening an issue, improving documentation or writing tutorials - all contributions are valuable and welcome.
We also appreciate if you spread the word, for instance by starring the CAX GitHub repository, or referencing CAX in projects that used it.
We do all of our development using git, so basic knowledge is assumed.
Follow these steps to contribute code:
-
Fork the CAX repository by clicking the Fork button on the repository page. This creates a copy of the CAX repository in your own account.
-
Install Python >= 3.10 locally in order to run tests.
-
pip
installing your fork from source. This allows you to modify the code and immediately test it out:
git clone https://github.com/maxencefaldor/cax
cd cax
pip install -e ".[dev]" # Installs CAX from the current directory in editable mode.
- Add the CAX repository as an upstream remote, so you can use it to sync your changes.
git remote add upstream https://github.com/maxencefaldor/cax
- Create a branch where you will develop from:
git checkout -b name-of-change
And implement your changes using your favorite editor.
- Make sure your code passes CAX’s lint and type checks, by running the following from the top of the repository:
ruff format .
ruff check .
- Make sure the tests pass by running the following command from the top of the repository:
pytest tests/
- Once you are satisfied with your change, create a commit as follows ( how to write a commit message):
git add file1.py file2.py ...
git commit -m "Your commit message"
Then sync your code with the main repo:
git fetch upstream
git rebase upstream/main
Finally, push your commit on your development branch and create a remote branch in your fork that you can use to create a pull request from:
git push --set-upstream origin name-of-change
- Create a pull request from the CAX repository and send it for review.
Go to https://github.com/maxencefaldor/cax/issues and click on "New issue".
Informative bug reports tend to have:
- A quick summary
- Steps to reproduce
- Be specific!
- Give sample code if you can.
- What you expected would happen
- What actually happens
- Additional notes
- Every CA in CAX inherits from
nnx.Module
and follows the perceive/update architecture - The perceive module defines how cells observe their neighborhood (e.g.,
ConvPerceive
) - The update module specifies how cells update their state based on these observations (e.g.,
ResidualUpdate
,NCAUpdate
,LeniaUpdate
)
- Vectorization: Use JAX's
vmap
for operations applied to all cells - Hardware Acceleration: Leverage Flax components (e.g.,
nnx.Conv
,nnx.Linear
) rather than writing custom operations - Batching: Design your CA to handle batched inputs from the start
- JIT Compilation: Ensure your CA is compatible with
jit
by avoiding Python control flow - Random Number Handling: Use
nnx.Rngs
for managing random states consistently
You should design your perceive and update module, so that they are readily compatible with the core CA
class.
perceive = MyPerceive(...)
update = MyUpdate(...)
ca = CA(perceive, update)
A CA step will correspond to:
@nnx.jit
def step(self, state: State, input: Input | None = None) -> State:
"""Perform a single step of the CA.
Args:
state: Current state.
input: Optional input.
Returns:
Updated state.
"""
perception = self.perceive(state)
state = self.update(state, perception, input)
return state
and a full forward pass will correspond to a simple jax.lax.scan of this function.
For neural cellular automata training, a typical training loop follows this structure:
@nnx.jit
def train_step(ca, optimizer, state, target, key):
# Forward pass with value and gradient computation
(loss, state), grad = nnx.value_and_grad(loss_fn, has_aux=True)(
ca, state, target
)
# Update model parameters
optimizer.update(grad)
return loss, state
# Main training loop
for i in range(num_train_steps):
key, subkey = jax.random.split(key)
loss, state = train_step(ca, optimizer, state, target, subkey)
- Avoid Python loops over cells - use vectorized operations
- Don't mix NumPy and JAX arrays
- Keep track of random key usage for stochastic updates
For an extensive list of common gotchas in JAX, please read JAX - The Sharp Bits.
CAX uses Flax NNX API, please read the documentation.
By submitting a contribution to CAX, you agree to license your work under the same MIT License that covers the project. This helps keep the codebase open and accessible to everyone. If you have any questions about the licensing terms, please don't hesitate to reach out to the maintainers.