Skip to content

Commit

Permalink
🕊️ Migration PPOv2 -> PPO (#2174)
Browse files Browse the repository at this point in the history
* delete old ppo

* rename ppov2 files

* PPOv2 -> PPO

* rm old doc

* rename ppo doc file

* rm old test

* rename test

* re-add v2 with deprecation

* style

* start update customization

* Lion

* Finish update customization

* remove ppo_multi_adaptater

* remove ppo example

* update some doc

* rm test no peft

* rm hello world

* processing class

* Update docs/source/detoxifying_a_lm.mdx

Co-authored-by: Edward Beeching <[email protected]>

* Update trl/trainer/ppov2_config.py

Co-authored-by: Edward Beeching <[email protected]>

* Update docs/source/customization.mdx

Co-authored-by: lewtun <[email protected]>

* Update docs/source/detoxifying_a_lm.mdx

Co-authored-by: lewtun <[email protected]>

* po to example overview

* drop lion

* remove "Use 8-bit optimizer"

* Update docs/source/customization.mdx

* Update docs/source/customization.mdx

Co-authored-by: lewtun <[email protected]>

* it applies to all trainers

---------

Co-authored-by: Edward Beeching <[email protected]>
Co-authored-by: lewtun <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent d0aa421 commit 70036bf
Show file tree
Hide file tree
Showing 22 changed files with 853 additions and 4,628 deletions.
2 changes: 0 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
title: ORPO
- local: ppo_trainer
title: PPO
- local: ppov2_trainer
title: PPOv2
- local: reward_trainer
title: Reward
- local: rloo_trainer
Expand Down
225 changes: 86 additions & 139 deletions docs/source/customization.mdx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Training customization

TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.

## Train on multiple GPUs / nodes

Expand Down Expand Up @@ -46,171 +46,118 @@ else:
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.


## Use different optimizers
## Use different optimizers and schedulers

By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`:
```python
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. define config
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)


# 2. Create optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)


# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
```

For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:

```python
import torch
import bitsandbytes as bnb

from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. define config
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)


# 2. Create optimizer
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)

# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
```

### Use LION optimizer
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:

You can use the new [LION optimizer from Google](https://huggingface.co/papers/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
```python
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)

...
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
optimizers=(optimizer, None),
)
trainer.train()
```
We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png">
</div>

### Add a learning rate scheduler

## Add a learning rate scheduler
You can also play with your training by adding learning rate schedulers.

You can also play with your training by adding learning rate schedulers!
```python
import torch
from transformers import GPT2Tokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# 2. define config
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
config = PPOConfig(**ppo_config)


# 2. Create optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

# 3. initialize trainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch import optim
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
optimizers=(optimizer, lr_scheduler),
)
trainer.train()
```

## Memory efficient fine-tuning by sharing layers

Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.

```python
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import create_reference_model, DPOConfig, DPOTrainer

# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
ref_model = create_reference_model(model, num_shared_layers=6)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')

# 2. initialize trainer
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
```

## Pass 8-bit reference models

<div>

Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.

Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).

</div>
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).

```python
# 0. imports
# pip install bitsandbytes
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')

# 2. initialize trainer
ppo_config = {'batch_size': 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
```

## Use the CUDA cache optimizer

When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`:
When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`:

```python
config = PPOConfig(..., optimize_cuda_cache=True)
```



## Use score scaling/normalization/clipping
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://huggingface.co/papers/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
```python
from trl import PPOConfig

ppo_config = {
use_score_scaling=True,
use_score_norm=True,
score_clip=0.5,
}
config = PPOConfig(**ppo_config)
```

To run `ppo.py`, you can use the following command:
```
python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
training_args = DPOConfig(..., optimize_cuda_cache=True)
```
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ Choosing the right dataset format depends on the task you are working on and the
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOv2Trainer`] | Tokenized language modeling |
| [`PPOTrainer`] | Tokenized language modeling |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
Expand Down
10 changes: 3 additions & 7 deletions docs/source/detoxifying_a_lm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,15 @@ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=

and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.

- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`:
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
</div>

```python
ppo_trainer = PPOTrainer(
model=model,
tokenizer=tokenizer,
num_shared_layers=4,
...
)
ref_policy = create_reference_model(model, num_shared_layers=6)
trainer = PPOTrainer(..., ref_policy=ref_policy)
```

In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The abstract from the paper is the following:

The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.

Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppov2_trainer):
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):

1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the [`PPOTrainer`] to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a sentiment analysis model using [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
Expand Down
Loading

0 comments on commit 70036bf

Please sign in to comment.