Skip to content

Commit

Permalink
Now displaying PRNGKey seed in failed tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 13, 2023
1 parent 5d8a06e commit 92c76e4
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
import dataclasses
import random
import typing

import jax.random as jr
import pytest
from jaxtyping import PRNGKeyArray


typing.TESTING = True # pyright: ignore


@pytest.fixture()
def getkey():
def _getkey():
# Not sure what the maximum actually is but this will do
return jr.PRNGKey(random.randint(0, 2**31 - 1))
# This offers reproducability -- the initial seed is printed in the repr so we can see
# it when a test fails.
# Note the `eq=False`, which means that `_GetKey `objects have `__eq__` and `__hash__`
# based on object identity.
@dataclasses.dataclass(eq=False)
class _GetKey:
seed: int
call: int
key: PRNGKeyArray

def __init__(self, seed: int):
self.seed = seed
self.call = 0
self.key = jr.PRNGKey(seed)

def __call__(self):
self.call += 1
return jr.fold_in(self.key, self.call)

return _getkey

@pytest.fixture
def getkey():
return _GetKey(random.randint(0, 2**31 - 1))

0 comments on commit 92c76e4

Please sign in to comment.