-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/danny-1k/nanodl into dev
- Loading branch information
Showing
1 changed file
with
35 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,32 +62,23 @@ We provide various example usages of the nanodl API. | |
```py | ||
import jax | ||
import jax.numpy as jnp | ||
from nanodl import time_rng_key | ||
from nanodl import ArrayDataset, DataLoader | ||
from nanodl import GPT4, GPTDataParallelTrainer | ||
|
||
# Generate dummy data | ||
batch_size = 8 | ||
max_length = 10 | ||
|
||
# Replace with actual tokenised data | ||
# Replace with actual list of tokenised texts | ||
data = jnp.ones((101, max_length), dtype=jnp.int32) | ||
|
||
# Shift to create next-token prediction dataset | ||
dummy_inputs = data[:, :-1] | ||
dummy_targets = data[:, 1:] | ||
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] | ||
|
||
# Create dataset and dataloader | ||
dataset = ArrayDataset(dummy_inputs, dummy_targets) | ||
dataloader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=False) | ||
|
||
# How to loop through dataloader | ||
for batch in dataloader: | ||
x, y = batch | ||
print(x.shape, y.shape) | ||
break | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) | ||
|
||
# model parameters | ||
hyperparams = { | ||
|
@@ -103,25 +94,17 @@ hyperparams = { | |
'end_token': 50, | ||
} | ||
|
||
# Initialize model | ||
# Initialize inferred GPT4 model | ||
model = GPT4(**hyperparams) | ||
rngs = jax.random.PRNGKey(0) | ||
rngs, dropout_rng = jax.random.split(rngs) | ||
params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] | ||
|
||
# Call as you would a Jax/Flax model | ||
outputs = model.apply({'params': params}, | ||
dummy_inputs, | ||
rngs={'dropout': dropout_rng}) | ||
print(outputs.shape) | ||
params = model.init( | ||
{'params': time_rng_key(), | ||
'dropout': time_rng_key() | ||
}, | ||
dummy_inputs)['params'] | ||
|
||
# Training on data | ||
trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') | ||
trainer.train(train_loader=dataloader, | ||
num_epochs=2, | ||
val_loader=dataloader) | ||
|
||
print(trainer.evaluate(dataloader)) | ||
trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader) | ||
|
||
# Generating from a start token | ||
start_tokens = jnp.array([[123, 456]]) | ||
|
@@ -130,33 +113,29 @@ start_tokens = jnp.array([[123, 456]]) | |
params = trainer.load_params('params.pkl') | ||
outputs = model.apply({'params': params}, | ||
start_tokens, | ||
rngs={'dropout': jax.random.PRNGKey(2)}, | ||
rngs={'dropout': time_rng_key()}, | ||
method=model.generate) | ||
print(outputs) | ||
``` | ||
|
||
Vision example | ||
|
||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
from nanodl import time_rng_key | ||
from nanodl import ArrayDataset, DataLoader | ||
from nanodl import DiffusionModel, DiffusionDataParallelTrainer | ||
|
||
image_size = 32 | ||
block_depth = 2 | ||
batch_size = 8 | ||
widths = [32, 64, 128] | ||
key = jax.random.PRNGKey(0) | ||
input_shape = (101, image_size, image_size, 3) | ||
images = jax.random.normal(key, input_shape) | ||
images = jax.random.normal(time_rng_key(), input_shape) | ||
|
||
# Use your own images | ||
dataset = ArrayDataset(images) | ||
dataloader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=False) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) | ||
|
||
# Create diffusion model | ||
diffusion_model = DiffusionModel(image_size, widths, block_depth) | ||
|
@@ -165,28 +144,26 @@ pred_noises, pred_images = diffusion_model.apply(params, images) | |
print(pred_noises.shape, pred_images.shape) | ||
|
||
# Training on your data | ||
# Note: saved params are often different from training weights, use the saved params for generation | ||
trainer = DiffusionDataParallelTrainer(diffusion_model, | ||
input_shape=images.shape, | ||
weights_filename='params.pkl', | ||
learning_rate=1e-4) | ||
trainer.train(dataloader, 10, dataloader) | ||
print(trainer.evaluate(dataloader)) | ||
|
||
# Generate some samples | ||
params = trainer.load_params('params.pkl') | ||
generated_images = diffusion_model.apply({'params': params}, | ||
num_images=5, | ||
diffusion_steps=5, | ||
method=diffusion_model.generate) | ||
print(generated_images.shape) | ||
``` | ||
|
||
Audio example | ||
|
||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
from nanodl import time_rng_key | ||
from nanodl import ArrayDataset, DataLoader | ||
from nanodl import Whisper, WhisperDataParallelTrainer | ||
|
||
|
@@ -200,13 +177,8 @@ vocab_size = 1000 | |
dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32) | ||
dummy_inputs = jnp.ones((101, max_length, embed_dim)) | ||
|
||
dataset = ArrayDataset(dummy_inputs, | ||
dummy_targets) | ||
|
||
dataloader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=False) | ||
dataset = ArrayDataset(dummy_inputs, dummy_targets) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) | ||
|
||
# model parameters | ||
hyperparams = { | ||
|
@@ -224,10 +196,8 @@ hyperparams = { | |
|
||
# Initialize model | ||
model = Whisper(**hyperparams) | ||
rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} | ||
rngs = {'params': time_rng_key(), 'dropout': time_rng_key()} | ||
params = model.init(rngs, dummy_inputs, dummy_targets)['params'] | ||
outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs) | ||
print(outputs.shape) | ||
|
||
# Training on your data | ||
trainer = WhisperDataParallelTrainer(model, | ||
|
@@ -239,20 +209,19 @@ trainer.train(dataloader, 2, dataloader) | |
# Sample inference | ||
params = trainer.load_params('params.pkl') | ||
|
||
# for more than one sample, use model.generate_batch | ||
# for more than one sample, often use model.generate_batch | ||
transcripts = model.apply({'params': params}, | ||
dummy_inputs[:1], | ||
rngs=rngs, | ||
method=model.generate) | ||
|
||
print(transcripts) | ||
``` | ||
|
||
Reward Model example for RLHF | ||
|
||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
from nanodl import time_rng_key | ||
from nanodl import ArrayDataset, DataLoader | ||
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer | ||
|
||
|
@@ -266,10 +235,7 @@ dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) | |
|
||
# Create dataset and dataloader | ||
dataset = ArrayDataset(dummy_chosen, dummy_rejected) | ||
dataloader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=False) | ||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) | ||
|
||
# model parameters | ||
hyperparams = { | ||
|
@@ -298,13 +264,9 @@ trainer.train(dataloader, 5, dataloader) | |
params = trainer.load_params('reward_model_weights.pkl') | ||
|
||
# Call as you would a regular Flax model | ||
rngs = jax.random.PRNGKey(0) | ||
rngs, dropout_rng = jax.random.split(rngs) | ||
rewards = reward_model.apply({'params': params}, | ||
dummy_chosen, | ||
rngs={'dropout': dropout_rng}) | ||
|
||
print(rewards.shape) | ||
rngs={'dropout': time_rng_key()}) | ||
``` | ||
|
||
PCA example | ||
|
@@ -313,13 +275,21 @@ PCA example | |
import jax | ||
from nanodl import PCA | ||
|
||
# Use actual data | ||
data = jax.random.normal(jax.random.key(0), (1000, 10)) | ||
|
||
# Initialise and train PCA model | ||
pca = PCA(n_components=2) | ||
pca.fit(data) | ||
|
||
# Get PCA transforms | ||
transformed_data = pca.transform(data) | ||
|
||
# Get reverse transforms | ||
original_data = pca.inverse_transform(transformed_data) | ||
|
||
# Sample from the distribution | ||
X_sampled = pca.sample(n_samples=1000, key=None) | ||
print(X_sampled.shape, original_data.shape, transformed_data.shape) | ||
``` | ||
|
||
NanoDL provides random module which abstracts away Jax's intricacies. | ||
|
@@ -339,16 +309,13 @@ jax_array = nanodl.uniform(shape=(3, 3)) | |
jax_array = nanodl.uniform(shape=(3, 3), seed=0) | ||
``` | ||
|
||
This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: | ||
This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged! | ||
|
||
- Raise the issue/discussion to get second opinions | ||
- Fork the repository | ||
- Create a branch | ||
- Make your changes without changing the design patterns | ||
- Write tests for your changes if necessary | ||
- Install locally with `pip install -e .` | ||
- Run tests with `python -m unittest discover -s tests` | ||
- Then submit a pull request from branch. | ||
- Then submit a pull request. | ||
|
||
Contributions can be made in various forms: | ||
|
||
|
@@ -371,7 +338,7 @@ Following the success of Phi models, the long-term goal is to build and train na | |
while ensuring they compete with the original models in performance, with total | ||
number of parameters not exceeding 1B. Trained weights will be made available via this library. | ||
Any form of sponsorship, funding, grants or contribution will help with training resources. | ||
You can sponsor via the tag on the user profile, or reach out via [email protected]. | ||
You can sponsor via the user profile tag or reach out via [email protected]. | ||
|
||
## Citing nanodl | ||
|
||
|