Skip to content

Commit

Permalink
feat: add tests for embedding layer
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Sep 15, 2024
1 parent 8a64274 commit 3d9dc3f
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 7 deletions.
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ dependencies = [
"jax>=0.4.31",
"mypy>=1.11.2",
"pillow>=10.4.0",
"pytest>=8.3.2",
"ruff>=0.6.3",
"transformers>=4.44.2",
]

[project.scripts]
jflux = "jflux.cli:app"

[tool.uv]
dev-dependencies = [
"pytest>=8.3.3",
"torch>=2.4.1",
]

[tool.uv.sources]
flux-jax = { workspace = true }

Expand Down
60 changes: 60 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import chex
import jax.numpy as jnp
import torch.nn as nn
import torch
from einops import rearrange

from jflux.layers import Embed


def torch_rope(pos, dim: int, theta: int):
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack(
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()


class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim

def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat(
[
torch_rope(ids[..., i], self.axes_dim[i], self.theta)
for i in range(n_axes)
],
dim=-3,
)

return emb.unsqueeze(1)


class EmbedTestCase(chex.TestCase):
def test_embed(self):
# Initialize layers
pytorch_embed_layer = EmbedND(512, 10000, [64, 64, 64, 64])
jax_embed_layer = Embed(512, 10000, [64, 64, 64, 64])

# Generate random inputs
torch_ids = torch.randint(0, 10000, (1, 32, 4), dtype=torch.float64)
jax_ids = jnp.asarray(torch_ids.numpy())

# Forward pass
jax_output = jax_embed_layer(jax_ids)
pytorch_output = pytorch_embed_layer(torch_ids)

# Assertions
chex.assert_equal_shape([jax_output, jnp.asarray(pytorch_output.numpy())])
chex.assert_trees_all_close(
jax_output, jnp.asarray(pytorch_output.numpy()), rtol=1e-3, atol=1e-3
)
3 changes: 2 additions & 1 deletion tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from jflux.math import attention, rope, apply_rope


@pytest.mark.xfail
class TestAttentionMechanism(unittest.TestCase):
def setUp(self):
self.batch_size = 2
Expand All @@ -29,6 +28,7 @@ def test_rope(self):
rope_output.shape, expected_shape, "rope function output shape is incorrect"
)

@pytest.mark.xfail
def test_apply_rope(self):
pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
pos = jnp.repeat(pos, self.batch_size, axis=0)
Expand All @@ -43,6 +43,7 @@ def test_apply_rope(self):
xk_out.shape, self.k.shape, "apply_rope xk output shape is incorrect"
)

@pytest.mark.xfail
def test_attention(self):
pos = jnp.expand_dims(jnp.arange(self.seq_len), axis=0)
pos = jnp.repeat(pos, self.batch_size, axis=0)
Expand Down
Loading

0 comments on commit 3d9dc3f

Please sign in to comment.