Skip to content

Commit

Permalink
fix(tests/sampling): use float32 for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Oct 9, 2024
1 parent 12dd0da commit 55f378b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_get_noise(self):
num_samples=1,
height=height,
width=width,
dtype=jax.dtypes.bfloat16,
dtype=jax.numpy.float32,
seed=jax.random.PRNGKey(seed=42),
)
x_torch = torch_get_noise(
num_samples=1,
height=height,
width=width,
dtype=torch.bfloat16,
dtype=torch.float32,
seed=42,
device="cuda" if torch.cuda.is_available() else "cpu",
)
Expand Down

0 comments on commit 55f378b

Please sign in to comment.