Skip to content

Commit

Permalink
touched up Ijepa
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Mar 25, 2024
1 parent 8967b1d commit 4fd0092
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 71 deletions.
52 changes: 3 additions & 49 deletions nanodl/__src/models/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class PatchEmbedding(nn.Module):
patch_size (int): Size of square patches from image.
embed_dim (int): Dimension of the embeddings for the patches.
Methods:
setup(): Calculates `num_patches` and initialises Conv layer.
__call__(x: jnp.ndarray): Passes image through Conv layer which extracts patches and projects into emebdding space.
"""
image_size:int
patch_size:int
Expand All @@ -43,7 +40,6 @@ def setup(self):
def __call__(self, x:jnp.ndarray) -> jnp.ndarray:
x = self.proj(x)
x = jnp.reshape(x, (x.shape[0], -1, self.embed_dim)) # (batch_size, num_patches, embed_dim)

return x


Expand All @@ -57,9 +53,6 @@ class PositionalEmbedding(nn.Module):
embed_dim (int): Patch embedding dimensions.
num_patches (int): Number of patches in an image which is dependent on the patch size.
Methods:
setup(): Initialises embedding layer
__call__(x: jnp.ndarray): Passes a tensor of positions through the positional embedding and adds the positional embeddings to the patch embeddings.
"""
embed_dim:int
num_patches:int
Expand All @@ -71,12 +64,9 @@ def setup(self):
)

def __call__(self, x:jnp.ndarray) -> jnp.ndarray:
# assuming x of shape (batch_size, num_tokens, embed_dim)
positions = jnp.arange(x.shape[1])[jnp.newaxis, :].repeat(x.shape[0], axis=0)
embed = self.embedding(positions)

x = x + embed

return x


Expand All @@ -90,9 +80,6 @@ class MultiHeadedAttention(nn.Module):
embed_dim (int): Dimensionality of the input and output features.
num_heads (int): Number of attention heads.
Methods:
setup(): Initializes projection matrices for queries, keys, values, and the output projection.
__call__(x: jnp.ndarray): Processes the input tensor through the multi-head self-attention mechanism.
"""
embed_dim:int
num_heads:int
Expand All @@ -107,25 +94,17 @@ def setup(self):
def __call__(self, x:jnp.ndarray) -> jnp.ndarray:
qkv = self.attn_proj(x)
query, key, value = jnp.array_split(qkv, 3, axis=-1)

query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_heads, -1))
key = jnp.reshape(key, (key.shape[0], key.shape[1], self.num_heads, -1))
value = jnp.reshape(value, (value.shape[0], value.shape[1], self.num_heads, -1))

# permute to (batch_size, num_heads, seq_len, embed_dim)

query = jnp.permute_dims(query, (0, 2, 1, 3))
key = jnp.permute_dims(key, (0, 2, 1, 3))
value = jnp.permute_dims(value, (0, 2, 1, 3))

attn_weights = jnp.matmul(query, key.transpose(0, 1, 3, 2)) / (self.embed_dim **.5)
attn_weights = nn.softmax(attn_weights, -1)

attn = jnp.matmul(attn_weights, value)
attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) # convert back to (batch_size, seq_len, embed_dim)

attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim))
attn = self.out_proj(attn)

return attn, attn_weights


Expand All @@ -141,9 +120,6 @@ class TransformerEncoderBlock(nn.Module):
feed_forward_dim (int): Dimension of the feed-forward network.
dropout_p (float): Dropout rate.
Methods:
setup(): Initializes the attention layer, feed forward layers and norm layers.
__call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block.
"""
embed_dim:int
num_heads:int
Expand Down Expand Up @@ -171,10 +147,8 @@ def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray:
x_, attn_weights = self.attn(self.norm1(x))
x = x + x_
x = self.dropout(x, deterministic=not training)

x = x + self.ff(self.norm2(x))
x = self.dropout(x, deterministic=not training)

return x, attn_weights


Expand All @@ -190,10 +164,7 @@ class TransformerEncoder(nn.Module):
embed_dim (int): Dimensionality of inputs and outputs.
num_layers (int): Number of encoder blocks.
feed_forward_dim (int): Dimension of the feed-forward network.
Methods:
setup(): Initializes the attention layer, feed forward layers and norm layers.
__call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block.
"""
dropout:float
num_heads:int
Expand All @@ -214,11 +185,9 @@ def setup(self):

def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray:
attn_maps = []

for layer in self.layers:
x, attn_weights = layer(x, training=training)
attn_maps.append(attn_weights)

return x, jnp.array(attn_maps)


Expand All @@ -240,11 +209,6 @@ class IJEPA(nn.Module):
predictor_num_heads (int): Number of transformer encoder heads for embedding predictor.
share_patch_embedding (bool): Whether or not to share the patch embeddings across the context and target encoders.
Methods:
setup(): Initializes the attention layer, feed forward layers and norm layers.
__call__(x:jnp.ndarray, content_mask:jnp.ndarray, target_mask:jnp.ndarray): Applies the context and target masks to the image to get the context and target blocks, then obtains the predicted representations of the target blocks.
Example usage:
```py
import jax
Expand Down Expand Up @@ -390,11 +354,8 @@ def setup(self):
self.to_encoder_embed = nn.Dense(self.embed_dim)

def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndarray, training:bool=False) -> Tuple[List[Tuple[jnp.ndarray, jnp.ndarray]], List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]:
# content & target masks of shape (N, M, num_patches)

x_context = self.patch_embedding["context"](x)
x_context = self.positional_embedding(x_context)

x_target = self.patch_embedding["target"](x)
x_target = self.positional_embedding(x_target)

Expand All @@ -404,7 +365,6 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar
for m in range(context_mask.shape[1]):
context, context_attn_weights = self.context_encoder(x_context, training=training)
context = context * jnp.expand_dims(context_mask[:, m], -1) # (N, num_patches, E)

target, target_attn_weights = self.target_encoder(x_target, training=training)
target = target * jnp.expand_dims(target_mask[:, m], -1) # (N, num_patches, E)

Expand All @@ -415,14 +375,10 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar

predicted_embeddings = self.to_encoder_embed(predicted_embeddings)
predicted_embeddings = predicted_embeddings * jnp.expand_dims(target_mask[:, m], -1)

outputs.append((predicted_embeddings, target))
attn_weights.append((context_attn_weights, target_attn_weights, embed_attn_weights))

return (
outputs,
attn_weights
)
return (outputs, attn_weights)


class IJEPADataSampler:
Expand Down Expand Up @@ -543,10 +499,8 @@ def create_train_state(self,

rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
context_mask, target_mask = self.data_sampler()

context_mask = jnp.repeat(context_mask[jnp.newaxis], input_shape[0], axis=0)
target_mask = jnp.repeat(target_mask[jnp.newaxis], input_shape[0], axis=0)

params = self.model.init(rngs, jnp.ones(input_shape), context_mask, target_mask)['params']

if self.params_path is not None:
Expand Down
22 changes: 0 additions & 22 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,27 +418,5 @@ def test_ijepa_model_initialization_and_processing(self):
self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim))
self.assertEqual(outputs[0][0].shape, outputs[0][1].shape)


def test_ijepa_training(self):
x = jax.random.normal(
jax.random.PRNGKey(0),
(9, self.image_size, self.image_size, self.num_channels)
)

dataset = ArrayDataset(x)

dataloader = DataLoader(dataset,
batch_size=3,
shuffle=True,
drop_last=False)

data_sampler = IJEPADataSampler(
image_size=self.image_size,
patch_size=self.patch_size
)

trainer = IJEPADataParallelTrainer(self.model, x.shape, 'params.pkl', data_sampler=data_sampler)
trainer.train(dataloader, 10, dataloader)

if __name__ == '__main__':
unittest.main()

0 comments on commit 4fd0092

Please sign in to comment.