Skip to content

Commit

Permalink
Update faq.md
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger authored Dec 26, 2023
1 parent eeb2ca6 commit 8d3175a
Showing 1 changed file with 9 additions and 30 deletions.
39 changes: 9 additions & 30 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,44 +142,23 @@ class Model(eqx.Module):

## I think my function is being recompiled each time it is run.

You can check each time your function is compiled by adding a print statement:
Use [`equinox.debug.assert_max_traces`][], for example
```python
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def your_function(x, y, z):
print("Compiling!")
... # rest of your code here
...
```
JAX calls your function each time it needs to compile it. Afterwards, it never actually calls it -- indeed it doesn't use Python at all! Instead, it uses its compiled copy of your function, which only performs array operations. Thus, a print statement is an easy way to check each time JAX is compiling your function.

A function will be recompiled every time the shape or dtype of its arrays changes, or if any of its static (non-array) inputs change (as measured by `__eq__`).
will raise an error if it is compiled more than once, and tell you which argment caused the recompilation. (A function will be recompiled every time the shape or dtype of one of its array-valued inputs change, or if any of its static (non-array) inputs change (as measured by `__eq__`).)

If you want to check which argument is causing an undesired recompilation, then this can be done by checking each argument in turn:
As an alternative, a quick check for announcing each time your function is compiled can be achieved with a print statement:
```python
@eqx.filter_jit
def check_arg(arg, name):
print(f"Argument {name} is triggering a compile.")


for step, (x, y, z) in enumerate(...): # e.g. a training loop
print(f"Step is {step}")
check_arg(x, "x")
check_arg(y, "y")
check_arg(z, "z")
your_function(x, y, z)
```
for which you'll often see output like
```
Step is 0
Argument x is triggering a compile.
Argument y is triggering a compile.
Argument z is triggering a compile.
Step is 1
Argument y is triggering a compile.
Step is 2
Argument y is triggering a compile.
...
def your_function(x, y, z):
print("Compiling!")
... # rest of your code here
```
On the very first step, none of the arguments have been seen before, so they all trigger a compile. On later steps, just the problematic argument will trigger a recompilation of `check_arg` -- this will be the one that is triggering a recompilation of `your_function` as well!
JAX calls your function each time it needs to compile it. Afterwards, it never actually calls it; indeed it doesn't use Python at all! (Instead, it just follows the computation graph of array operations that it has already traced and compiled -- doing this is the point of JIT compilation.) Thus, a print statement is an easy way to check each time JAX is compiling your function.

## How does Equinox compare to...?

Expand Down

0 comments on commit 8d3175a

Please sign in to comment.