Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IJEPA implementation #28

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 35 additions & 68 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,23 @@ We provide various example usages of the nanodl API.
```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
# Replace with actual list of tokenised texts
data = jnp.ones((101, max_length), dtype=jnp.int32)

# Shift to create next-token prediction dataset
dummy_inputs = data[:, :-1]
dummy_targets = data[:, 1:]
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]

# Create dataset and dataloader
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)

# How to loop through dataloader
for batch in dataloader:
x, y = batch
print(x.shape, y.shape)
break
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -103,25 +94,17 @@ hyperparams = {
'end_token': 50,
}

# Initialize model
# Initialize inferred GPT4 model
model = GPT4(**hyperparams)
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']

# Call as you would a Jax/Flax model
outputs = model.apply({'params': params},
dummy_inputs,
rngs={'dropout': dropout_rng})
print(outputs.shape)
params = model.init(
{'params': time_rng_key(),
'dropout': time_rng_key()
},
dummy_inputs)['params']

# Training on data
trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
trainer.train(train_loader=dataloader,
num_epochs=2,
val_loader=dataloader)

print(trainer.evaluate(dataloader))
trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)

# Generating from a start token
start_tokens = jnp.array([[123, 456]])
Expand All @@ -130,33 +113,29 @@ start_tokens = jnp.array([[123, 456]])
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': jax.random.PRNGKey(2)},
rngs={'dropout': time_rng_key()},
method=model.generate)
print(outputs)
```

Vision example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)
images = jax.random.normal(time_rng_key(), input_shape)

# Use your own images
dataset = ArrayDataset(images)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
Expand All @@ -165,28 +144,26 @@ pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))

# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
diffusion_steps=5,
method=diffusion_model.generate)
print(generated_images.shape)
```

Audio example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Whisper, WhisperDataParallelTrainer

Expand All @@ -200,13 +177,8 @@ vocab_size = 1000
dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_inputs = jnp.ones((101, max_length, embed_dim))

dataset = ArrayDataset(dummy_inputs,
dummy_targets)

dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -224,10 +196,8 @@ hyperparams = {

# Initialize model
model = Whisper(**hyperparams)
rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
rngs = {'params': time_rng_key(), 'dropout': time_rng_key()}
params = model.init(rngs, dummy_inputs, dummy_targets)['params']
outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)
print(outputs.shape)

# Training on your data
trainer = WhisperDataParallelTrainer(model,
Expand All @@ -239,20 +209,19 @@ trainer.train(dataloader, 2, dataloader)
# Sample inference
params = trainer.load_params('params.pkl')

# for more than one sample, use model.generate_batch
# for more than one sample, often use model.generate_batch
transcripts = model.apply({'params': params},
dummy_inputs[:1],
rngs=rngs,
method=model.generate)

print(transcripts)
```

Reward Model example for RLHF

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

Expand All @@ -266,10 +235,7 @@ dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)

# Create dataset and dataloader
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand Down Expand Up @@ -298,13 +264,9 @@ trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')

# Call as you would a regular Flax model
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': dropout_rng})

print(rewards.shape)
rngs={'dropout': time_rng_key()})
```

PCA example
Expand All @@ -313,13 +275,21 @@ PCA example
import jax
from nanodl import PCA

# Use actual data
data = jax.random.normal(jax.random.key(0), (1000, 10))

# Initialise and train PCA model
pca = PCA(n_components=2)
pca.fit(data)

# Get PCA transforms
transformed_data = pca.transform(data)

# Get reverse transforms
original_data = pca.inverse_transform(transformed_data)

# Sample from the distribution
X_sampled = pca.sample(n_samples=1000, key=None)
print(X_sampled.shape, original_data.shape, transformed_data.shape)
```

NanoDL provides random module which abstracts away Jax's intricacies.
Expand All @@ -339,16 +309,13 @@ jax_array = nanodl.uniform(shape=(3, 3))
jax_array = nanodl.uniform(shape=(3, 3), seed=0)
```

This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:
This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged!

- Raise the issue/discussion to get second opinions
- Fork the repository
- Create a branch
- Make your changes without changing the design patterns
- Write tests for your changes if necessary
- Install locally with `pip install -e .`
- Run tests with `python -m unittest discover -s tests`
- Then submit a pull request from branch.
- Then submit a pull request.

Contributions can be made in various forms:

Expand All @@ -371,7 +338,7 @@ Following the success of Phi models, the long-term goal is to build and train na
while ensuring they compete with the original models in performance, with total
number of parameters not exceeding 1B. Trained weights will be made available via this library.
Any form of sponsorship, funding, grants or contribution will help with training resources.
You can sponsor via the tag on the user profile, or reach out via [email protected].
You can sponsor via the user profile tag or reach out via [email protected].

## Citing nanodl

Expand Down
9 changes: 9 additions & 0 deletions nanodl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@
RewardDataParallelTrainer
)

from nanodl.__src.models.ijepa import (
IJEPA,
IJEPADataParallelTrainer,
IJEPADataSampler
)

from nanodl.__src.layers.attention import (
MultiQueryAttention,
LocalMultiHeadAttention,
Expand Down Expand Up @@ -193,6 +199,9 @@
"GaussianProcess",

# Models
"IJEPA",
"IJEPADataParallelTrainer",
"IJEPADataSampler",
"Gemma",
"GemmaDataParallelTrainer",
"GemmaDecoder",
Expand Down
Loading
Loading