From 987d438f0c2b739b5147ecd8cffce8b8c9da7dd7 Mon Sep 17 00:00:00 2001 From: = <=> Date: Sun, 24 Mar 2024 14:50:13 +0000 Subject: [PATCH] IJEPA implementation --- README.md | 103 ++--- nanodl/__init__.py | 9 + nanodl/__src/models/ijepa.py | 704 +++++++++++++++++++++++++++++++++++ tests/test_models.py | 89 +++++ 4 files changed, 837 insertions(+), 68 deletions(-) create mode 100644 nanodl/__src/models/ijepa.py diff --git a/README.md b/README.md index 5ddbaa6..23c32f2 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ We provide various example usages of the nanodl API. ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import GPT4, GPTDataParallelTrainer @@ -69,25 +70,15 @@ from nanodl import GPT4, GPTDataParallelTrainer batch_size = 8 max_length = 10 -# Replace with actual tokenised data +# Replace with actual list of tokenised texts data = jnp.ones((101, max_length), dtype=jnp.int32) # Shift to create next-token prediction dataset -dummy_inputs = data[:, :-1] -dummy_targets = data[:, 1:] +dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) - -# How to loop through dataloader -for batch in dataloader: - x, y = batch - print(x.shape, y.shape) - break +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # model parameters hyperparams = { @@ -103,25 +94,17 @@ hyperparams = { 'end_token': 50, } -# Initialize model +# Initialize inferred GPT4 model model = GPT4(**hyperparams) -rngs = jax.random.PRNGKey(0) -rngs, dropout_rng = jax.random.split(rngs) -params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] - -# Call as you would a Jax/Flax model -outputs = model.apply({'params': params}, - dummy_inputs, - rngs={'dropout': dropout_rng}) -print(outputs.shape) +params = model.init( + {'params': time_rng_key(), + 'dropout': time_rng_key() + }, + dummy_inputs)['params'] # Training on data trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') -trainer.train(train_loader=dataloader, - num_epochs=2, - val_loader=dataloader) - -print(trainer.evaluate(dataloader)) +trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader) # Generating from a start token start_tokens = jnp.array([[123, 456]]) @@ -130,9 +113,8 @@ start_tokens = jnp.array([[123, 456]]) params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': time_rng_key()}, method=model.generate) -print(outputs) ``` Vision example @@ -140,6 +122,7 @@ Vision example ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import DiffusionModel, DiffusionDataParallelTrainer @@ -147,16 +130,12 @@ image_size = 32 block_depth = 2 batch_size = 8 widths = [32, 64, 128] -key = jax.random.PRNGKey(0) input_shape = (101, image_size, image_size, 3) -images = jax.random.normal(key, input_shape) +images = jax.random.normal(time_rng_key(), input_shape) # Use your own images dataset = ArrayDataset(images) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # Create diffusion model diffusion_model = DiffusionModel(image_size, widths, block_depth) @@ -165,13 +144,11 @@ pred_noises, pred_images = diffusion_model.apply(params, images) print(pred_noises.shape, pred_images.shape) # Training on your data -# Note: saved params are often different from training weights, use the saved params for generation trainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.shape, weights_filename='params.pkl', learning_rate=1e-4) trainer.train(dataloader, 10, dataloader) -print(trainer.evaluate(dataloader)) # Generate some samples params = trainer.load_params('params.pkl') @@ -179,7 +156,6 @@ generated_images = diffusion_model.apply({'params': params}, num_images=5, diffusion_steps=5, method=diffusion_model.generate) -print(generated_images.shape) ``` Audio example @@ -187,6 +163,7 @@ Audio example ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import Whisper, WhisperDataParallelTrainer @@ -200,13 +177,8 @@ vocab_size = 1000 dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32) dummy_inputs = jnp.ones((101, max_length, embed_dim)) -dataset = ArrayDataset(dummy_inputs, - dummy_targets) - -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataset = ArrayDataset(dummy_inputs, dummy_targets) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # model parameters hyperparams = { @@ -224,10 +196,8 @@ hyperparams = { # Initialize model model = Whisper(**hyperparams) -rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} +rngs = {'params': time_rng_key(), 'dropout': time_rng_key()} params = model.init(rngs, dummy_inputs, dummy_targets)['params'] -outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs) -print(outputs.shape) # Training on your data trainer = WhisperDataParallelTrainer(model, @@ -239,13 +209,11 @@ trainer.train(dataloader, 2, dataloader) # Sample inference params = trainer.load_params('params.pkl') -# for more than one sample, use model.generate_batch +# for more than one sample, often use model.generate_batch transcripts = model.apply({'params': params}, dummy_inputs[:1], rngs=rngs, method=model.generate) - -print(transcripts) ``` Reward Model example for RLHF @@ -253,6 +221,7 @@ Reward Model example for RLHF ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import Mistral, RewardModel, RewardDataParallelTrainer @@ -266,10 +235,7 @@ dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) # Create dataset and dataloader dataset = ArrayDataset(dummy_chosen, dummy_rejected) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # model parameters hyperparams = { @@ -298,13 +264,9 @@ trainer.train(dataloader, 5, dataloader) params = trainer.load_params('reward_model_weights.pkl') # Call as you would a regular Flax model -rngs = jax.random.PRNGKey(0) -rngs, dropout_rng = jax.random.split(rngs) rewards = reward_model.apply({'params': params}, dummy_chosen, - rngs={'dropout': dropout_rng}) - -print(rewards.shape) + rngs={'dropout': time_rng_key()}) ``` PCA example @@ -313,13 +275,21 @@ PCA example import jax from nanodl import PCA +# Use actual data data = jax.random.normal(jax.random.key(0), (1000, 10)) + +# Initialise and train PCA model pca = PCA(n_components=2) pca.fit(data) + +# Get PCA transforms transformed_data = pca.transform(data) + +# Get reverse transforms original_data = pca.inverse_transform(transformed_data) + +# Sample from the distribution X_sampled = pca.sample(n_samples=1000, key=None) -print(X_sampled.shape, original_data.shape, transformed_data.shape) ``` NanoDL provides random module which abstracts away Jax's intricacies. @@ -339,16 +309,13 @@ jax_array = nanodl.uniform(shape=(3, 3)) jax_array = nanodl.uniform(shape=(3, 3), seed=0) ``` -This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: +This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged! -- Raise the issue/discussion to get second opinions -- Fork the repository -- Create a branch - Make your changes without changing the design patterns - Write tests for your changes if necessary - Install locally with `pip install -e .` - Run tests with `python -m unittest discover -s tests` -- Then submit a pull request from branch. +- Then submit a pull request. Contributions can be made in various forms: @@ -371,7 +338,7 @@ Following the success of Phi models, the long-term goal is to build and train na while ensuring they compete with the original models in performance, with total number of parameters not exceeding 1B. Trained weights will be made available via this library. Any form of sponsorship, funding, grants or contribution will help with training resources. -You can sponsor via the tag on the user profile, or reach out via ndubuakuhenry@gmail.com. +You can sponsor via the user profile tag or reach out via ndubuakuhenry@gmail.com. ## Citing nanodl diff --git a/nanodl/__init__.py b/nanodl/__init__.py index 6977970..d6ac463 100644 --- a/nanodl/__init__.py +++ b/nanodl/__init__.py @@ -133,6 +133,12 @@ RewardDataParallelTrainer ) +from nanodl.__src.models.ijepa import ( + IJEPA, + IJEPADataParallelTrainer, + IJEPADataSampler +) + from nanodl.__src.layers.attention import ( MultiQueryAttention, LocalMultiHeadAttention, @@ -193,6 +199,9 @@ "GaussianProcess", # Models + "IJEPA", + "IJEPADataParallelTrainer", + "IJEPADataSampler", "Gemma", "GemmaDataParallelTrainer", "GemmaDecoder", diff --git a/nanodl/__src/models/ijepa.py b/nanodl/__src/models/ijepa.py new file mode 100644 index 0000000..01f5acc --- /dev/null +++ b/nanodl/__src/models/ijepa.py @@ -0,0 +1,704 @@ +import jax +import flax +import time +import optax +from einops import rearrange +import jax.numpy as jnp +import flax.linen as nn +from flax.training import train_state +from typing import List, Tuple, Any, Optional, Dict, Iterable + + +class PatchEmbedding(nn.Module): + """ + Implements patch embedding for vision transformers. + + This module utilises a 2D conv layer to project patches of from the image to a specified embedding dimension. + + Attributes: + image_size (int): Size of square image. + patch_size (int): Size of square patches from image. + embed_dim (int): Dimension of the embeddings for the patches. + + Methods: + setup(): Calculates `num_patches` and initialises Conv layer. + __call__(x: jnp.ndarray): Passes image through Conv layer which extracts patches and projects into emebdding space. + """ + image_size:int + patch_size:int + embed_dim:int + num_channels:int + + def setup(self): + self.num_patches = (self.image_size**2) // (self.patch_size**2) + + # Use sliding window from conv layer implementation to avoid "splitting" the image. + self.proj = nn.Conv( + features=self.embed_dim, + kernel_size=(self.patch_size, self.patch_size), + strides=self.patch_size, + padding="VALID", + ) + + def __call__(self, x:jnp.ndarray) -> jnp.ndarray: + x = self.proj(x) + x = jnp.reshape(x, (x.shape[0], -1, self.embed_dim)) # (batch_size, num_patches, embed_dim) + + return x + + +class PositionalEmbedding(nn.Module): + """ + Implements Learnt Positional Embedding. + + This module adds a learnt vector to the patch embeddings to introduce a notion of temporal / spatial dependence. + + Attributes: + embed_dim (int): Patch embedding dimensions. + num_patches (int): Number of patches in an image which is dependent on the patch size. + + Methods: + setup(): Initialises embedding layer + __call__(x: jnp.ndarray): Passes a tensor of positions through the positional embedding and adds the positional embeddings to the patch embeddings. + """ + embed_dim:int + num_patches:int + + def setup(self): + self.embedding = nn.Embed( + num_embeddings=self.num_patches, + features=self.embed_dim + ) + + def __call__(self, x:jnp.ndarray) -> jnp.ndarray: + # assuming x of shape (batch_size, num_tokens, embed_dim) + positions = jnp.arange(x.shape[1])[jnp.newaxis, :].repeat(x.shape[0], axis=0) + embed = self.embedding(positions) + + x = x + embed + + return x + + +class MultiHeadedAttention(nn.Module): + """ + Implements the multi-head attention mechanism as described in "Attention is All You Need" by Vaswani et al 2017. + + This module splits the input into multiple heads, applies scaled dot-product attention independently on each head, and then concatenates the results. It allows the model to jointly attend to information from different representation subspaces at different positions. + + Attributes: + embed_dim (int): Dimensionality of the input and output features. + num_heads (int): Number of attention heads. + + Methods: + setup(): Initializes projection matrices for queries, keys, values, and the output projection. + __call__(x: jnp.ndarray): Processes the input tensor through the multi-head self-attention mechanism. + """ + embed_dim:int + num_heads:int + + def setup(self): + self.attn_proj = nn.Dense(3 * self.embed_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + self.out_proj = nn.Dense(self.embed_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + def __call__(self, x:jnp.ndarray) -> jnp.ndarray: + qkv = self.attn_proj(x) + query, key, value = jnp.array_split(qkv, 3, axis=-1) + + query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_heads, -1)) + key = jnp.reshape(key, (key.shape[0], key.shape[1], self.num_heads, -1)) + value = jnp.reshape(value, (value.shape[0], value.shape[1], self.num_heads, -1)) + + # permute to (batch_size, num_heads, seq_len, embed_dim) + + query = jnp.permute_dims(query, (0, 2, 1, 3)) + key = jnp.permute_dims(key, (0, 2, 1, 3)) + value = jnp.permute_dims(value, (0, 2, 1, 3)) + + attn_weights = jnp.matmul(query, key.transpose(0, 1, 3, 2)) / (self.embed_dim **.5) + attn_weights = nn.softmax(attn_weights, -1) + + attn = jnp.matmul(attn_weights, value) + attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) # convert back to (batch_size, seq_len, embed_dim) + + attn = self.out_proj(attn) + + return attn, attn_weights + + +class TransformerEncoderBlock(nn.Module): + """ + Implements a Transformer Encoder Block. + + The transformer encoder block is composed of an attention block and a feedforward block. The sublayers have residual connections followed by a Layer Norm. + + Attributes: + embed_dim (int): Dimensionality of the input and output features. + num_heads (int): Number of attention heads. + feed_forward_dim (int): Dimension of the feed-forward network. + dropout_p (float): Dropout rate. + + Methods: + setup(): Initializes the attention layer, feed forward layers and norm layers. + __call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block. + """ + embed_dim:int + num_heads:int + feed_forward_dim:int + dropout_p:float + + def setup(self): + self.norm1 = nn.LayerNorm() + self.norm2 = nn.LayerNorm() + + self.ff = nn.Sequential([ + nn.Dense(self.feed_forward_dim), + lambda x: nn.gelu(x), + nn.Dense(self.embed_dim) + ]) + + self.attn = MultiHeadedAttention( + embed_dim=self.embed_dim, + num_heads=self.embed_dim, + ) + + self.dropout = nn.Dropout(self.dropout_p) + + def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray: + x_, attn_weights = self.attn(self.norm1(x)) + x = x + x_ + x = self.dropout(x, deterministic=not training) + + x = x + self.ff(self.norm2(x)) + x = self.dropout(x, deterministic=not training) + + return x, attn_weights + + +class TransformerEncoder(nn.Module): + """ + Implements a Transformer Encoder Block. + + The transformer encoder block is composed of an attention block and a feedforward block. The sublayers have residual connections followed by a Layer Norm. + + Attributes: + dropout (int): dropout probability. + num_heads (int): Number of attention heads. + embed_dim (int): Dimensionality of inputs and outputs. + num_layers (int): Number of encoder blocks. + feed_forward_dim (int): Dimension of the feed-forward network. + + Methods: + setup(): Initializes the attention layer, feed forward layers and norm layers. + __call__(x: jnp.ndarray): Processes the input tensor through the transformer encoder block. + """ + dropout:float + num_heads:int + embed_dim:int + num_layers:int + feed_forward_dim:int + + def setup(self): + self.layers = [ + TransformerEncoderBlock( + embed_dim=self.embed_dim, + num_heads=self.num_heads, + feed_forward_dim=self.feed_forward_dim, + dropout_p=self.dropout + ) for _ in range(self.num_layers) + ] + + + def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray: + attn_maps = [] + + for layer in self.layers: + x, attn_weights = layer(x, training=training) + attn_maps.append(attn_weights) + + return x, jnp.array(attn_maps) + + +class IJEPA(nn.Module): + """ + Implements the IJEPA architecture for non-generative self-supervised learning. + Ref: "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" by Mahmoud Assran et al. + + This module consists of three ViTs / Transformer Encoders; A context and target encoder and an embedding predictor. + The embedding predictor is trained to predict the outputs of the target encoder given the outputs of the context encoder. + + Attributes: + image_size (int): Image size. Assuming image is a square image. + num_channels (int): Number of image channels. + patch_size (int): Patch size for ViTs. Assuming patch size is a square and image is a square image. + embed_dim (int): Embedding dimensions for ViTs. + num_heads (int): Number of transformer encoder heads for context and target encoders. + dropout_p (float): Dropout probability. + predictor_num_heads (int): Number of transformer encoder heads for embedding predictor. + share_patch_embedding (bool): Whether or not to share the patch embeddings across the context and target encoders. + + + Methods: + setup(): Initializes the attention layer, feed forward layers and norm layers. + __call__(x:jnp.ndarray, content_mask:jnp.ndarray, target_mask:jnp.ndarray): Applies the context and target masks to the image to get the context and target blocks, then obtains the predicted representations of the target blocks. + + Example usage: + ```py + import jax + import jax.numpy as jnp + from nanodl import ArrayDataset, DataLoader + from nanodl import IJEPA, IJEPADataSampler, IJEPADataParallelTrainer + + # Dummy data parameters + batch_size = 8 + embed_dim = 256 + patch_size = 16 + image_size = 256 + M=4 + + num_patches = (256 * 256) // (patch_size * patch_size) + + # Generate data + dummy_inputs = jnp.ones((batch_size, image_size, image_size, 3)) + dummy_context_masks = jnp.zeros((batch_size, M, num_patches, embed_dim)) + dummy_target_masks = jnp.zeros((batch_size, M, num_patches, embed_dim)) + + key = jax.random.PRNGKey(10) + + # Create dataset and dataloader + dataset = ArrayDataset(dummy_inputs) + + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + data_sampler = IJEPADataSampler( + image_size=img_size, + patch_size=patch_size + ) + + # model parameters + hyperparams = { + "image_size": 256 + "num_channels": 3 + "patch_size": patch_size, + "embed_dim": embed_dim, + "num_heads": 4, + "num_layers": 4, + "dropout_p": 0.1, + "predictor_num_heads": 4, + "predictor_bottleneck": 128, + "predictor_num_layers": 2 + } + + # Initialize model + model = IJEPA(**hyperparams) + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + params = model.init(rngs, dummy_inputs, dummy_context_masks, dummy_target_masks)['params'] + + outputs, _ = model.apply( + {'params': params}, + dummy_inputs, + dummy_context_mask, + dummy_target_mask, + rngs=rngs + ) + + print(outputs.shape) + + # Training on your data + trainer = IJEPADataParallelTrainer(model, dummy_inputs.shape, 'params.pkl', data_sampler=data_sampler) + trainer.train(dataloader, 10, dataloader) + ``` + """ + + image_size: int + num_channels:int + patch_size:int + embed_dim:int + num_heads:int + num_layers:int + dropout_p:float + predictor_num_heads:int + predictor_bottleneck:int + predictor_num_layers:int + share_patch_embedding:bool = True + + def setup(self): + self.num_patches = (self.image_size**2) // (self.patch_size**2) + + self.feed_forward_dim = self.embed_dim*4 + self.predictor_feed_forward_dim = self.predictor_bottleneck*4 + + create_patch_embedding = lambda:PatchEmbedding( + image_size=self.image_size, + patch_size=self.patch_size, + embed_dim=self.embed_dim, + num_channels=self.num_channels, + ) + + if self.share_patch_embedding: # We could have the context and target decoder share the patch emebddings + patch_embedding = create_patch_embedding() + self.patch_embedding = { + "context": patch_embedding, + "target": patch_embedding + } + + else: # Or have them learn different patch embeddings + self.patch_embedding = { + "context": create_patch_embedding(), + "target": create_patch_embedding() + } + + # because the positional embedding is constant, doesn't need to be shared. + self.positional_embedding = PositionalEmbedding( + embed_dim=self.embed_dim, + num_patches=self.num_patches + ) + + + self.context_encoder = TransformerEncoder( + dropout=self.dropout_p, + num_heads=self.num_heads, + embed_dim=self.embed_dim, + num_layers=self.num_layers, + feed_forward_dim=self.feed_forward_dim + ) + + self.target_encoder = TransformerEncoder( + dropout=self.dropout_p, + num_heads=self.num_heads, + embed_dim=self.embed_dim, + num_layers=self.num_layers, + feed_forward_dim=self.feed_forward_dim + + ) + + self.embedding_predictor = TransformerEncoder( + dropout=self.dropout_p, + num_heads=self.predictor_num_heads, + embed_dim=self.predictor_bottleneck, + num_layers=self.predictor_num_layers, + feed_forward_dim=self.predictor_feed_forward_dim + ) + + self.to_predictor_embed = nn.Dense(self.predictor_bottleneck) + self.to_encoder_embed = nn.Dense(self.embed_dim) + + def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndarray, training:bool=False) -> Tuple[List[Tuple[jnp.ndarray, jnp.ndarray]], List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]: + # content & target masks of shape (N, M, num_patches) + + x_context = self.patch_embedding["context"](x) + x_context = self.positional_embedding(x_context) + + x_target = self.patch_embedding["target"](x) + x_target = self.positional_embedding(x_target) + + outputs = [] + attn_weights = [] + + for m in range(context_mask.shape[1]): + context, context_attn_weights = self.context_encoder(x_context, training=training) + context = context * jnp.expand_dims(context_mask[:, m], -1) # (N, num_patches, E) + + target, target_attn_weights = self.target_encoder(x_target, training=training) + target = target * jnp.expand_dims(target_mask[:, m], -1) # (N, num_patches, E) + + predicted_embeddings, embed_attn_weights = self.embedding_predictor( + self.to_predictor_embed(context), + training=training + ) + + predicted_embeddings = self.to_encoder_embed(predicted_embeddings) + predicted_embeddings = predicted_embeddings * jnp.expand_dims(target_mask[:, m], -1) + + outputs.append((predicted_embeddings, target)) + attn_weights.append((context_attn_weights, target_attn_weights, embed_attn_weights)) + + return ( + outputs, + attn_weights + ) + + +class IJEPADataSampler: + to_scale:Any = lambda self, x, a, b: (b-a) * x + a + random_key:int = 0 + random_key = jax.random.PRNGKey(random_key) + + def __init__( + self, + image_size:int = 256, + patch_size:int = 16, + M:int = 4, + context_scale_range:tuple = (.85, 1), + target_scale_range:tuple = (.15, .2), + target_aspect_ratio_range:tuple = (.75, 1.5), + ): + + self.image_size = image_size + self.patch_size = patch_size + self.M = M + self.context_scale_range = context_scale_range + self.target_scale_range = target_scale_range + self.target_aspect_ratio_range = target_aspect_ratio_range + + self.h = image_size // patch_size + self.w = image_size // patch_size + + def sample_target_block_scale(self) -> Tuple[int, int]: + scale = self.to_scale( + jax.random.uniform(self.random_key), + self.target_scale_range[0], + self.target_scale_range[1] + ) + + context_scale = self.to_scale( + jax.random.uniform(self.random_key), + self.context_scale_range[0], + self.context_scale_range[1] + ) + + aspect_ratio = self.to_scale( + jax.random.uniform(self.random_key), + self.target_aspect_ratio_range[0], + self.target_aspect_ratio_range[1] + ) + + target_mask_scale = int(self.h * self.w * scale * context_scale) + + target_h = int((target_mask_scale * aspect_ratio)**.5) + target_w = int((target_mask_scale / aspect_ratio)**.5) + + if target_h >= self.h: + target_h -= target_h - self.h - 1 + if target_w >= self.w: + target_w -= target_w - self.w - 1 + + return target_h, target_w + + def sample_context_target_blocks(self, h:int, w:int) -> Tuple[jnp.ndarray, jnp.ndarray]: + context_mask = jnp.ones((self.M, self.image_size, self.image_size)) + target_mask = jnp.zeros((self.M, self.image_size, self.image_size)) + + for m in range(self.M): + top = jax.random.randint(self.random_key, (), 0, self.h - h) + left = jax.random.randint(self.random_key, (), 0, self.w - w) + + context_mask = context_mask.at[m, + top*self.patch_size: (top+h)*self.patch_size, + left*self.patch_size: (left+w)*self.patch_size].set(0) + + target_mask = target_mask.at[m, + top*self.patch_size: (top+h)*self.patch_size, + left*self.patch_size: (left+w)*self.patch_size].set(1) + + context_mask = rearrange(context_mask, "m (p1 h) (p2 w) -> m (h w) (p1 p2)", p1=self.patch_size, p2=self.patch_size) + target_mask = rearrange(target_mask, "m (p1 h) (p2 w) -> m (h w) (p1 p2)", p1=self.patch_size, p2=self.patch_size) + + context_mask = jnp.any(context_mask == 1, axis=-1) + target_mask = jnp.any(target_mask == 0, axis=-1) + + return context_mask, target_mask + + + def __call__(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + h, w = self.sample_target_block_scale() + context_mask, target_mask = self.sample_context_target_blocks(h, w) + + return context_mask, target_mask + + +class IJEPADataParallelTrainer: + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename:str, + data_sampler: IJEPADataSampler, + learning_rate:float = 1e-4, + params_path: Optional[str] = None) -> None: + + self.model = model + self.params = None + self.params_path = params_path + self.num_parameters = None + self.best_val_loss = float("inf") + self.weights_filename = weights_filename + self.data_sampler = data_sampler + self.num_devices = jax.local_device_count() + self.train_step = jax.pmap(IJEPADataParallelTrainer.train_step, axis_name='devices') + self.evaluation_step = jax.pmap(IJEPADataParallelTrainer.evaluation_step, axis_name='devices') + self.state = self.create_train_state(learning_rate, input_shape) + print(f'Number of accelerators: {self.num_devices}') + + + def create_train_state(self, + learning_rate: float, + input_shape: Tuple[int, ...]) -> Any: + + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + context_mask, target_mask = self.data_sampler() + + context_mask = jnp.repeat(context_mask[jnp.newaxis], input_shape[0], axis=0) + target_mask = jnp.repeat(target_mask[jnp.newaxis], input_shape[0], axis=0) + + params = self.model.init(rngs, jnp.ones(input_shape), context_mask, target_mask)['params'] + + if self.params_path is not None: + params = self.load_params(self.params_path) + + self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) + print(f'Number of parameters: {self.num_parameters}') + state = train_state.TrainState.create(apply_fn=self.model.apply, + params=params, + tx=optax.adam(learning_rate)) + return jax.device_put_replicated(state, jax.local_devices()) + + @staticmethod + def train_step(state: Any, + images: jnp.ndarray, + context_mask: jnp.ndarray, + target_mask: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + def loss_fn(params): + outputs, _ = state.apply_fn( + {'params': params}, + images, + context_mask=context_mask, + target_mask=target_mask, + training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))} + ) + + losses = jnp.array([ + jnp.mean(jnp.square(outputs[i][0] - outputs[i][1])) for i in range(len(outputs)) + ]) + + return jnp.mean(losses) + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + def train(self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: + + for epoch in range(num_epochs): + total_loss = 0.0 + count = 0 + for images in train_loader: + images = images[0] if len(images) == 1 else images + + batch_size = images.shape[0] + batch_size_per_device = batch_size // self.num_devices + images = images.reshape((self.num_devices, batch_size_per_device, images.shape[1], images.shape[2], images.shape[3])) + + context_mask, target_mask = self.data_sampler() + + context_mask = jnp.repeat(context_mask[jnp.newaxis], batch_size, axis=0) + target_mask = jnp.repeat(target_mask[jnp.newaxis], batch_size, axis=0) + + context_mask = context_mask.reshape((self.num_devices, batch_size_per_device, context_mask.shape[1], context_mask.shape[2])) + target_mask = target_mask.reshape((self.num_devices, batch_size_per_device, target_mask.shape[1], target_mask.shape[2])) + + self.state, loss = self.train_step(state=self.state, + images=images, + context_mask=context_mask, + target_mask=target_mask + ) + + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + + if val_loader is not None: + val_loss = self.evaluate(val_loader) + print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + print("New best validation score achieved, saving model...") + self.save_params() + return + + @staticmethod + def evaluation_step(state: Any, + images: jnp.ndarray, + context_mask: jnp.ndarray, + target_mask: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + outputs, _ = state.apply_fn( + {'params': state.params}, + images, + context_mask=context_mask, + target_mask=target_mask, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))} + ) + + losses = jnp.array([ + jnp.mean(jnp.square(outputs[i][0] - outputs[i][1])) for i in range(len(outputs)) + ]) + + return jnp.mean(losses) + + + def evaluate(self, + test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + + total_loss = 0.0 + count = 0 + for images in test_loader: + images = images[0] if len(images) == 1 else images + + batch_size = images.shape[0] + batch_size_per_device = batch_size // self.num_devices + images = images.reshape((self.num_devices, batch_size_per_device, images.shape[1], images.shape[2], images.shape[3])) + + context_mask, target_mask = self.data_sampler() + + context_mask = jnp.repeat(context_mask[jnp.newaxis], batch_size, axis=0) + target_mask = jnp.repeat(target_mask[jnp.newaxis], batch_size, axis=0) + + context_mask = context_mask.reshape(( + self.num_devices, + batch_size_per_device, + context_mask.shape[1], + context_mask.shape[2] + )) + + target_mask = target_mask.reshape(( + self.num_devices, + batch_size_per_device, + target_mask.shape[1], + target_mask.shape[2] + )) + + loss = self.evaluation_step( + state=self.state, + images=images, + context_mask=context_mask, + target_mask=target_mask + ) + + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + return mean_loss + + def save_params(self) -> None: + self.params = flax.jax_utils.unreplicate(self.state.params) + with open(self.weights_filename, 'wb') as f: + f.write(flax.serialization.to_bytes(self.params)) + + def load_params(self, filename: str): + with open(filename, 'rb') as f: + self.params = flax.serialization.from_bytes(self.params, f.read()) + return self.params \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index a9c66ea..e2b9ff3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -351,5 +351,94 @@ def test_gat_model_initialization_and_processing(self): self.assertEqual(output.shape, (self.num_nodes, self.nclass)) +class TestIJEPAModel(unittest.TestCase): + def setUp(self): + self.image_size = 128 + self.num_channels = 3 + self.patch_size = 16 + self.embed_dim = 32 + self.predictor_bottleneck = 16 + self.num_heads = 4 + self.predictor_num_heads = 4 + self.num_layers = 2 + self.predictor_num_layers = 1 + self.dropout_p = 0 + self.num_patches = (self.image_size ** 2) / (self.patch_size ** 2) + + + self.x = jax.random.normal( + jax.random.PRNGKey(0), + (1, self.image_size, self.image_size, self.num_channels) + ) + + self.model = IJEPA( + image_size=self.image_size, + num_channels=self.num_channels, + patch_size=self.patch_size, + embed_dim=self.embed_dim, + predictor_bottleneck=self.predictor_bottleneck, + num_heads=self.num_heads, + predictor_num_heads=self.predictor_num_heads, + num_layers=self.num_layers, + predictor_num_layers=self.predictor_num_layers, + dropout_p=self.dropout_p, + ) + + self.data_sampler = IJEPADataSampler( + image_size=self.image_size, + M=4, + patch_size=self.patch_size + ) + + def test_ijepa_data_sampling(self): + context_mask, target_mask = self.data_sampler() + self.assertEqual(context_mask.shape, (4, self.num_patches)) + self.assertEqual(target_mask.shape, (4, self.num_patches)) + + def test_ijepa_model_initialization_and_processing(self): + context_mask, target_mask = self.data_sampler() + + params = self.model.init( + jax.random.key(0), + self.x, + context_mask[jnp.newaxis], + target_mask[jnp.newaxis], + training=False + ) + + outputs , _ = self.model.apply( + params, + self.x, + context_mask[jnp.newaxis], + target_mask[jnp.newaxis], + training=False + ) + + self.assertEqual(len(outputs), 4) + self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim)) + self.assertEqual(outputs[0][0].shape, outputs[0][1].shape) + + + def test_ijepa_training(self): + x = jax.random.normal( + jax.random.PRNGKey(0), + (9, self.image_size, self.image_size, self.num_channels) + ) + + dataset = ArrayDataset(x) + + dataloader = DataLoader(dataset, + batch_size=3, + shuffle=True, + drop_last=False) + + data_sampler = IJEPADataSampler( + image_size=self.image_size, + patch_size=self.patch_size + ) + + trainer = IJEPADataParallelTrainer(self.model, x.shape, 'params.pkl', data_sampler=data_sampler) + trainer.train(dataloader, 10, dataloader) + if __name__ == '__main__': unittest.main() \ No newline at end of file