Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/danny-1k/nanodl into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Mar 22, 2024
2 parents 4940368 + dc47e23 commit c188a04
Showing 1 changed file with 35 additions and 68 deletions.
103 changes: 35 additions & 68 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,23 @@ We provide various example usages of the nanodl API.
```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
# Replace with actual list of tokenised texts
data = jnp.ones((101, max_length), dtype=jnp.int32)

# Shift to create next-token prediction dataset
dummy_inputs = data[:, :-1]
dummy_targets = data[:, 1:]
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]

# Create dataset and dataloader
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)

# How to loop through dataloader
for batch in dataloader:
x, y = batch
print(x.shape, y.shape)
break
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -103,25 +94,17 @@ hyperparams = {
'end_token': 50,
}

# Initialize model
# Initialize inferred GPT4 model
model = GPT4(**hyperparams)
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']

# Call as you would a Jax/Flax model
outputs = model.apply({'params': params},
dummy_inputs,
rngs={'dropout': dropout_rng})
print(outputs.shape)
params = model.init(
{'params': time_rng_key(),
'dropout': time_rng_key()
},
dummy_inputs)['params']

# Training on data
trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
trainer.train(train_loader=dataloader,
num_epochs=2,
val_loader=dataloader)

print(trainer.evaluate(dataloader))
trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)

# Generating from a start token
start_tokens = jnp.array([[123, 456]])
Expand All @@ -130,33 +113,29 @@ start_tokens = jnp.array([[123, 456]])
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': jax.random.PRNGKey(2)},
rngs={'dropout': time_rng_key()},
method=model.generate)
print(outputs)
```

Vision example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)
images = jax.random.normal(time_rng_key(), input_shape)

# Use your own images
dataset = ArrayDataset(images)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
Expand All @@ -165,28 +144,26 @@ pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))

# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
diffusion_steps=5,
method=diffusion_model.generate)
print(generated_images.shape)
```

Audio example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Whisper, WhisperDataParallelTrainer

Expand All @@ -200,13 +177,8 @@ vocab_size = 1000
dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_inputs = jnp.ones((101, max_length, embed_dim))

dataset = ArrayDataset(dummy_inputs,
dummy_targets)

dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -224,10 +196,8 @@ hyperparams = {

# Initialize model
model = Whisper(**hyperparams)
rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
rngs = {'params': time_rng_key(), 'dropout': time_rng_key()}
params = model.init(rngs, dummy_inputs, dummy_targets)['params']
outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)
print(outputs.shape)

# Training on your data
trainer = WhisperDataParallelTrainer(model,
Expand All @@ -239,20 +209,19 @@ trainer.train(dataloader, 2, dataloader)
# Sample inference
params = trainer.load_params('params.pkl')

# for more than one sample, use model.generate_batch
# for more than one sample, often use model.generate_batch
transcripts = model.apply({'params': params},
dummy_inputs[:1],
rngs=rngs,
method=model.generate)

print(transcripts)
```

Reward Model example for RLHF

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

Expand All @@ -266,10 +235,7 @@ dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)

# Create dataset and dataloader
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand Down Expand Up @@ -298,13 +264,9 @@ trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')

# Call as you would a regular Flax model
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': dropout_rng})

print(rewards.shape)
rngs={'dropout': time_rng_key()})
```

PCA example
Expand All @@ -313,13 +275,21 @@ PCA example
import jax
from nanodl import PCA

# Use actual data
data = jax.random.normal(jax.random.key(0), (1000, 10))

# Initialise and train PCA model
pca = PCA(n_components=2)
pca.fit(data)

# Get PCA transforms
transformed_data = pca.transform(data)

# Get reverse transforms
original_data = pca.inverse_transform(transformed_data)

# Sample from the distribution
X_sampled = pca.sample(n_samples=1000, key=None)
print(X_sampled.shape, original_data.shape, transformed_data.shape)
```

NanoDL provides random module which abstracts away Jax's intricacies.
Expand All @@ -339,16 +309,13 @@ jax_array = nanodl.uniform(shape=(3, 3))
jax_array = nanodl.uniform(shape=(3, 3), seed=0)
```

This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:
This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged!

- Raise the issue/discussion to get second opinions
- Fork the repository
- Create a branch
- Make your changes without changing the design patterns
- Write tests for your changes if necessary
- Install locally with `pip install -e .`
- Run tests with `python -m unittest discover -s tests`
- Then submit a pull request from branch.
- Then submit a pull request.

Contributions can be made in various forms:

Expand All @@ -371,7 +338,7 @@ Following the success of Phi models, the long-term goal is to build and train na
while ensuring they compete with the original models in performance, with total
number of parameters not exceeding 1B. Trained weights will be made available via this library.
Any form of sponsorship, funding, grants or contribution will help with training resources.
You can sponsor via the tag on the user profile, or reach out via [email protected].
You can sponsor via the user profile tag or reach out via [email protected].

## Citing nanodl

Expand Down

0 comments on commit c188a04

Please sign in to comment.