diff --git a/tests/test_random.py b/tests/test_random.py index 7c4b115..d0940df 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from nanodl import ( time_rng_key, uniform, normal, bernoulli, categorical, randint, - permutation, gumbel, choice, binomial, bits, exponential, + permutation, gumbel, choice, bits, exponential, triangular, truncated_normal, poisson, geometric, gamma, chisquare )