diff --git a/.gitignore b/.gitignore index 4949977..29a7f68 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,9 @@ __pycache__/ # Ignore configuration files with sensitive information config.ini secrets.yaml +params.pkl +base_params.pkl +reward_params.pkl # Ignore user-specific files /userdata/ diff --git a/README.md b/README.md index 0f9c3e7..c48514a 100644 --- a/README.md +++ b/README.md @@ -20,17 +20,26 @@ Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. -- An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. -- Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. +- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. +- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. - Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective. - Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. - Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. +- True random number generators in Jax which do not need the verbose code. - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc. - Each model is contained in a single file with no external dependencies, so the source code can also be easily used. Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at ndubuakuhenry@gmail.com. +## What's New in version 1.2.0.dev1 + +- Google's Gemma architecture. +- Reward model wrapper and data-parallel distributed reward trainer. +- True random number generators in Jax which do not need the verbose code (examples shown in next sections). + +There are experimental features (like MAMBA architecture and RLHF) in the repo which is not available via the package, pending tests. + ## Quick install You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md) @@ -52,7 +61,7 @@ pip install nanodl ## What does nanodl look like? -We provide various examples using the nanodl API: language, vision and audio, starting with an LLM. +We provide various example usages of the nanodl API. ```py import jax @@ -129,7 +138,9 @@ outputs = model.apply({'params': params}, method=model.generate) print(outputs) ``` + Vision example + ```py import jax import jax.numpy as jnp @@ -176,6 +187,7 @@ print(generated_images.shape) ``` Audio example + ```py import jax import jax.numpy as jnp @@ -200,12 +212,6 @@ dataloader = DataLoader(dataset, shuffle=True, drop_last=False) -# How to loop through dataloader -for batch in dataloader: - x, y = batch - print(x.shape, y.shape) - break - # model parameters hyperparams = { 'num_layers': 1, @@ -246,7 +252,67 @@ transcripts = model.apply({'params': params}, print(transcripts) ``` +Reward Model example for RLHF + +```py +import jax +import jax.numpy as jnp +from nanodl import ArrayDataset, DataLoader +from nanodl import Mistral, RewardModel, RewardDataParallelTrainer + +# Generate dummy data +batch_size = 8 +max_length = 10 + +# Replace with actual tokenised data +dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) +dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) + +# Create dataset and dataloader +dataset = ArrayDataset(dummy_chosen, dummy_rejected) +dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + # model parameters +hyperparams = { + 'num_layers': 1, + 'hidden_dim': 256, + 'num_heads': 2, + 'feedforward_dim': 256, + 'dropout': 0.1, + 'vocab_size': 1000, + 'embed_dim': 256, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + 'num_groups': 2, + 'window_size': 5, + 'shift_size': 2 +} + +# Initialize reward model from Mistral +model = Mistral(**hyperparams) +reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) + +# Train the reward model +trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') +trainer.train(dataloader, 5, dataloader) +params = trainer.load_params('reward_model_weights.pkl') + +# Call as you would a regular Flax model +rngs = jax.random.PRNGKey(0) +rngs, dropout_rng = jax.random.split(rngs) +rewards = reward_model.apply({'params': params}, + dummy_chosen, + rngs={'dropout': dropout_rng}) + +print(rewards.shape) +``` + PCA example + ```py import jax from nanodl import PCA @@ -260,14 +326,27 @@ X_sampled = pca.sample(n_samples=1000, key=None) print(X_sampled.shape, original_data.shape, transformed_data.shape) ``` -# Contribution +NanoDL provides random module which abstracts away Jax's intricacies. +It generates truly random variables by using the current timestamp as seed. + +```py +# Jax example +key = random.PRNGKey(0) +jax_array = random.uniform(key, shape=(3, 3)) + +# NanoDL example +jax_array = nanodl.uniform(shape=(3, 3)) + +# For reproducability, use seed +jax_array = nanodl.uniform(shape=(3, 3), seed=0) +``` This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: - Raise the issue/discussion to get second opinions - Fork the repository - Create a branch -- Make your changes without ruining the design patterns +- Make your changes without changing the design patterns - Write tests for your changes if necessary - Install locally with `pip install -e .` - Run tests with `python -m unittest discover -s tests` @@ -279,16 +358,11 @@ Contributions can be made in various forms: - Fixing bugs. - Implementing papers. - Writing high-coverage tests. -- OPtimizing existing codes. +- Optimizing existing codes. - Experimenting and submitting real-world examples to the examples section. - Reporting bugs. - Responding to reported issues. -Coming features include: -- Reinforcement Learning With Human Feedback (RLHF). -- Tokenizers. -- Code optimisations. - To follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A) ## Sponsorships diff --git a/nanodl/__init__.py b/nanodl/__init__.py index 8eb87d7..bb55018 100644 --- a/nanodl/__init__.py +++ b/nanodl/__init__.py @@ -1,8 +1,9 @@ -__version__ = "1.0.0.dev1" +__version__ = "1.2.0.dev1" from nanodl.__src.sklearn_gpu.bayes import NaiveBayesClassifier from nanodl.__src.sklearn_gpu.dimensionality_reduction import PCA from nanodl.__src.sklearn_gpu.clustering import KMeans, GaussianMixtureModel +from nanodl.__src.utils.random import * from nanodl.__src.sklearn_gpu.regression import ( LinearRegression, @@ -106,6 +107,7 @@ UNetResidualBlock ) + from nanodl.__src.models.transformer import ( Transformer, TransformerDataParallelTrainer, @@ -118,6 +120,26 @@ AddNorm ) +from nanodl.__src.models.gemma import ( + Gemma, + GemmaDataParallelTrainer, + GemmaDecoder, + GemmaDecoderBlock +) + +from nanodl.__src.models.reward import ( + RewardModel, + RewardDataParallelTrainer +) + +from nanodl.__src.layers.attention import ( + MultiQueryAttention, + LocalMultiHeadAttention, + HierarchicalMultiHeadAttention, + GatedMultiHeadAttention, + RotaryMultiHeadAttention +) + from nanodl.__src.utils.data import ( Dataset, ArrayDataset, @@ -170,6 +192,10 @@ "GaussianProcess", # Models + "Gemma", + "GemmaDataParallelTrainer", + "GemmaDecoder", + "GemmaDecoderBlock", "GAT", "GraphAttentionLayer", "T5", @@ -223,6 +249,8 @@ "WhisperDataParallelTrainer", "WhisperSpeechEncoder", "WhisperSpeechEncoderBlock", + "RewardModel", + "RewardDataParallelTrainer", "DiffusionModel", "DiffusionDataParallelTrainer", "UNet", @@ -267,7 +295,32 @@ "normalize_images", "random_crop", "random_flip_image", - "sobel_edge_detection" + "sobel_edge_detection", + "MultiQueryAttention", + "LocalMultiHeadAttention", + "HierarchicalMultiHeadAttention", + "GatedMultiHeadAttention", + "RotaryMultiHeadAttention", + + # Random + "time_rng_key", + "uniform", + "normal", + "bernoulli", + "categorical", + "randint", + "permutation", + "gumbel", + "choice", + "binomial", + "bits", + "exponential", + "triangular", + "truncated_normal", + "poisson", + "geometric", + "gamma", + "chisquare", ] import importlib @@ -289,11 +342,15 @@ def test_jax(jax): def test_optax(optax): optimizer = optax.sgd(learning_rate=0.1) +def test_einops(einops): + arr = einops.rearrange([1, 2, 3], 'a b c -> b a c') + def main(): try: flax = check_library_installed('flax') jax = check_library_installed('jax') optax = check_library_installed('optax') + einops = check_library_installed('einops') test_flax(flax) test_jax(jax) diff --git a/nanodl/__src/layers/__init__.py b/nanodl/__src/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nanodl/__src/layers/attention.py b/nanodl/__src/layers/attention.py new file mode 100644 index 0000000..156d1e3 --- /dev/null +++ b/nanodl/__src/layers/attention.py @@ -0,0 +1,495 @@ +import jax +import jax.numpy as jnp +import flax.linen as nn + +class MultiQueryAttention(nn.Module): + """Multi-Query Attention module. + + This module implements the Multi-Query Attention mechanism proposed in the + paper "Reformer: The Efficient Transformer" (https://arxiv.org/abs/1911.02150) + by Noah Shazeer. + + The Multi-Query Attention mechanism can be used for both self-attention and + cross-attention. It uses one set of query and key heads with multiple query + heads, reducing the number of projection parameters and making it more + efficient compared to the standard attention mechanism. + + Args: + hidden_dim (int): The output dimension of the attention module. + num_heads (int): The number of parallel attention heads. + """ + hidden_dim : int # Output dimension + num_heads : int # Number of parallel heads + + def setup(self): + # To ensure dimensions are compatible + assert self.hidden_dim % self.num_heads <= 0 + + self.query_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.key_projection = nn.Dense(self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.value_projection = nn.Dense(self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.output = nn.Dense(self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + + + def __call__(self, + inputs: jnp.ndarray, + context: jnp.ndarray, + mask: jnp.ndarray = None) -> tuple: + + query = self.query_projection(inputs) + key = self.key_projection(context) + value = self.value_projection(context) + key = jnp.repeat(key, self.num_heads, axis=-1) + value = jnp.repeat(value, self.num_heads, axis=-1) + context_vectors, attention = self.attention_function(query,key, value, mask=mask) + outputs = self.output(context_vectors) + return outputs, attention + + def attention_function(self, query, key, value, mask=None): + input_length = value.shape[1] + context_length = key.shape[1] + head_dim = query.shape[-1] // self.num_heads + dim_key = key.shape[-1] + + # Split queries, keys, and values into heads + query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) + key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) + value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + + attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + if mask is not None: + attention_scores = attention_scores * mask + + attention_weights = jax.nn.softmax(attention_scores, axis=-1) + attended_values = jnp.matmul(attention_weights, value_heads) + attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + return attended_values, attention_weights + + + +class RotaryPositionalEncoding(): + def __init__(self, dim_model: int): + super().__init__() + self.dim_model = dim_model + + inv_freq = 1.0 / (10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model)) + self.inv_freq = inv_freq + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=1): + seq_len = x.shape[seq_dimension] + + if seq_len != self._seq_len_cached: + self._seq_len_cached = seq_len + t = jnp.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = jnp.outer(t, self.inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + self._cos_cached = jnp.cos(emb)[None, None, :, :] + self._sin_cached = jnp.sin(emb)[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def rotate_half(self, x): + x1, x2 = jnp.split(x, 2, axis=-1) + return jnp.concatenate((-x2, x1), axis=-1) + + def apply_rotary_pos_emb(self, x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + return (x * cos) + (self.rotate_half(x) * sin) + + def __call__(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + return ( + self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)[0], + self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)[0], + ) + + +class RotaryMultiHeadAttention(nn.Module): + """Rotary Multi-Head Attention module. + + This module implements the Rotary Multi-Head Attention mechanism, which + incorporates the Rotary Positional Encoding (RoPE) proposed in the paper + "RoBERTa: A Robustly Optimized BERT Pretraining Approach" + (https://arxiv.org/abs/1907.11692) by Yinhan Liu et al. + + The Rotary Multi-Head Attention mechanism is an extension of the standard + Multi-Head Attention mechanism, where the queries, keys, and values are + rotated by distinct frequency bands based on their relative positions. + This approach helps the attention mechanism better capture positional + information and improve performance on tasks involving long sequences. + + Args: + hidden_dim (int): The output dimension of the attention module. + num_heads (int): The number of parallel attention heads. + """ + hidden_dim : int # Output dimension + num_heads : int # Number of parallel heads + + def setup(self): + # Because the Query is determined from a context, project separately + self.query_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.key_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.value_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.rope = RotaryPositionalEncoding(self.hidden_dim*self.num_heads) + self.output = nn.Dense(self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + + + def __call__(self, + inputs: jnp.ndarray, + context: jnp.ndarray, + mask: jnp.ndarray = None) -> tuple: + + query = self.query_projection(inputs) + key = self.key_projection(context) + value = self.value_projection(context) + query, key = self.rope(query, key) # Encode query and key with RoPE + context_vectors, attention = self.attention_function(query,key, value, mask=mask) + outputs = self.output(context_vectors) + return outputs, attention + + def attention_function(self, query, key, value, mask=None): + input_length = value.shape[1] + context_length = key.shape[1] + head_dim = query.shape[-1] // self.num_heads + dim_key = key.shape[-1] + + # Split queries, keys, and values into heads + query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) + key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) + value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + + attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + if mask is not None: + attention_scores = attention_scores * mask + + attention_weights = jax.nn.softmax(attention_scores, axis=-1) + attended_values = jnp.matmul(attention_weights, value_heads) + attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + return attended_values, attention_weights + + +class GatedMultiHeadAttention(nn.Module): + """Gated Multi-Head Attention module. + + This module implements the Gated Multi-Head Attention mechanism proposed in + the paper "Gated Attention Networks for Learning on Large and Spatiotemporal + Graphs" (https://arxiv.org/abs/1912.00349) by Lingxue Zhu et al. + + The Gated Multi-Head Attention mechanism involves transforming the input by + weighting features based on their importance relative to a context. This + approach aims to capture the most relevant information and improve the + model's performance. + + Note: The discrete nature of the gate creates a differentiability challenge + during backpropagation. The paper suggests using the Gumbel-Softmax + approximation to mitigate this issue before training. + + Args: + hidden_dim (int): The output dimension of the attention module. + num_heads (int): The number of parallel attention heads. + """ + hidden_dim : int # Output dimension + num_heads : int # Number of parallel heads + + def setup(self): + # Because the Query is determined from a context, project separately + self.query_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.key_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.value_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.output = nn.Dense(self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.gate = nn.Dense(features=1) + + + def __call__(self, + inputs: jnp.ndarray, + context: jnp.ndarray, + mask: jnp.ndarray = None) -> tuple: + + query = self.query_projection(inputs) + key = self.key_projection(context) + value = self.value_projection(context) + context_vectors, attention = self.attention_function(query,key,value,mask=mask) + outputs = self.output(context_vectors) + return outputs, attention + + def attention_function(self, query, key, value,mask=None): + input_length = value.shape[1] + context_length = key.shape[1] + head_dim = query.shape[-1] // self.num_heads + dim_key = key.shape[-1] + + # Split queries, keys, and values into heads + query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) + key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) + value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + + probabilities = jax.nn.sigmoid(self.gate(value_heads)) + booleans = jax.random.bernoulli(jax.random.PRNGKey(0), probabilities) + gate = jnp.where(booleans, 1.0, 0.0) + + attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores * gate + + if mask is not None: + attention_scores = attention_scores * mask + + attention_weights = jax.nn.softmax(attention_scores, axis=-1) + attended_values = jnp.matmul(attention_weights, value_heads) + attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + return attended_values, attention_weights + + +class HierarchicalMultiHeadAttention(nn.Module): + """Hierarchical Multi-Head Attention module. + + This module implements the Hierarchical Attention Network proposed in the + paper "Hierarchical Attention Networks for Document Classification" + (https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pdf) + by Zichao Yang et al. + + The Hierarchical Attention Network consists of two main parts: a word + attention layer and a sentence attention layer. The word attention layer + learns to attend to the most important words in a sentence, while the + sentence attention layer learns to attend to the most important sentences + in a document. + + Note: This module can be computationally intensive. Many works have + proposed techniques to alleviate this issue. One such method involves + projecting the inputs to lower dimensions. A Jax implementation of PCA + for dimensionality reduction can be found in `core.ml.PCA()`. One could + project the inputs in each batch before passing them to this module. + + Args: + hidden_dim (int): The output dimension of the attention module. + num_heads (int): The number of parallel attention heads. + """ + hidden_dim : int # Output dimension + num_heads : int # Number of parallel heads + + def setup(self): + # Because the Query is determined from a context, project separately + self.word_query_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.word_key_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.word_value_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.word_output = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.sentence_query_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.sentence_key_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.sentence_value_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.sentence_output = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + + + def __call__(self, + word_inputs: jnp.ndarray, + word_context: jnp.ndarray, + sentence_inputs: jnp.ndarray, + sentence_context: jnp.ndarray, + word_mask: jnp.ndarray = None, + sentence_mask: jnp.ndarray = None) -> tuple: + """Computes the hierarchical multi-head attention. + + Args: + word_inputs (jnp.ndarray): Input word representations. + word_context (jnp.ndarray): Context word representations. + sentence_inputs (jnp.ndarray): Input sentence representations. + sentence_context (jnp.ndarray): Context sentence representations. + word_mask (jnp.ndarray, optional): Mask for word attention. + sentence_mask (jnp.ndarray, optional): Mask for sentence attention. + + Returns: + tuple: A tuple containing: + - word_outputs (jnp.ndarray): Output word representations. + - sentence_outputs (jnp.ndarray): Output sentence representations. + - word_attention (jnp.ndarray): Word attention weights. + - sentence_attention (jnp.ndarray): Sentence attention weights. + """ + + word_queries = self.word_query_projection(word_inputs) + word_keys = self.word_key_projection(word_context) + word_values = self.word_value_projection(word_context) + word_attention, word_context_vectors = self.attention_function(word_queries, + word_keys, + word_values, + mask=word_mask) + + sentence_queries = self.sentence_query_projection(sentence_inputs) + sentence_keys = self.sentence_key_projection(sentence_context) + sentence_values = self.sentence_value_projection(sentence_context) + sentence_attention, sentence_context_vectors = self.attention_function(sentence_queries, + sentence_keys, + sentence_values, + mask=sentence_mask) + word_outputs = self.word_output(word_context_vectors) + sentence_outputs = self.sentence_output(sentence_context_vectors) + return word_outputs, sentence_outputs, word_attention, sentence_attention + + def attention_function(self, query, key, value, mask=None): + input_length = value.shape[1] + context_length = key.shape[1] + head_dim = query.shape[-1] // self.num_heads + dim_key = key.shape[-1] + + # Split queries, keys, and values into heads + query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) + key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) + value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + + attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + if mask is not None: + attention_scores = attention_scores * mask + + attention_weights = jax.nn.softmax(attention_scores, axis=-1) + attended_values = jnp.matmul(attention_weights, value_heads) + attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + return attended_values, attention_weights + + + +class LocalMultiHeadAttention(nn.Module): + """Local Multi-Head Attention module. + + This module implements the Local Multi-Head Attention mechanism proposed in + the paper "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) + by Ashish Vaswani et al. + + The Local Multi-Head Attention mechanism involves transforming the input + by weighting features based on their importance relative to a local + context, which is determined by a sliding window of a fixed size. This + approach reduces the computational complexity of the attention mechanism + and allows for efficient processing of long sequences. + + Args: + hidden_dim (int): The output dimension of the attention module. + num_heads (int): The number of parallel attention heads. + window_size (int, optional): The size of the local attention window. + Default is 3. + """ + hidden_dim : int # Output dimension + num_heads : int # Number of parallel heads + window_size : int = 3 + + def setup(self): + # Because the Query is determined from a context, project separately + self.query_projection = nn.Dense(self.hidden_dim*self.num_headsm, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.key_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.value_projection = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.output = nn.Dense(self.hidden_dim*self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + + + def __call__(self, + inputs: jnp.ndarray, + context: jnp.ndarray) -> tuple: + + query = self.query_projection(inputs) + key = self.key_projection(context) + value = self.value_projection(context) + + local_mask = self.create_local_attention_mask(query.shape[1], key.shape[1]) + + context_vectors, attention = self.attention_function(query,key,value,mask=local_mask) + outputs = self.output(context_vectors) + return outputs, attention + + def create_local_attention_mask(self, input_length, context_length): + # Create a matrix with shape (input_length, context_length) + mask = jnp.ones((input_length, context_length)) + + # Fill the mask with zeros outside the local window for each position + for i in range(input_length): + start = max(0, i - self.window_size // 2) + end = min(context_length, start + self.window_size) + mask = mask.at[i, :start].set(0) + mask = mask.at[i, end:].set(0) + return mask + + def attention_function(self, query, key, value, mask=None): + input_length = value.shape[1] + context_length = key.shape[1] + head_dim = query.shape[-1] // self.num_heads + dim_key = key.shape[-1] + + # Split queries, keys, and values into heads + query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) + key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) + value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + + attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + if mask is not None: + attention_scores = attention_scores * mask + + attention_weights = jax.nn.softmax(attention_scores, axis=-1) + attended_values = jnp.matmul(attention_weights, value_heads) + attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + return attended_values, attention_weights \ No newline at end of file diff --git a/nanodl/__src/layers/general.py b/nanodl/__src/layers/general.py new file mode 100644 index 0000000..94411be --- /dev/null +++ b/nanodl/__src/layers/general.py @@ -0,0 +1,34 @@ +import jax +import time +import jax.numpy as jnp +from jax import random + +def dropout(x: jnp.ndarray, + rate: float, + training: bool = False) -> jnp.ndarray: + """Apply dropout to input tensor. + + Args: + x (jnp.ndarray): Input tensor. + rate (float): Dropout rate, must be between 0 and 1. + training (bool, optional): Whether to apply dropout. + If False, returns input tensor unchanged. Defaults to False. + + Raises: + ValueError: If dropout rate is not in [0, 1). + + Returns: + jnp.ndarray: Tensor after applying dropout. + """ + if not training: + return x + + if not 0 <= rate < 1: + raise ValueError("Dropout rate must be in the range [0, 1).") + + if rate == 0: + return x + + keep_prob = 1 - rate + mask = jax.random.bernoulli(random.PRNGKey(int(time.time())), keep_prob, x.shape) + return jax.lax.select(mask, x / keep_prob, jnp.zeros_like(x)) \ No newline at end of file diff --git a/nanodl/__src/models/diffusion.py b/nanodl/__src/models/diffusion.py index c6325ca..63b3abc 100644 --- a/nanodl/__src/models/diffusion.py +++ b/nanodl/__src/models/diffusion.py @@ -484,14 +484,9 @@ def evaluate(self, return mean_loss def get_ema_weights(self, params, ema=0.999): - - new_params = {} - for key, value in params.items(): - if isinstance(value, dict): - new_params[key] = self.get_ema_weights(value, ema) - else: - new_params[key] = ema * value + (1 - ema) * value - return new_params + def func(x): + return x * ema + (1 - ema) * x + return jax.tree_util.tree_map(func, params) def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) diff --git a/nanodl/__src/models/gemma.py b/nanodl/__src/models/gemma.py new file mode 100644 index 0000000..125b5f2 --- /dev/null +++ b/nanodl/__src/models/gemma.py @@ -0,0 +1,651 @@ +import jax +import flax +import time +import optax +import jax.numpy as jnp +import flax.linen as nn +from flax.training import train_state +from typing import Tuple, Any, Optional, Iterable + +class RotaryPositionalEncoding(): + """ + Implements rotary positional encoding (RoPE) for transformers, enhancing their ability to capture sequence order. + + Rotary positional encoding applies a rotation to the embedding of each token based on its position in the sequence. This method helps preserve the relative positional information between tokens in a more effective manner compared to traditional positional encodings. + + Attributes: + dim_model (int): The dimensionality of the model embeddings. + + Methods: + _update_cos_sin_tables(x, seq_dimension): Updates cosine and sine tables based on the sequence length. + rotate_half(x): Rotates the last half of the dimensions of x by swapping them and changing signs to simulate a 90-degree rotation. + apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. + __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. + """ + def __init__(self, dim_model: int): + super().__init__() + self.dim_model = dim_model + inv_freq = 1.0 / (10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model)) + self.inv_freq = inv_freq + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=1): + seq_len = x.shape[seq_dimension] + + if seq_len != self._seq_len_cached: + self._seq_len_cached = seq_len + t = jnp.arange(seq_len, dtype=self.inv_freq.dtype) + freqs = jnp.outer(t, self.inv_freq) + emb = jnp.concatenate((freqs, freqs), axis=-1) + self._cos_cached = jnp.cos(emb)[None, None, :, :] + self._sin_cached = jnp.sin(emb)[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def rotate_half(self, x): + x1, x2 = jnp.split(x, 2, axis=-1) + return jnp.concatenate((-x2, x1), axis=-1) + + def apply_rotary_pos_emb(self, x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + return (x * cos) + (self.rotate_half(x) * sin) + + def __call__(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + return ( + self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)[0], + self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)[0], + ) + + +class GroupedRotaryMultiHeadAttention(nn.Module): + """ + Implements multi-head self-attention with grouped rotary positional embeddings. + + This module extends the concept of multi-head attention by applying rotary positional embeddings to groups of attention heads. This approach allows for a more nuanced representation of positional information and potentially improves the model's understanding of sequence order and context. + + Attributes: + hidden_dim (int): Dimensionality of the input and output features. + num_heads (int): Number of attention heads. + num_groups (int): Number of groups to split the heads into for applying rotary positional embeddings separately. + + Methods: + setup(): Initializes the projections for query, key, value, and output, along with the rotary positional encoder. + __call__(inputs, context, mask): Processes the input and context tensors through the grouped rotary multi-head attention mechanism. + process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding and attention. + attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors. + """ + hidden_dim : int # Output dimension + num_heads : int # Number of parallel heads + num_groups : int # Number of groups to split the heads into + + def setup(self): + self.query_projection = nn.Dense(self.hidden_dim // self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.value_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros + ) + self.rope = RotaryPositionalEncoding(self.hidden_dim // self.num_groups) + self.output = nn.Dense(self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros) + + def __call__(self, + inputs: jnp.ndarray, + context: jnp.ndarray, + mask: jnp.ndarray = None) -> tuple: + + query = self.query_projection(inputs) + key = self.key_projection(context) + value = self.value_projection(context) + + # Break query into groups and transpose to (num_groups, batch_size, seq_len, dims) + # This will allow vmapping over the groups for parallelization + grouped_query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_groups, -1)) + grouped_query = jnp.repeat(grouped_query, self.num_heads, axis=-1) + grouped_query = jnp.transpose(grouped_query, (2, 0, 1, 3)) + + # Repeat the key and values + key = jnp.repeat(key, self.num_heads, axis=-1) + value = jnp.repeat(value, self.num_heads, axis=-1) + vectorized_process_group = jax.vmap(self.process_group, in_axes=(0, None, None, None)) + results = vectorized_process_group(grouped_query, key, value, mask) + + # Merge the groups back together + context_vectors = jnp.concatenate(results[0], axis=-1) + return self.output(context_vectors), results[1] + + def process_group(self, query, key, value, mask): + query, key = self.rope(query, key) + return self.attention_function(query, key, value, mask=mask) + + def attention_function(self, query, key, value, mask=None): + input_length = query.shape[1] + context_length = key.shape[1] + head_dim = query.shape[-1] // self.num_heads + dim_key = key.shape[-1] + + query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) + key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) + value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + + attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + if mask is not None: + attention_scores = attention_scores * mask + + attention_weights = jax.nn.softmax(attention_scores, axis=-1) + attended_values = jnp.matmul(attention_weights, value_heads) + attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + return attended_values, attention_weights + + +class GemmaMLP(nn.Module): + """ + GemmaMLP is processes inputs through a sequence of dense layers + and applies a gating operation to enhance the representational capacity of the model. + + Attributes: + hidden_size (int): The size of the output dimension of the MLP, also the dimensionality + of the input features. + intermediate_size (int): The size of the intermediate layer, where the input is projected + to before the gating mechanism and the subsequent projection back + to the original dimensionality. + """ + hidden_size: int + intermediate_size: int + + def setup(self): + self.gate_proj = nn.Dense(self.intermediate_size) + self.up_proj = nn.Dense(self.intermediate_size) + self.down_proj = nn.Dense(self.hidden_size) + + def __call__(self, x): + gate = jax.nn.gelu(self.gate_proj(x)) + up = self.up_proj(x) + fuse = gate * up + outputs = self.down_proj(fuse) + return outputs + + +class GemmaDecoderBlock(nn.Module): + """ + Implements a decoder block for the Gemma model, incorporating grouped rotary positional embeddings. + + This block is designed to enhance the model's ability to understand and generate text by using grouped rotary positional embeddings for more nuanced positional encoding, alongside traditional transformer mechanisms like self-attention and feed-forward layers. + + Attributes: + hidden_dim (int): Dimensionality of the input and output features. + num_heads (int): Number of attention heads in the multi-head self-attention mechanism. + feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. + dropout (float): Dropout rate for regularization. + num_groups (int): Number of groups for the grouped rotary positional embeddings. + + Methods: + setup(): Initializes the components of the Gemma decoder block. + causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. + __call__(x, training): Processes the input tensor through the Gemma decoder block. + """ + hidden_dim: int + num_heads: int + feedforward_dim: int + dropout: float + num_groups: int + + def setup(self): + self.attention = GroupedRotaryMultiHeadAttention(hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups) + self.feed_forward = GemmaMLP(self.feedforward_dim, self.hidden_dim) + self.norm1 = nn.RMSNorm(self.dropout) + self.norm2 = nn.RMSNorm(self.dropout) + self.dropout1 = nn.Dropout(self.dropout) + self.dropout2 = nn.Dropout(self.dropout) + + def causal_mask(self, + batch_size: int, + destination_dim: int, + source_dim: int) -> jnp.ndarray: + + # Create index tensors for the source and destination dimensions + idx_source = jnp.arange(destination_dim)[:, None] + idx_destination = jnp.arange(source_dim) + mask = idx_source >= idx_destination - source_dim + destination_dim + mask = mask.astype(jnp.int32) + + # Expand dimensions to match the required output shape + mask = mask[None, None, :, :] + return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + + def __call__(self, + x: jnp.ndarray, + training: bool = False) -> tuple: + + mask = self.causal_mask(x.shape[0], x.shape[1], x.shape[1]) + + x = self.norm1(x) + attended_x, attention = self.attention(x, x, mask=mask) + x = self.dropout1(x, deterministic=not training) + x += attended_x + + x = self.norm2(x) + output = self.feed_forward(x) + x = self.dropout2(x, deterministic=not training) + x += output + + return x, jnp.array(attention) + + +class GemmaDecoder(nn.Module): + """ + Implements the decoder component of the LLaMA2 model. + + The decoder is composed of multiple LLaMA2DecoderBlocks, processing sequences of tokens to generate text. It includes an embedding layer to convert tokens into vectors and an output layer to predict the next token in the sequence. + + Attributes: + num_layers (int): Number of LLaMA2DecoderBlocks in the decoder. + hidden_dim (int): Dimensionality of the input and output features for the blocks. + num_heads (int): Number of attention heads in each block. + num_groups (int): Number of groups for the grouped rotary positional embeddings in each block. + feedforward_dim (int): Dimensionality of the inner layer of the feed-forward networks in the blocks. + dropout (float): Dropout rate used for regularization. + vocab_size (float): Size of the vocabulary. + embed_dim (float): Dimensionality of the token embeddings. + + Methods: + setup(): Initializes the components of the LLaMA2 decoder. + __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 decoder. + """ + num_layers: int + hidden_dim: int + num_heads: int + num_groups: int + feedforward_dim: int + dropout: float + vocab_size: float + embed_dim: float + + def setup(self): + self.embedding = nn.Embed(num_embeddings=self.vocab_size, + features=self.embed_dim) + + self.layers = [GemmaDecoderBlock(self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.num_groups) for _ in range(self.num_layers)] + + self.outputs = nn.Dense(self.vocab_size) + + + def __call__(self, + x: jnp.ndarray, + training: bool = False, + drop_last_layer: bool = False) -> tuple: + + attention_maps = [] + x = self.embedding(x) + for layer in self.layers: + x, attention = layer(x, training=training) + attention_maps.append(attention) + + if not drop_last_layer: + x = self.outputs(x) + + return x, jnp.array(attention_maps) + + + +class Gemma(nn.Module): + """ + Implements the Gemma model for text generation, featuring GQA + RMSNorm + RoPE. + + Attributes: + num_layers (int): Number of layers (blocks) in the LLaMA2 model. + num_heads (int): Number of attention heads in each block. + num_groups (int): Number of groups for the grouped rotary positional embeddings in each block. + hidden_dim (int): Dimensionality of the input and output features for the blocks. + feedforward_dim (int): Dimensionality of the inner layer of the feed-forward networks in the blocks. + dropout (float): Dropout rate used for regularization. + vocab_size (float): Size of the vocabulary. + embed_dim (float): Dimensionality of the token embeddings. + max_length (int): Maximum length of the generated sequences. + start_token (int): Token used to start the generation process. + end_token (int): Token that indicates the end of a generated sequence. + + Methods: + setup(): Initializes the LLaMA2 model including the decoder component. + __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 model. + generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. + generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. + + LlaMA is built upon the transformer architecture, incorporating enhancements inspired by recent advancements in the field of large language models. + These improvements are drawn from various sources, such as GPT-3, PaLM, and GPT-Neo. Notable modifications include the adoption of pre-normalization for enhanced training stability, + employing the RMSNorm normalization function. Additionally, the ReLU non-linearity is replaced with the SwiGLU activation function, which is a variant of the GLU activation function. + Absolute positional embeddings are replaced with rotary positional embeddings (RoPE), implemented at each layer of the network. For specific hyper-parameter details, refer to Table 2 in the document. + + Example usage: + ``` + import jax + import jax.numpy as jnp + from nanodl import ArrayDataset, DataLoader + from nanodl import Gemma, GemmaDataParallelTrainer + + # Generate dummy data + batch_size = 8 + max_length = 10 + + # Replace with actual tokenised data + data = jnp.ones((101, max_length), dtype=jnp.int32) + + # Shift to create next-token prediction dataset + dummy_inputs = data[:, :-1] + dummy_targets = data[:, 1:] + + # Create dataset and dataloader + dataset = ArrayDataset(dummy_inputs, dummy_targets) + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + # How to loop through dataloader + for batch in dataloader: + x, y = batch + print(x.shape, y.shape) + break + + # model parameters + hyperparams = { + 'num_layers': 1, + 'hidden_dim': 256, + 'num_heads': 2, + 'feedforward_dim': 256, + 'dropout': 0.1, + 'vocab_size': 1000, + 'embed_dim': 256, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + 'num_groups': 2, + } + + # Initialize model + model = Gemma(**hyperparams) + rngs = jax.random.PRNGKey(0) + rngs, dropout_rng = jax.random.split(rngs) + params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] + + # Call as you would a Jax/Flax model + outputs = model.apply({'params': params}, + dummy_inputs, + rngs={'dropout': dropout_rng}) + print(outputs.shape) + + # Training on data + trainer = GemmaDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') + trainer.train(train_loader=dataloader, + num_epochs=2, + val_loader=dataloader) + + print(trainer.evaluate(dataloader)) + + # Generating from a start token + start_tokens = jnp.array([[123, 456]]) + + # Remember to load the trained parameters + params = trainer.load_params('params.pkl') + outputs = model.apply({'params': params}, + start_tokens, + rngs={'dropout': jax.random.PRNGKey(2)}, + method=model.generate) + print(outputs) + ``` + """ + num_layers: int + num_heads: int + num_groups: int + hidden_dim: int + feedforward_dim: int + dropout: float + vocab_size: float + embed_dim: float + max_length: int + start_token: int + end_token: int + + def setup(self): + + self.decoder = GemmaDecoder(self.num_layers, + self.hidden_dim, + self.num_heads, + self.num_groups, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim) + + def __call__(self, + x: jnp.ndarray, + training: bool = False, + drop_last_layer: bool = False) -> jnp.ndarray: + + + return self.decoder(x=x, + training=training, + drop_last_layer=drop_last_layer)[0] + + + def generate(self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False) -> Tuple[jnp.ndarray]: + + if x is not None: + assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" + + decoder_input = x if x is not None else jnp.array([[self.start_token]]) + output_sequence = [] + + # Autoregressive decoding loop + for _ in range(self.max_length): + decoder_output = self.decoder(decoder_input, training=False)[0] + last_token_logits = decoder_output[:, -1, :] + scaled_logits = last_token_logits / temperature + next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) + + if deterministic: + next_token = jnp.argmax(next_token_probabilities, axis=-1) + else: + next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + + next_token = next_token[0] + output_sequence.append(next_token.item()) + decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + + if next_token.item() == self.end_token: + break + + return jnp.array(output_sequence) + + + def generate_batch(self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False) -> jnp.ndarray: + + batch_size = x.shape[0] if x is not None else 1 + decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) + + for i in range(self.max_length): + decoder_output = self.decoder(decoder_input, training=False)[0] + last_token_logits = decoder_output[:, -1, :] + scaled_logits = last_token_logits / temperature + next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) + + if deterministic: + next_token = jnp.argmax(next_token_probabilities, axis=-1) + else: + key = jax.random.PRNGKey(int(time.time())) + next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + + output_sequences = output_sequences.at[:, i].set(next_token) + decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + + if jnp.all(next_token == self.end_token): + break + + return output_sequences + + + +class GemmaDataParallelTrainer: + """ + Trainer class using data parallelism with JAX. + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + It handles the model training loop, including gradient computation, parameter updates, and evaluation. + + Attributes: + model (Any): The model to be trained. + input_shape (Tuple[int, ...]): The shape of the input tensor. + weights_filename (str): Filename where the trained model weights will be saved. + learning_rate (float): Learning rate for the optimizer. + params_path (Optional[str]): Path to pre-trained model parameters for initializing the model, if available. + + Methods: + create_train_state(learning_rate, text_input_shape, image_input_shape): Initializes the training state, including parameters and optimizer. + train_step(state, texts, images): Performs a single training step, including forward pass, loss computation, and gradients update. + train(train_loader, num_epochs, val_loader): Runs the training loop over the specified number of epochs, using the provided data loaders for training and validation. + evaluation_step(state, texts, images): Performs an evaluation step, computing forward pass and loss without updating model parameters. + evaluate(test_loader): Evaluates the model performance on a test dataset. + save_params(): Saves the model parameters to a file. + load_params(filename): Loads model parameters from a file. + """ + def __init__(self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None) -> None: + self.model = model + self.params = None + self.params_path = params_path + self.num_parameters = None + self.best_val_loss = float("inf") + self.weights_filename = weights_filename + self.num_devices = jax.local_device_count() + self.train_step = jax.pmap(GemmaDataParallelTrainer.train_step, axis_name='devices') + self.evaluation_step = jax.pmap(GemmaDataParallelTrainer.evaluation_step, axis_name='devices') + self.state = self.create_train_state(learning_rate, input_shape) + print(f'Number of accelerators: {self.num_devices}') + + + def create_train_state(self, + learning_rate: float, + input_shape: Tuple[int, ...]) -> Any: + + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + params = self.model.init(rngs, + jnp.ones(input_shape, dtype=jnp.int32))['params'] + + if self.params_path is not None: + params = self.load_params(self.params_path) + + self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) + print(f'Number of parameters: {self.num_parameters}') + state = train_state.TrainState.create(apply_fn=self.model.apply, + params=params, + tx=optax.adam(learning_rate)) + return jax.device_put_replicated(state, jax.local_devices()) + + @staticmethod + def train_step(state: Any, + inputs: jnp.ndarray, + targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + def loss_fn(params): + logits = state.apply_fn({'params': params}, + inputs, + training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + def train(self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: + + for epoch in range(num_epochs): + total_loss = 0.0 + count = 0 + for inputs, targets in train_loader: + batch_size = inputs.shape[0] + batch_size_per_device = batch_size // self.num_devices + inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) + targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) + self.state, loss = self.train_step(state=self.state, + inputs=inputs, + targets=targets) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + + if val_loader is not None: + val_loss = self.evaluate(val_loader) + print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + print("New best validation score achieved, saving model...") + self.save_params() + return + + @staticmethod + def evaluation_step(state: Any, + inputs: jnp.ndarray, + targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() + + def evaluate(self, + test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + + total_loss = 0.0 + count = 0 + for inputs, targets in test_loader: + batch_size = inputs.shape[0] + batch_size_per_device = batch_size // self.num_devices + inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) + targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) + loss = self.evaluation_step(self.state, inputs, targets) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + return mean_loss + + def save_params(self) -> None: + self.params = flax.jax_utils.unreplicate(self.state.params) + with open(self.weights_filename, 'wb') as f: + f.write(flax.serialization.to_bytes(self.params)) + + def load_params(self, filename: str): + with open(filename, 'rb') as f: + self.params = flax.serialization.from_bytes(self.params, f.read()) + return self.params \ No newline at end of file diff --git a/nanodl/__src/models/mamba_experimental.py b/nanodl/__src/models/mamba_experimental.py new file mode 100644 index 0000000..fdcf320 --- /dev/null +++ b/nanodl/__src/models/mamba_experimental.py @@ -0,0 +1,554 @@ +import jax +import flax +import time +import math +import optax +import jax.numpy as jnp +import flax.linen as nn +from einops import einsum +from flax.training import train_state +from typing import Tuple, Any, Optional, Iterable + +########## EXPERIMENMTAL ############ + +class MambaBlock(nn.Module): + """ + MambaBlock is a custom neural network block that incorporates normalization, + convolution, and dense layers to process input sequences. This block is designed + for sequence modeling tasks and includes specialized components like selective + scan for dynamic computation. + + Attributes: + d_inner (int): Dimensionality of the inner dense layer. + d_conv (int): Size of the convolution kernel. + dt_rank (int): Rank for delta transformations in the selective scan. + d_state (int): Dimensionality of the state vector in the selective scan. + d_model (int): Dimensionality of the input and output of the block. + seq_len (int): Length of the input sequences. + bias (bool): Flag indicating whether to use bias in dense layers. + conv_bias (bool): Flag indicating whether to use bias in the convolution layer. + """ + d_inner: int + d_conv: int + dt_rank: int + d_state: int + d_model: int + seq_len: int + bias: bool + conv_bias: bool + + def setup(self): + self.norm = nn.RMSNorm(self.d_model) + self.in_proj = nn.Dense(features=self.d_inner * 2, use_bias=self.bias) + + self.conv1d = nn.Conv(features=self.seq_len, + kernel_size=(self.d_conv,), + strides=(1,), + padding='SAME', + use_bias=self.conv_bias, + feature_group_count=self.d_inner) + + self.x_proj = nn.Dense(features=self.dt_rank + self.d_state * 2, use_bias=False) + self.dt_proj = nn.Dense(features=self.d_inner, use_bias=True) + self.out_proj = nn.Dense(features=self.d_model, use_bias=self.bias) + + # Parameter initialization + A = jnp.tile(jnp.arange(1, self.d_state + 1), (self.d_inner, 1)) + self.A_log = self.variable('params', 'A_log', lambda: jnp.log(A)) + self.D = self.variable('params', 'D', lambda: jnp.ones((self.d_inner,))) + + def __call__(self, inputs: jnp.ndarray): + u = self.norm(inputs) + A = -jnp.exp(self.A_log.value) + D = self.D.value + x_and_res = self.in_proj(u) + x, res = jnp.split(x_and_res, 2, axis=-1) + x = jnp.transpose(x, (0, 2, 1)) + x = self.conv1d(x)[:, :, :u.shape[1]] + x = jnp.transpose(x, (0, 2, 1)) + x = nn.silu(x) + + x_dbl = self.x_proj(u) + delta, B, C = jnp.split(x_dbl, indices_or_sections=[self.dt_rank, + self.dt_rank + self.d_state], + axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + y = self.selective_scan(x, delta, A, B, C, D) + y = y * nn.silu(res) + return self.out_proj(y) + inputs + + def selective_scan(self, + u: jnp.ndarray, + delta: jnp.ndarray, + A: jnp.ndarray, + B: jnp.ndarray, + C: jnp.ndarray, + D: jnp.ndarray) -> jnp.ndarray: + + b, l, d_in = u.shape + n = A.shape[1] + + deltaA = jnp.exp(einsum( + delta, A, + 'b l d_in, d_in n -> b l d_in n')) + + deltaB_u = einsum( + delta, B, u, + 'b l d_in, b l n, b l d_in -> b l d_in n') + + x = jnp.zeros((b, d_in, n)) + ys = [] + + for i in range(l): + x = deltaA[:, i] * x + deltaB_u[:, i] + y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + + return jnp.stack(ys, axis=1) + u * D + + +class Mamba(nn.Module): + """ + MAMBA is an advanced ML model renowned for its exceptional linear-time processing efficiency, + which notably enhances its inference speed to outperform traditional Transformer models by up to five times in throughput. + Unlike conventional models that struggle with long sequence lengths, MAMBA demonstrates a linear scalability with sequence length, + maintaining or even improving its performance with sequences that extend up to a million elements. + This attribute makes MAMBA a highly versatile and efficient backbone for a variety of sequence modeling tasks across different domains, + including but not limited to language processing, audio analysis, and genomic studies. + + Attributes: + vocab_size (int): The size of the vocabulary. + n_layer (int): The number of MambaBlock layers. + d_conv (int): The convolution kernel size used within each MambaBlock. + d_state (int): The dimensionality of the state vector in each MambaBlock's selective scan. + d_model (int): The dimensionality of the embeddings and the input/output size of each layer. + max_length (int): The maximum length of the input sequences. + expand (int): Factor to determine the inner dimension size based on `d_model`. + start_token (int): The token used to indicate the start of a sequence. + end_token (int): The token used to indicate the end of a sequence. + dropout (float): Dropout rate used in the dropout layer. + bias (bool): Indicates whether to use bias in the Dense layers of MambaBlock. Defaults to True. + conv_bias (bool): Indicates whether to use bias in the Conv layer of MambaBlock. Defaults to True. + dt_rank (int or 'auto'): The rank for delta transformations in each MambaBlock's selective scan. If 'auto', + it is calculated based on `d_model`. + + Example: + ```python + import jax + import jax.numpy as jnp + from nanodl import ArrayDataset, DataLoader + #from nanodl import Mamba, MambaDataParallelTrainer + + # Generate dummy data + batch_size = 8 + max_length = 128 + + # Replace with actual tokenised data + data = jnp.ones((101, max_length+1), dtype=jnp.int16) + + # Shift to create next-token prediction dataset + dummy_inputs = data[:, :-1] + dummy_targets = data[:, 1:] + + # Create dataset and dataloader + dataset = ArrayDataset(dummy_inputs, dummy_targets) + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + # How to loop through dataloader + for batch in dataloader: + x, y = batch + print(x.shape, y.shape) + break + + # model parameters + hyperparams = { + 'vocab_size': 100, + 'expand': 2, + 'n_layer': 2, + 'd_conv': 3, + 'dt_rank': 16, + 'd_state': 8, + 'd_model': 64, + 'dropout': 0.2, + 'bias':True, + 'conv_bias': True, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + } + + # Initialize model + model = Mamba(**hyperparams) + rngs = jax.random.PRNGKey(0) + rngs, dropout_rng = jax.random.split(rngs) + params = model.init({'params': rngs, 'dropout': dropout_rng}, + dummy_inputs)['params'] + + # Call as you would a Jax/Flax model + outputs = model.apply({'params': params}, + dummy_inputs, + rngs={'dropout': dropout_rng}) + + print(outputs.shape) + + # Training on data + trainer = MambaDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') + trainer.train(train_loader=dataloader, + num_epochs=2, + val_loader=dataloader) + + print(trainer.evaluate(dataloader)) + + # Generating from a start token + start_tokens = jnp.array([[123, 456]]) + + # Remember to load the trained parameters + params = trainer.load_params('params.pkl') + outputs = model.apply({'params': params}, + start_tokens, + rngs={'dropout': jax.random.PRNGKey(2)}, + method=model.generate) + print(outputs) + ``` + """ + vocab_size: int + n_layer: int + d_conv: int + d_state: int + d_model: int + max_length: int + expand: int + max_length: int + start_token: int + end_token: int + dropout: float + bias: bool = True + conv_bias: bool = True + dt_rank: int = 'auto' + + def setup(self): + self.d_inner = int(self.expand * self.d_model) + + if self.dt_rank == 'auto': + self.dt_rank = math.ceil(self.d_model / 16) + + self.embedding = nn.Embed(self.vocab_size, self.d_model) + + self.layers = [MambaBlock(d_inner=self.d_inner, + d_conv=self.d_conv, + dt_rank=self.dt_rank, + d_state=self.d_state, + d_model=self.d_model, + seq_len=self.max_length, + bias=self.bias, + conv_bias=self.conv_bias) for _ in range(self.n_layer)] + + self.norm_f = nn.RMSNorm(self.d_model) + self.dropout1 = nn.Dropout(self.dropout) + self.lm_head = nn.Dense(features=self.vocab_size, use_bias=False) + # Note: Flax doesn't support parameter sharing like PyTorch's weight tying directly. + # You might need to implement a custom method for weight tying or handle it outside the model definition. + + def __call__(self, + input_ids: jnp.ndarray, + training: bool = False) -> jnp.ndarray: + + x = self.embedding(input_ids) + for layer in self.layers: + x = self.dropout1(layer(x), deterministic=not training) + + x = self.norm_f(x) + logits = self.lm_head(x) + return logits + + + def zero_pad(self, arr, max_length): + current_length = arr.shape[1] + num_zeros = max_length - current_length + + if num_zeros > 0: + zeros = jnp.zeros((arr.shape[0], num_zeros), dtype=arr.dtype) + padded_array = jnp.concatenate([arr, zeros], axis=1) + else: + padded_array = arr + + return padded_array + + + def generate(self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False) -> Tuple[jnp.ndarray]: + + if x is not None: + assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" + + decoder_input = x if x is not None else jnp.array([[self.start_token]]) + output_sequence = [] + + # Autoregressive decoding loop + print(self.zero_pad(decoder_input, self.max_length).shape) + for _ in range(self.max_length-1): + decoder_output = self.__call__(self.zero_pad(decoder_input, self.max_length), training=False)[0] + print(decoder_output.shape) + last_token_logits = decoder_output[:, -1, :] + scaled_logits = last_token_logits / temperature + next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) + + if deterministic: + next_token = jnp.argmax(next_token_probabilities, axis=-1) + else: + next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + + next_token = next_token[0] + output_sequence.append(next_token.item()) + decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + + if next_token.item() == self.end_token or len(output_sequence) == self.max_length: + break + + return jnp.array(output_sequence) + + + def generate_batch(self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False) -> jnp.ndarray: + + batch_size = x.shape[0] if x is not None else 1 + decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) + + for i in range(self.max_length-1): + decoder_output = self.__call__(self.zero_pad(decoder_input, self.max_length), training=False)[0] + last_token_logits = decoder_output[:, -1, :] + scaled_logits = last_token_logits / temperature + next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) + + if deterministic: + next_token = jnp.argmax(next_token_probabilities, axis=-1) + else: + key = jax.random.PRNGKey(int(time.time())) + next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + + output_sequences = output_sequences.at[:, i].set(next_token) + decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + + if jnp.all(next_token == self.end_token) or len(output_sequences) == self.max_length: + break + + return output_sequences + + +class MambaDataParallelTrainer: + """ + Trainer class using data parallelism with JAX. + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + It handles the model training loop, including gradient computation, parameter updates, and evaluation. + + Attributes: + model (Any): The model to be trained. + input_shape (Tuple[int, ...]): The shape of the input tensor. + weights_filename (str): Filename where the trained model weights will be saved. + learning_rate (float): Learning rate for the optimizer. + params_path (Optional[str]): Path to pre-trained model parameters for initializing the model, if available. + + Methods: + create_train_state(learning_rate, text_input_shape, image_input_shape): Initializes the training state, including parameters and optimizer. + train_step(state, texts, images): Performs a single training step, including forward pass, loss computation, and gradients update. + train(train_loader, num_epochs, val_loader): Runs the training loop over the specified number of epochs, using the provided data loaders for training and validation. + evaluation_step(state, texts, images): Performs an evaluation step, computing forward pass and loss without updating model parameters. + evaluate(test_loader): Evaluates the model performance on a test dataset. + save_params(): Saves the model parameters to a file. + load_params(filename): Loads model parameters from a file. + """ + def __init__(self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None) -> None: + + self.model = model + self.params = None + self.params_path = params_path + self.num_parameters = None + self.best_val_loss = float("inf") + self.weights_filename = weights_filename + self.num_devices = jax.local_device_count() + self.train_step = jax.pmap(MambaDataParallelTrainer.train_step, axis_name='devices') + self.evaluation_step = jax.pmap(MambaDataParallelTrainer.evaluation_step, axis_name='devices') + self.state = self.create_train_state(learning_rate, input_shape) + print(f'Number of accelerators: {self.num_devices}') + + + def create_train_state(self, + learning_rate: float, + input_shape: Tuple[int, ...]) -> Any: + + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + params = self.model.init(rngs, + jnp.ones(input_shape, dtype=jnp.int32))['params'] + + if self.params_path is not None: + params = self.load_params(self.params_path) + + self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) + print(f'Number of parameters: {self.num_parameters}') + state = train_state.TrainState.create(apply_fn=self.model.apply, + params=params, + tx=optax.adam(learning_rate)) + return jax.device_put_replicated(state, jax.local_devices()) + + @staticmethod + def train_step(state: Any, + inputs: jnp.ndarray, + targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + def loss_fn(params): + logits = state.apply_fn({'params': params}, + inputs, + training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + def train(self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: + + for epoch in range(num_epochs): + total_loss = 0.0 + count = 0 + for inputs, targets in train_loader: + batch_size = inputs.shape[0] + batch_size_per_device = batch_size // self.num_devices + inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) + targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) + self.state, loss = self.train_step(state=self.state, + inputs=inputs, + targets=targets) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + + if val_loader is not None: + val_loss = self.evaluate(val_loader) + print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + print("New best validation score achieved, saving model...") + self.save_params() + return + + @staticmethod + def evaluation_step(state: Any, + inputs: jnp.ndarray, + targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() + + def evaluate(self, + test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + + total_loss = 0.0 + count = 0 + for inputs, targets in test_loader: + batch_size = inputs.shape[0] + batch_size_per_device = batch_size // self.num_devices + inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) + targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) + loss = self.evaluation_step(self.state, inputs, targets) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + return mean_loss + + def save_params(self) -> None: + self.params = flax.jax_utils.unreplicate(self.state.params) + with open(self.weights_filename, 'wb') as f: + f.write(flax.serialization.to_bytes(self.params)) + + def load_params(self, filename: str): + with open(filename, 'rb') as f: + self.params = flax.serialization.from_bytes(self.params, f.read()) + return self.params + + +import jax +import jax.numpy as jnp +from nanodl import ArrayDataset, DataLoader +#from nanodl import Mamba, MambaDataParallelTrainer + +# Generate dummy data +batch_size = 8 +max_length = 128 + +# Replace with actual tokenised data +data = jnp.ones((101, max_length+1), dtype=jnp.int16) + +# Shift to create next-token prediction dataset +dummy_inputs = data[:, :-1] +dummy_targets = data[:, 1:] + +# Create dataset and dataloader +dataset = ArrayDataset(dummy_inputs, dummy_targets) +dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + +# How to loop through dataloader +for batch in dataloader: + x, y = batch + print(x.shape, y.shape) + break + +# model parameters +hyperparams = { + 'vocab_size': 100, + 'expand': 2, + 'n_layer': 2, + 'd_conv': 3, + 'dt_rank': 16, + 'd_state': 8, + 'd_model': 64, + 'dropout': 0.2, + 'bias':True, + 'conv_bias': True, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, +} + +# Initialize model +model = Mamba(**hyperparams) +rngs = jax.random.PRNGKey(0) +rngs, dropout_rng = jax.random.split(rngs) +params = model.init({'params': rngs, 'dropout': dropout_rng}, + dummy_inputs)['params'] + +# Call as you would a Jax/Flax model +outputs = model.apply({'params': params}, + dummy_inputs, + rngs={'dropout': dropout_rng}) + +print(outputs.shape) + +start_tokens = jnp.array([[123, 456]]) +outputs = model.apply({'params': params}, + start_tokens, + rngs={'dropout': jax.random.PRNGKey(2)}, + method=model.generate) +print(outputs) \ No newline at end of file diff --git a/nanodl/__src/models/reward.py b/nanodl/__src/models/reward.py new file mode 100644 index 0000000..32401ef --- /dev/null +++ b/nanodl/__src/models/reward.py @@ -0,0 +1,247 @@ +import jax +import flax +import time +import optax +import jax.numpy as jnp +import flax.linen as nn +from flax.training import train_state +from typing import Tuple, Any, Optional, Iterable + + +class RewardModel(nn.Module): + """ + The RewardModel estimates the reward or value of a given input sequence, + typically used in reinforcement learning frameworks for natural language processing tasks. + It uses the last hidden state of a transformer-based model to generate a scalar reward prediction, + guiding the agent's behavior by evaluating the desirability or utility of its generated outputs. + + Example: + ```python + from nanodl import ArrayDataset, DataLoader + from nanodl import Gemma, RewardModel, RewardDataParallelTrainer + + # Generate dummy data + batch_size = 8 + max_length = 10 + + # Replace with actual tokenised data + dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) + dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) + + # Create dataset and dataloader + dataset = ArrayDataset(dummy_chosen, dummy_rejected) + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + # model parameters + hyperparams = { + 'num_layers': 1, + 'hidden_dim': 256, + 'num_heads': 2, + 'feedforward_dim': 256, + 'dropout': 0.1, + 'vocab_size': 1000, + 'embed_dim': 256, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + 'num_groups': 2, + } + + # Initialize reward model from Gemma + model = Gemma(**hyperparams) + reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) + + # Train the reward model + trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') + trainer.train(dataloader, 5, dataloader) + params = trainer.load_params('reward_model_weights.pkl') + + # Call as you would a regular Flax model + rngs = jax.random.PRNGKey(0) + rngs, dropout_rng = jax.random.split(rngs) + rewards = reward_model.apply({'params': params}, + dummy_chosen, + rngs={'dropout': dropout_rng}) + + print(rewards.shape) + ``` + """ + model: nn.Module + dim: int + dropout: float + + @nn.compact + def __call__(self, + x: jnp.ndarray, + training: bool = False): + + x = self.model(x, training=training, drop_last_layer=True) + x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) + x = nn.Dense(1)(x) + return nn.sigmoid(x)[:, -1, 0] + + +class RewardDataParallelTrainer: + """ + Trainer class using data parallelism with JAX. + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + It handles the model training loop, including gradient computation, parameter updates, and evaluation. + + Attributes: + model (Any): The model to be trained. + input_shape (Tuple[int, ...]): The shape of the input tensor. + weights_filename (str): Filename where the trained model weights will be saved. + learning_rate (float): Learning rate for the optimizer. + params_path (Optional[str]): Path to pre-trained reward model parameters for initializing the REWARD model, if available. + model_params_path (Optional[str]): Path to pre-trained backbone model parameters for initializing the BACKBONE model, if available. + + Methods: + create_train_state(learning_rate, text_input_shape, image_input_shape): Initializes the training state, including parameters and optimizer. + train_step(state, texts, images): Performs a single training step, including forward pass, loss computation, and gradients update. + train(train_loader, num_epochs, val_loader): Runs the training loop over the specified number of epochs, using the provided data loaders for training and validation. + evaluation_step(state, texts, images): Performs an evaluation step, computing forward pass and loss without updating model parameters. + evaluate(test_loader): Evaluates the model performance on a test dataset. + save_params(): Saves the model parameters to a file. + load_params(filename): Loads model parameters from a file. + """ + def __init__(self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + model_params_path: Optional[str] = None) -> None: + + self.model = model + self.params = None + self.params_path = params_path + self.model_params_path = model_params_path + self.num_parameters = None + self.best_val_loss = float("inf") + self.weights_filename = weights_filename + self.num_devices = jax.local_device_count() + self.train_step = jax.pmap(RewardDataParallelTrainer.train_step, axis_name='devices') + self.evaluation_step = jax.pmap(RewardDataParallelTrainer.evaluation_step, axis_name='devices') + self.state = self.create_train_state(learning_rate, input_shape) + print(f'Number of accelerators: {self.num_devices}') + + + def create_train_state(self, + learning_rate: float, + input_shape: Tuple[int, ...]) -> Any: + + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + params = self.model.init(rngs, + jnp.ones(input_shape, dtype=jnp.int32))['params'] + + if self.params_path is not None: + params = self.load_params(self.params_path) + + if self.model_params_path is not None: + model_params = self.load_params(self.model_params_path) + params = self.merge_params(model_params, params) + + self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) + print(f'Number of parameters: {self.num_parameters}') + state = train_state.TrainState.create(apply_fn=self.model.apply, + params=params, + tx=optax.adam(learning_rate)) + return jax.device_put_replicated(state, jax.local_devices()) + + @staticmethod + def train_step(state: Any, + chosen: jnp.ndarray, + rejected: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + def loss_fn(params): + chosen_rewards = state.apply_fn({'params': params}, + chosen, + training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + + rejected_rewards = state.apply_fn({'params': params}, + rejected, + training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + + return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean() + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + def train(self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: + + for epoch in range(num_epochs): + total_loss = 0.0 + count = 0 + for chosen, rejected in train_loader: + batch_size = chosen.shape[0] + batch_size_per_device = batch_size // self.num_devices + chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1)) + rejected = rejected.reshape((self.num_devices, batch_size_per_device, -1)) + self.state, loss = self.train_step(state=self.state, + chosen=chosen, + rejected=rejected) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + + if val_loader is not None: + val_loss = self.evaluate(val_loader) + print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + print("New best validation score achieved, saving model...") + self.save_params() + return + + @staticmethod + def evaluation_step(state: Any, + chosen: jnp.ndarray, + rejected: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + chosen_rewards = state.apply_fn({'params': state.params}, chosen, rngs={'dropout': jax.random.PRNGKey(2)}) + rejected_rewards = state.apply_fn({'params': state.params}, rejected, rngs={'dropout': jax.random.PRNGKey(2)}) + return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean() + + def evaluate(self, + test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + + total_loss = 0.0 + count = 0 + for chosen, rejected in test_loader: + batch_size = chosen.shape[0] + batch_size_per_device = batch_size // self.num_devices + chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1)) + rejected = rejected.reshape((self.num_devices, batch_size_per_device, -1)) + loss = self.evaluation_step(self.state, chosen, rejected) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + return mean_loss + + def merge_params(untrained_params, trained_params): + updated_untrained_params = jax.tree_map( + lambda untrained, trained: trained if untrained.shape == trained.shape else untrained, + untrained_params, + trained_params) + return updated_untrained_params + + def save_params(self) -> None: + self.params = flax.jax_utils.unreplicate(self.state.params) + with open(self.weights_filename, 'wb') as f: + f.write(flax.serialization.to_bytes(self.params)) + + def load_params(self, filename: str): + with open(filename, 'rb') as f: + self.params = flax.serialization.from_bytes(self.params, f.read()) + return self.params \ No newline at end of file diff --git a/nanodl/__src/models/rlhf.py b/nanodl/__src/models/rlhf.py new file mode 100644 index 0000000..b24d2c0 --- /dev/null +++ b/nanodl/__src/models/rlhf.py @@ -0,0 +1,315 @@ +import jax +import flax +import time +import copy +import optax +import jax.numpy as jnp +import flax.linen as nn +from flax.training import train_state +from typing import Tuple, Any, Optional, Iterable + + +class RLHF(nn.Module): + policy_network: Any + reference: bool = False + + def setup(self) -> None: + self.dense1 = nn.Dense(256) + self.dense2 = nn.Dense(256) + self.dense3 = nn.Dense(1) + + def __call__(self, + x: jnp.ndarray, + training: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]: + + logits = self.policy_network(x, training=training) + log_probs = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + probs = jnp.exp(log_probs) + rng = jax.random.PRNGKey(int(time.time())) + action = jax.random.categorical(rng, log_probs, axis=-1) + entropy = -jnp.sum(probs * log_probs, axis=-1) + action_log_probs = jnp.take_along_axis(log_probs, action[:, None], axis=-1) + value = self.get_value(x) if not self.reference else None + return action, action_log_probs, entropy, value + + def get_value(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: + hidden = self.policy_network(x, training=training, drop_last_layer=True) + hidden = nn.relu(self.dense1(hidden)) + hidden = nn.relu(self.dense2(hidden)) + value = nn.tanh(self.dense3(hidden)) + return value + + def generate(self, x: jnp.ndarray) -> jnp.ndarray: + return self.policy_network.generate(x) + + def generate_batch(self, x: jnp.ndarray) -> jnp.ndarray: + return self.policy_network.generate_batch(x) + + +class PPODataParallelTrainer: + def __init__(self, + rlhf_main: Any, + rlhf_ref: Any, + reward_model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + gamma: float = 0.99, + beta: float = 0.2, + lam: float = 0.95, + ent_coef: float = 0.01, + vf_coef: float = 0.5, + learning_rate: float = 1e-4, + params_path: Optional[str] = None, + sft_params_path: Optional[str] = None, + reward_params_path: Optional[str] = None, + ) -> None: + + self.rlhf_main = rlhf_main + self.reward_model = reward_model + self.rlhf_ref = rlhf_ref + + self.gamma = gamma + self.lam = lam + self.beta = beta + self.epsilon = 1.0e-8 + self.ent_coef = ent_coef + self.vf_coef = vf_coef + + self.params = None + self.ref_params = None + self.params_path = params_path + self.sft_params = self.load_params(sft_params_path) + + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + reward_params = self.reward_model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))['params'] + self.reward_params = self.load_params(reward_params_path, params=reward_params) + + self.num_parameters = None + self.best_val_loss = float("inf") + self.weights_filename = weights_filename + self.num_devices = jax.local_device_count() + self.train_step = jax.pmap(PPODataParallelTrainer.train_step, axis_name='devices') + self.state = self.create_train_state(learning_rate, input_shape) + print(f'Number of accelerators: {self.num_devices}') + + + def create_train_state(self, + learning_rate: float, + input_shape: Tuple[int, ...]) -> Any: + + rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + params = self.rlhf_main.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))['params'] + params['policy_network']['decoder'] = self.sft_params['decoder'] + self.ref_params = copy.deepcopy(params) + + if self.params_path is not None: + params = self.load_params(self.params_path) + + self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) + print(f'Number of parameters: {self.num_parameters}') + state = train_state.TrainState.create(apply_fn=self.rlhf_main.apply, + params=params, + tx=optax.adam(learning_rate)) + + return jax.device_put_replicated(state, jax.local_devices()) + + + def compute_agent_objective(self, model_logits, sft_logits, reward_score, gamma, beta): + ratio = nn.log_softmax(model_logits, axis=-1) - nn.log_softmax(sft_logits, axis=-1) + left = jnp.mean(reward_score - beta * ratio.mean(axis=-1)) + right = gamma * nn.log_softmax(model_logits, axis=-1).mean(axis=-1) + return left + right + + def advantage_and_return(self, rewards, values): + rewards = jnp.expand_dims(rewards, axis=0) + values = jnp.expand_dims(values, axis=0) + + gen_len = rewards.shape[1] + lastgaelam = 0 + advantages_reversed = [] + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.gamma * self.lam * lastgaelam + advantages_reversed.append(lastgaelam) + + # Reversing and stacking to create the correct shape for advantages + advantages = jnp.vstack(advantages_reversed[::-1]).T + returns = advantages + values + advantages = jnp.squeeze(advantages, axis=0) + returns = jnp.squeeze(returns, axis=0) + return advantages, returns + + def calculate_loss(self, logprobs, values, entropies, ref_logprobs, rewards): + ratio = jnp.exp(logprobs - ref_logprobs) + clipped_ratio = jnp.clip(ratio, 1 - self.epsilon, 1 + self.epsilon) + advantages, returns = self.advantage_and_return(rewards, values) + value_loss = jnp.square(values - returns).mean() + pg_loss_1 = advantages * ratio + pg_loss_2 = advantages * clipped_ratio + pg_loss = jnp.minimum(pg_loss_1, pg_loss_2).mean() + loss = pg_loss - self.ent_coef * entropies.mean() + self.vf_coef * value_loss + return loss + + def get_ref_log_probs(self, inputs: jnp.ndarray) -> jnp.ndarray: + return self.rlhf_ref.apply({'params': self.ref_params}, + inputs, training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + + def get_rewards(self, inputs: jnp.ndarray) -> jnp.ndarray: + responses = self.rlhf_main.apply({'params': self.params}, + inputs, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}, + method=self.rlhf_main.generate_batch) + return self.reward_model.apply({'params': self.reward_params}, + responses, + training=False, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + + def train_step(self, + state: Any, + inputs: jnp.ndarray, + ref_log_probs: jnp.ndarray, + rewards: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + + def loss_fn(params): + _, action_log_probs, entropy, value = state.apply_fn({'params': params}, + inputs, + training=True, + rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) + + + + return self.calculate_loss(action_log_probs, value, entropy, ref_log_probs, rewards) + + loss, grads = jax.value_and_grad(loss_fn)(state.params) + state = state.apply_gradients(grads=grads) + return state, loss + + def train(self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: + + for epoch in range(num_epochs): + total_loss = 0.0 + count = 0 + for inputs in train_loader: + inputs = inputs[0] + ref_log_probs = self.get_ref_log_probs(inputs) + rewards = self.get_rewards(inputs) + batch_size = inputs.shape[0] + batch_size_per_device = batch_size // self.num_devices + inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) + ref_log_probs = ref_log_probs.reshape((self.num_devices, batch_size_per_device, -1)) + rewards = rewards.reshape((self.num_devices, batch_size_per_device, -1)) + self.state, loss = self.train_step(state=self.state, + inputs=inputs, + ref_log_probs=ref_log_probs, + rewards=rewards) + total_loss += jnp.mean(loss) + count += 1 + + mean_loss = total_loss / count + print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + + if val_loader is not None: + val_loss = self.evaluate(val_loader) + print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + if val_loss < self.best_val_loss: + self.best_val_loss = val_loss + print("New best validation score achieved, saving model...") + self.save_params() + return + + def merge_params(self, untrained_params, trained_params): + updated_untrained_params = jax.tree_map( + lambda untrained, trained: trained if untrained.shape == trained.shape else untrained, + untrained_params, + trained_params) + return updated_untrained_params + + def save_params(self) -> None: + self.params = flax.jax_utils.unreplicate(self.state.params) + with open(self.weights_filename, 'wb') as f: + f.write(flax.serialization.to_bytes(self.params)) + + def load_params(self, filename: str, params=None): + with open(filename, 'rb') as f: + params = self.params if params is None else params + self.params = flax.serialization.from_bytes(params, f.read()) + return self.params + + + + +from nanodl import ArrayDataset, DataLoader +from nanodl import Gemma, GemmaDataParallelTrainer +from nanodl import RewardModel, RewardDataParallelTrainer +# from nanodl import RLHF, PPODataParallelTrainer + +batch_size = 8 +max_length = 10 +model_params_path = 'base_params.pkl' +rlhf_params_path = 'rlhf_params.pkl' +reward_params_path = 'reward_params.pkl' + +# model parameters +hyperparams = { + 'num_layers': 1, + 'hidden_dim': 128, + 'num_heads': 2, + 'feedforward_dim': 128, + 'dropout': 0.1, + 'vocab_size': 200, + 'embed_dim': 128, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + 'num_groups': 2, +} + +print('Step 1: Pretraining') +# Replace with actual tokenised data +data = jnp.ones((101, max_length), dtype=jnp.int32) +dummy_inputs = data[:, :-1] +dummy_targets = data[:, 1:] +dataset = ArrayDataset(dummy_inputs, dummy_targets) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) +model = Gemma(**hyperparams) +# trainer = GemmaDataParallelTrainer(model, dummy_inputs.shape, model_params_path) +# trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader) + +print('\nStep 2: Superfised Fine-Tuning') +# Replace with actual tokenised data +dummy_prompt = jnp.ones((101, max_length), dtype=jnp.int32) +dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) +dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) +# dataset = ArrayDataset(dummy_prompt, dummy_chosen) +# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) +# model = Gemma(**hyperparams) +# trainer = GemmaDataParallelTrainer(model, dummy_prompt.shape, model_params_path) +# trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader) + +print('\nStep 3: Train a reward model') +dataset = ArrayDataset(dummy_chosen, dummy_rejected) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) +reward_model = RewardModel(Gemma(**hyperparams), dim=hyperparams['hidden_dim'], dropout=0.1) +# trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, reward_params_path) +# trainer.train(dataloader, 2, dataloader) + +print('\nStep 4: Train the RLHF model via PPO, using a reference model and the reward model.') +rlhf_model = RLHF(model) +rlhf_ref = RLHF(model, reference=True) +dataset = ArrayDataset(dummy_chosen) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) +trainer = PPODataParallelTrainer(rlhf_model, + rlhf_ref, + reward_model, + dummy_inputs.shape, + rlhf_params_path, + sft_params_path=model_params_path, + reward_params_path=reward_params_path) + +trainer.train(dataloader, 2) \ No newline at end of file diff --git a/nanodl/__src/utils/random.py b/nanodl/__src/utils/random.py new file mode 100644 index 0000000..f184841 --- /dev/null +++ b/nanodl/__src/utils/random.py @@ -0,0 +1,331 @@ +import jax +import time +import jax.numpy as jnp +from jax import random +from typing import Any, Union, Tuple + +def time_rng_key(seed=None) -> jnp.ndarray: + """Generate a JAX random key based on the current UNIX timestamp. + + Returns: + jnp.ndarray: A JAX random key. + """ + key = int(time.time()) if seed is None else seed + return random.PRNGKey(seed) + +def uniform(shape: Tuple[int, ...], + dtype: Any = jnp.float32, + minval: float = 0.0, + maxval: float = 1.0, + seed=None) -> jnp.ndarray: + """Generate a tensor of uniform random values. + + Args: + shape (Tuple[int, ...]): The shape of the output tensor. + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + minval (float, optional): The lower bound of the uniform distribution. Defaults to 0.0. + maxval (float, optional): The upper bound of the uniform distribution. Defaults to 1.0. + + Returns: + jnp.ndarray: A tensor of uniform random values. + """ + return random.uniform(time_rng_key(seed), + shape, + dtype=dtype, + minval=minval, + maxval=maxval) + +def normal(shape: Tuple[int, ...], + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Generate a tensor of normal random values. + + Args: + shape (Tuple[int, ...]): The shape of the output tensor. + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + + Returns: + jnp.ndarray: A tensor of normal random values. + """ + return random.normal(time_rng_key(seed), + shape, dtype=dtype) + +def bernoulli(p: float, + shape: Tuple[int, ...] = (), + seed=None) -> jnp.ndarray: + """Generate random boolean values with a given probability. + + Args: + p (float): Probability of sampling a True value. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + + Returns: + jnp.ndarray: A tensor of boolean values. + """ + return random.bernoulli(time_rng_key(seed), p, shape) + +def categorical(logits: jnp.ndarray, + axis: int = -1, + shape: Tuple[int, ...] = (), + seed=None) -> jnp.ndarray: + """Draw samples from a categorical distribution. + + Args: + logits (jnp.ndarray): The unnormalized log probabilities of the categories. + axis (int, optional): The axis along which the categorical distribution is applied. Defaults to -1. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + + Returns: + jnp.ndarray: The sampled indices with the specified shape. + """ + return random.categorical(time_rng_key(seed), + logits, + axis=axis, + shape=shape) + +def randint(shape: Tuple[int, ...], + minval: int, + maxval: int, + dtype: str = 'int32', + seed=None) -> jnp.ndarray: + """Generate random integers between minval (inclusive) and maxval (exclusive). + + Args: + shape (Tuple[int, ...]): The shape of the output tensor. + minval (int): The lower bound of the random integers, inclusive. + maxval (int): The upper bound of the random integers, exclusive. + dtype (str, optional): The data type of the output tensor. Defaults to 'int32'. + + Returns: + jnp.ndarray: A tensor of random integers. + """ + return random.randint(time_rng_key(seed), + shape, + minval, + maxval, + dtype=dtype) + +def permutation(x: Union[int, jnp.ndarray], + axis: int = 0, + seed=None) -> jnp.ndarray: + """Randomly permute a sequence, or return a permuted range. + + Args: + x (Union[int, jnp.ndarray]): If x is an integer, permute range(x). If x is an array, permute its elements. + axis (int, optional): The axis along which to permute if x is an array. Defaults to 0. + + Returns: + jnp.ndarray: The permuted sequence or array. + """ + if isinstance(x, int): + arr = jax.numpy.arange(x) + return random.permutation(time_rng_key(seed), arr, axis=axis) + else: + return random.permutation(time_rng_key(seed), x, axis=axis) + +def gumbel(shape: Tuple[int, ...], + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Draw samples from a Gumbel distribution. + + Args: + shape (Tuple[int, ...]): The shape of the output tensor. + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + + Returns: + jnp.ndarray: A tensor of samples from a Gumbel distribution. + """ + return random.gumbel(time_rng_key(seed), shape, dtype=dtype) + +def choice(a: Union[int, jnp.ndarray], + shape: Tuple[int, ...] = (), + replace: bool = True, + p: Union[None, jnp.ndarray] = None, + axis: int = 0, + seed=None) -> jnp.ndarray: + """Randomly choose elements from a given 1-D array. + + Args: + a (Union[int, jnp.ndarray]): If an int, the random sample is generated as if a were jnp.arange(a). + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + replace (bool, optional): Whether the sample is with or without replacement. Defaults to True. + p (Union[None, jnp.ndarray], optional): The probabilities associated with each entry in a. Defaults to None. + axis (int, optional): The axis along which to choose if a is an array. Defaults to 0. + + Returns: + jnp.ndarray: The randomly chosen elements. + """ + if isinstance(a, int): + a = jnp.arange(a) + return random.choice(time_rng_key(seed), + a, + shape=shape, + replace=replace, + p=p, + axis=axis) + +def binomial(n: int, + p: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Draw samples from a binomial distribution. + + Args: + n (int): The number of trials. + p (float): The probability of success of an individual trial. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32. + + Returns: + jnp.ndarray: A tensor of samples from a binomial distribution. + """ + return random.binomial(time_rng_key(seed), + n, + p, + shape=shape, + dtype=dtype) + +def bits(shape: Tuple[int, ...], + dtype: Any = jnp.uint32, + seed=None) -> jnp.ndarray: + """Generate random bits. + + Args: + shape (Tuple[int, ...]): The shape of the output tensor. + dtype (Any, optional): The data type of the output tensor, typically an unsigned integer type. Defaults to jnp.uint32. + + Returns: + jnp.ndarray: A tensor of random bits. + """ + return random.bits(time_rng_key(seed), shape, dtype=dtype) + +def exponential(shape: Tuple[int, ...], + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Draw samples from an exponential distribution. + + Args: + shape (Tuple[int, ...]): The shape of the output tensor. + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + + Returns: + jnp.ndarray: A tensor of samples from an exponential distribution. + """ + return random.exponential(time_rng_key(seed), shape, dtype=dtype) + +def triangular(left: float, + right: float, + mode: float, + shape: Tuple[int, ...] = (), + seed=None) -> jnp.ndarray: + """Draw samples from a triangular distribution. + + Args: + left (float): The lower limit of the distribution. + right (float): The upper limit of the distribution. + mode (float): The mode (peak) of the distribution. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + + Returns: + jnp.ndarray: A tensor of samples from a triangular distribution. + """ + return random.triangular(time_rng_key(seed), left, right, mode, shape) + +def truncated_normal(lower: float, + upper: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Draw samples from a truncated normal distribution. + + Args: + lower (float): The lower bound of the distribution. + upper (float): The upper bound of the distribution. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + + Returns: + jnp.ndarray: A tensor of samples from a truncated normal distribution. + """ + return random.truncated_normal(time_rng_key(seed), + lower, + upper, + shape, + dtype) + +def poisson(lam: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.int32, + seed=None) -> jnp.ndarray: + """Draw samples from a Poisson distribution. + + Args: + lam (float): The expectation of interval (lambda parameter). + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32. + + Returns: + jnp.ndarray: A tensor of samples from a Poisson distribution. + """ + return random.poisson(time_rng_key(seed), + lam, + shape=shape, + dtype=dtype) + +def geometric(p: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.int32, + seed=None) -> jnp.ndarray: + """Draw samples from a geometric distribution. + + Args: + p (float): The probability of success of an individual trial. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.int32. + + Returns: + jnp.ndarray: A tensor of samples from a geometric distribution. + """ + return random.geometric(time_rng_key(seed), + p, + shape=shape, + dtype=dtype) + +def gamma(a: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Draw samples from a gamma distribution. + + Args: + a (float): The shape parameter of the gamma distribution. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + + Returns: + jnp.ndarray: A tensor of samples from a gamma distribution. + """ + return random.gamma(time_rng_key(seed), + a, + shape=shape, + dtype=dtype) + +def chisquare(df: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.float32, + seed=None) -> jnp.ndarray: + """Draw samples from a chi-square distribution. + + Args: + df (float): The degrees of freedom. + shape (Tuple[int, ...], optional): The shape of the output tensor. Defaults to (). + dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. + + Returns: + jnp.ndarray: A tensor of samples from a chi-square distribution. + """ + return random.chisquare(time_rng_key(seed), + df, + shape=shape, + dtype=dtype) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8d19c22..93deb62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ jax jaxlib flax -optax \ No newline at end of file +optax +einops \ No newline at end of file diff --git a/setup.py b/setup.py index 2497952..6874462 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='nanodl', - version='1.0.1.dev1', + version='1.2.0.dev1', author='Henry Ndubuaku', author_email='ndubuakuhenry@gmail.com', description='A Jax-based library for designing and training transformer models from scratch.', diff --git a/tests/test_models.py b/tests/test_models.py index e1b5a63..a9c66ea 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -71,6 +71,11 @@ def test_llama_model(self): num_groups=2) self._test_decoder_only_model(model) + def test_gemma_model(self): + model = Gemma(**self.hyperparams, + num_groups=2) + self._test_decoder_only_model(model) + def _test_encoder_decoder_model(self, model): rngs = { 'params': jax.random.key(0), @@ -114,6 +119,19 @@ def _test_decoder_only_model(self, model): outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size) ) + + def test_reward_model(self): + model = RewardModel(Mixtral(**self.hyperparams, + num_groups=2, + window_size=5, + shift_size=2), dim=self.hyperparams['hidden_dim'], dropout=0.1) + rngs = jax.random.PRNGKey(0) + rngs, dropout_rng = jax.random.split(rngs) + params = model.init({'params': rngs, 'dropout': dropout_rng}, self.dummy_inputs)['params'] + rewards = model.apply({'params': params}, + self.dummy_inputs, + rngs={'dropout': dropout_rng}) + assert rewards.shape == (self.batch_size,) class TestVisionBasedModels(unittest.TestCase): diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 0000000..1e5ed3b --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,102 @@ +import unittest +import jax.numpy as jnp +from nanodl import ( + time_rng_key, uniform, normal, bernoulli, categorical, randint, + permutation, gumbel, choice, binomial, bits, exponential, + triangular, truncated_normal, poisson, geometric, gamma, + chisquare +) + +class TestRandomFunctions(unittest.TestCase): + + def test_time_rng_key(self): + key1 = time_rng_key(seed=42) + key2 = time_rng_key(seed=42) + self.assertTrue(jnp.array_equal(key1, key2), "Keys should be equal for the same seed") + + def test_uniform(self): + result = uniform((2, 3), seed=42) + self.assertEqual(result.shape, (2, 3)) + self.assertEqual(result.dtype, jnp.float32) + + def test_normal(self): + result = normal((4, 5), seed=42) + self.assertEqual(result.shape, (4, 5)) + self.assertEqual(result.dtype, jnp.float32) + + def test_bernoulli(self): + result = bernoulli(0.5, (10,), seed=42) + self.assertEqual(result.shape, (10,)) + self.assertEqual(result.dtype, jnp.bool_) + + def test_categorical(self): + logits = jnp.array([0.1, 0.2, 0.7]) + result = categorical(logits, shape=(5,), seed=42) + self.assertEqual(result.shape, (5,)) + + def test_randint(self): + result = randint((3, 3), 0, 10, seed=42) + self.assertEqual(result.shape, (3, 3)) + self.assertEqual(result.dtype, jnp.int32) + + def test_permutation(self): + arr = jnp.arange(10) + result = permutation(arr, seed=42) + self.assertEqual(result.shape, arr.shape) + self.assertNotEqual(jnp.all(result == arr), True) + + def test_gumbel(self): + result = gumbel((2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.float32) + + def test_choice(self): + result = choice(5, shape=(3,), seed=42) + self.assertEqual(result.shape, (3,)) + + def test_binomial(self): + result = binomial(10, 0.5, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + + def test_bits(self): + result = bits((2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.uint32) + + def test_exponential(self): + result = exponential((2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.float32) + + def test_triangular(self): + result = triangular(0, 1, 0.5, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.float32) + + def test_truncated_normal(self): + result = truncated_normal(0, 1, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.float32) + + def test_poisson(self): + result = poisson(3, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.int32) + + def test_geometric(self): + result = geometric(0.5, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.int32) + + def test_gamma(self): + result = gamma(2, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.float32) + + def test_chisquare(self): + result = chisquare(2, (2, 2), seed=42) + self.assertEqual(result.shape, (2, 2)) + self.assertEqual(result.dtype, jnp.float32) + +if __name__ == '__main__': + unittest.main()