From 1a4d354582f15db4a48919009400c7fd094a73cf Mon Sep 17 00:00:00 2001 From: Henry Ndubuaku Date: Tue, 19 Mar 2024 21:32:57 +0000 Subject: [PATCH 1/2] Update README.md --- README.md | 96 +++++++++++++++++++------------------------------------ 1 file changed, 33 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 5ddbaa6..6f7f6dc 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ We provide various example usages of the nanodl API. ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import GPT4, GPTDataParallelTrainer @@ -69,25 +70,15 @@ from nanodl import GPT4, GPTDataParallelTrainer batch_size = 8 max_length = 10 -# Replace with actual tokenised data +# Replace with actual list of tokenised texts data = jnp.ones((101, max_length), dtype=jnp.int32) # Shift to create next-token prediction dataset -dummy_inputs = data[:, :-1] -dummy_targets = data[:, 1:] +dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) - -# How to loop through dataloader -for batch in dataloader: - x, y = batch - print(x.shape, y.shape) - break +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # model parameters hyperparams = { @@ -103,25 +94,17 @@ hyperparams = { 'end_token': 50, } -# Initialize model +# Initialize inferred GPT4 model model = GPT4(**hyperparams) -rngs = jax.random.PRNGKey(0) -rngs, dropout_rng = jax.random.split(rngs) -params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] - -# Call as you would a Jax/Flax model -outputs = model.apply({'params': params}, - dummy_inputs, - rngs={'dropout': dropout_rng}) -print(outputs.shape) +params = model.init( + {'params': time_rng_key(), + 'dropout': time_rng_key() + }, + dummy_inputs)['params'] # Training on data trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') -trainer.train(train_loader=dataloader, - num_epochs=2, - val_loader=dataloader) - -print(trainer.evaluate(dataloader)) +trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader) # Generating from a start token start_tokens = jnp.array([[123, 456]]) @@ -130,9 +113,8 @@ start_tokens = jnp.array([[123, 456]]) params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': time_rng_key()}, method=model.generate) -print(outputs) ``` Vision example @@ -140,6 +122,7 @@ Vision example ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import DiffusionModel, DiffusionDataParallelTrainer @@ -147,16 +130,12 @@ image_size = 32 block_depth = 2 batch_size = 8 widths = [32, 64, 128] -key = jax.random.PRNGKey(0) input_shape = (101, image_size, image_size, 3) -images = jax.random.normal(key, input_shape) +images = jax.random.normal(time_rng_key(), input_shape) # Use your own images dataset = ArrayDataset(images) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # Create diffusion model diffusion_model = DiffusionModel(image_size, widths, block_depth) @@ -165,13 +144,11 @@ pred_noises, pred_images = diffusion_model.apply(params, images) print(pred_noises.shape, pred_images.shape) # Training on your data -# Note: saved params are often different from training weights, use the saved params for generation trainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.shape, weights_filename='params.pkl', learning_rate=1e-4) trainer.train(dataloader, 10, dataloader) -print(trainer.evaluate(dataloader)) # Generate some samples params = trainer.load_params('params.pkl') @@ -179,7 +156,6 @@ generated_images = diffusion_model.apply({'params': params}, num_images=5, diffusion_steps=5, method=diffusion_model.generate) -print(generated_images.shape) ``` Audio example @@ -187,6 +163,7 @@ Audio example ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import Whisper, WhisperDataParallelTrainer @@ -200,13 +177,8 @@ vocab_size = 1000 dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32) dummy_inputs = jnp.ones((101, max_length, embed_dim)) -dataset = ArrayDataset(dummy_inputs, - dummy_targets) - -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataset = ArrayDataset(dummy_inputs, dummy_targets) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # model parameters hyperparams = { @@ -224,10 +196,8 @@ hyperparams = { # Initialize model model = Whisper(**hyperparams) -rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} +rngs = {'params': time_rng_key(), 'dropout': time_rng_key()} params = model.init(rngs, dummy_inputs, dummy_targets)['params'] -outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs) -print(outputs.shape) # Training on your data trainer = WhisperDataParallelTrainer(model, @@ -239,13 +209,11 @@ trainer.train(dataloader, 2, dataloader) # Sample inference params = trainer.load_params('params.pkl') -# for more than one sample, use model.generate_batch +# for more than one sample, often use model.generate_batch transcripts = model.apply({'params': params}, dummy_inputs[:1], rngs=rngs, method=model.generate) - -print(transcripts) ``` Reward Model example for RLHF @@ -253,6 +221,7 @@ Reward Model example for RLHF ```py import jax import jax.numpy as jnp +from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import Mistral, RewardModel, RewardDataParallelTrainer @@ -266,10 +235,7 @@ dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) # Create dataset and dataloader dataset = ArrayDataset(dummy_chosen, dummy_rejected) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # model parameters hyperparams = { @@ -298,13 +264,9 @@ trainer.train(dataloader, 5, dataloader) params = trainer.load_params('reward_model_weights.pkl') # Call as you would a regular Flax model -rngs = jax.random.PRNGKey(0) -rngs, dropout_rng = jax.random.split(rngs) rewards = reward_model.apply({'params': params}, dummy_chosen, - rngs={'dropout': dropout_rng}) - -print(rewards.shape) + rngs={'dropout': time_rng_key()}) ``` PCA example @@ -313,13 +275,21 @@ PCA example import jax from nanodl import PCA +# Use actual data data = jax.random.normal(jax.random.key(0), (1000, 10)) + +# Initialise and train PCA model pca = PCA(n_components=2) pca.fit(data) + +# Get PCA transforms transformed_data = pca.transform(data) + +# Get reverse transforms original_data = pca.inverse_transform(transformed_data) + +# Sample from the distribution X_sampled = pca.sample(n_samples=1000, key=None) -print(X_sampled.shape, original_data.shape, transformed_data.shape) ``` NanoDL provides random module which abstracts away Jax's intricacies. @@ -371,7 +341,7 @@ Following the success of Phi models, the long-term goal is to build and train na while ensuring they compete with the original models in performance, with total number of parameters not exceeding 1B. Trained weights will be made available via this library. Any form of sponsorship, funding, grants or contribution will help with training resources. -You can sponsor via the tag on the user profile, or reach out via ndubuakuhenry@gmail.com. +You can sponsor via the user profile tag or reach out via ndubuakuhenry@gmail.com. ## Citing nanodl From dc47e23a11f3d6a83fb6f5d97713caf9fd929789 Mon Sep 17 00:00:00 2001 From: Henry Ndubuaku Date: Tue, 19 Mar 2024 21:55:16 +0000 Subject: [PATCH 2/2] Update README.md --- README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 6f7f6dc..23c32f2 100644 --- a/README.md +++ b/README.md @@ -309,16 +309,13 @@ jax_array = nanodl.uniform(shape=(3, 3)) jax_array = nanodl.uniform(shape=(3, 3), seed=0) ``` -This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: +This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged! -- Raise the issue/discussion to get second opinions -- Fork the repository -- Create a branch - Make your changes without changing the design patterns - Write tests for your changes if necessary - Install locally with `pip install -e .` - Run tests with `python -m unittest discover -s tests` -- Then submit a pull request from branch. +- Then submit a pull request. Contributions can be made in various forms: