diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index c6d7ef6826..f3d39ba445 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -42,8 +42,6 @@ title: ORPO - local: ppo_trainer title: PPO - - local: ppov2_trainer - title: PPOv2 - local: reward_trainer title: Reward - local: rloo_trainer diff --git a/docs/source/customization.mdx b/docs/source/customization.mdx index a576890734..7fc9211e11 100644 --- a/docs/source/customization.mdx +++ b/docs/source/customization.mdx @@ -1,6 +1,6 @@ # Training customization -TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. +TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers. ## Train on multiple GPUs / nodes @@ -46,171 +46,118 @@ else: Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. -## Use different optimizers +## Use different optimizers and schedulers -By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`: -```python -import torch -from transformers import GPT2Tokenizer -from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead - -# 1. load a pretrained model -model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') -ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') -tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - -# 2. define config -ppo_config = {'batch_size': 1, 'learning_rate':1e-5} -config = PPOConfig(**ppo_config) - - -# 2. Create optimizer -optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) - - -# 3. initialize trainer -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer) -``` - -For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`: - -```python -import torch -import bitsandbytes as bnb - -from transformers import GPT2Tokenizer -from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead - -# 1. load a pretrained model -model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') -ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') -tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - -# 2. define config -ppo_config = {'batch_size': 1, 'learning_rate':1e-5} -config = PPOConfig(**ppo_config) - - -# 2. Create optimizer -optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate) - -# 3. initialize trainer -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer) -``` - -### Use LION optimizer +By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows: -You can use the new [LION optimizer from Google](https://huggingface.co/papers/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training: ```python -optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate) - -... -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer) +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import optim +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + optimizers=(optimizer, None), +) +trainer.train() ``` -We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)): - -
- -
+### Add a learning rate scheduler -## Add a learning rate scheduler +You can also play with your training by adding learning rate schedulers. -You can also play with your training by adding learning rate schedulers! ```python -import torch -from transformers import GPT2Tokenizer -from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead - -# 1. load a pretrained model -model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') -ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') -tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - -# 2. define config -ppo_config = {'batch_size': 1, 'learning_rate':1e-5} -config = PPOConfig(**ppo_config) - - -# 2. Create optimizer -optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate) -lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) - -# 3. initialize trainer -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler) +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import optim +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate) +lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + optimizers=(optimizer, lr_scheduler), +) +trainer.train() ``` ## Memory efficient fine-tuning by sharing layers Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train. + ```python -import torch -from transformers import AutoTokenizer -from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import create_reference_model, DPOConfig, DPOTrainer -# 1. load a pretrained model -model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m') +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") ref_model = create_reference_model(model, num_shared_layers=6) -tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') - -# 2. initialize trainer -ppo_config = {'batch_size': 1} -config = PPOConfig(**ppo_config) -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() ``` ## Pass 8-bit reference models -
- -Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. +Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. -Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition). - -
+Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit). ```python -# 0. imports -# pip install bitsandbytes -import torch -from transformers import AutoTokenizer -from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead - -# 1. load a pretrained model -model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m') -ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True) -tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') - -# 2. initialize trainer -ppo_config = {'batch_size': 1} -config = PPOConfig(**ppo_config) -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer) +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +quantization_config = BitsAndBytesConfig(load_in_8bit=True) +ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() ``` ## Use the CUDA cache optimizer -When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`: +When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`: ```python -config = PPOConfig(..., optimize_cuda_cache=True) -``` - - - -## Use score scaling/normalization/clipping -As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://huggingface.co/papers/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`: -```python -from trl import PPOConfig - -ppo_config = { - use_score_scaling=True, - use_score_norm=True, - score_clip=0.5, -} -config = PPOConfig(**ppo_config) -``` - -To run `ppo.py`, you can use the following command: -``` -python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5 +training_args = DPOConfig(..., optimize_cuda_cache=True) ``` diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index 6742d23123..fa69ff1e32 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -205,7 +205,7 @@ Choosing the right dataset format depends on the task you are working on and the | [`NashMDTrainer`] | [Prompt-only](#prompt-only) | | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | | [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`PPOv2Trainer`] | Tokenized language modeling | +| [`PPOTrainer`] | Tokenized language modeling | | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | [`SFTTrainer`] | [Language modeling](#language-modeling) | | [`XPOTrainer`] | [Prompt-only](#prompt-only) | diff --git a/docs/source/detoxifying_a_lm.mdx b/docs/source/detoxifying_a_lm.mdx index e63fa4ebff..30c7d5a930 100644 --- a/docs/source/detoxifying_a_lm.mdx +++ b/docs/source/detoxifying_a_lm.mdx @@ -98,19 +98,15 @@ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype= and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`. -- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`: +- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
```python -ppo_trainer = PPOTrainer( - model=model, - tokenizer=tokenizer, - num_shared_layers=4, - ... -) +ref_policy = create_reference_model(model, num_shared_layers=6) +trainer = PPOTrainer(..., ref_policy=ref_policy) ``` In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index 66e55940c9..f9d01074ed 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -12,7 +12,7 @@ The abstract from the paper is the following: The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm. -Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppov2_trainer): +Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer): 1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt. 2. **Optimization**: Maximize the log-likelihood of the DPO loss directly. diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 16686e5fe4..d239199810 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -44,8 +44,8 @@ Then, it is encouraged to launch jobs with `accelerate launch`! | [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | | [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. | | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | -| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the [`PPOTrainer`] to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. | -| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a sentiment analysis model using [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). | +| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language | +| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. | | [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. | | [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. | diff --git a/docs/source/logging.mdx b/docs/source/logging.mdx index 71eb7c4137..4c60868dac 100644 --- a/docs/source/logging.mdx +++ b/docs/source/logging.mdx @@ -1,15 +1,14 @@ # Logging As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging. -By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`. +By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard. Upon initialization, pass one of these two options to the [`PPOConfig`]: + ``` -config = PPOConfig( - model_name=args.model_name, - log_with=`wandb`, # or `tensorboard` -) +training_args = PPOConfig(..., report_to="wandb") # or "tensorboard" ``` + If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig. ## PPO Logging diff --git a/docs/source/ppov2_trainer.md b/docs/source/ppo_trainer.md similarity index 98% rename from docs/source/ppov2_trainer.md rename to docs/source/ppo_trainer.md index 93adf0ffdc..414c051abc 100644 --- a/docs/source/ppov2_trainer.md +++ b/docs/source/ppo_trainer.md @@ -1,4 +1,4 @@ -# PPOv2 Trainer +# PPO Trainer [![](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl) @@ -167,7 +167,7 @@ In the logs the sampled generations look like ## Implementation details -This PPOv2 implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). +This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). ## Benchmark experiments @@ -222,14 +222,14 @@ python -m openrlbenchmark.rlops_multi_metrics \ --pc.ncols 4 \ --pc.ncols-legend 1 \ --pc.xlabel "Episode" \ - --output-filename benchmark/trl/pr-1540/ppov2 \ + --output-filename benchmark/trl/pr-1540/ppo \ --scan-history ``` -## PPOv2Trainer +## PPOTrainer -[[autodoc]] PPOv2Trainer +[[autodoc]] PPOTrainer -## PPOv2Config +## PPOConfig -[[autodoc]] PPOv2Config \ No newline at end of file +[[autodoc]] PPOConfig \ No newline at end of file diff --git a/docs/source/ppo_trainer.mdx b/docs/source/ppo_trainer.mdx deleted file mode 100644 index ebc97a9e28..0000000000 --- a/docs/source/ppo_trainer.mdx +++ /dev/null @@ -1,173 +0,0 @@ -# PPO Trainer - -[![](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl) - -TRL supports the [PPO](https://huggingface.co/papers/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback). - -The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm. - -## How PPO works - -Fine-tuning a language model via PPO consists of roughly three steps: - -1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence. -2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. -3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. - -This process is illustrated in the sketch below: - -
- -

Figure: Sketch of the workflow.

-
- -## Expected dataset format - -The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm. - -Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop. - -Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset: - -```py -from datasets import load_dataset - -dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train") -dataset = dataset.rename_column("prompt", "query") -dataset = dataset.remove_columns(["meta", "completion"]) -``` - -Resulting in the following subset of the dataset: - -```py -ppo_dataset_dict = { - "query": [ - "Explain the moon landing to a 6 year old in a few sentences.", - "Why aren’t birds real?", - "What happens if you fire a cannonball directly at a pumpkin at high speeds?", - "How can I steal from a grocery store without getting caught?", - "Why is it important to eat socks after meditating? " - ] -} -``` - -## Using the `PPOTrainer` - -For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response. - -### Initializing the `PPOTrainer` - -The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer. - -```py -from trl import PPOConfig - -config = PPOConfig( - model_name="gpt2", - learning_rate=1.41e-5, -) -``` - -Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows: - -```py -from transformers import AutoTokenizer - -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer - -model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name) -tokenizer = AutoTokenizer.from_pretrained(config.model_name) - -tokenizer.pad_token = tokenizer.eos_token -``` - -As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use. - -```py -from transformers import pipeline - -reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb") -``` - -Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop: - -```py -def tokenize(sample): - sample["input_ids"] = tokenizer.encode(sample["query"]) - return sample - -dataset = dataset.map(tokenize, batched=False) -``` - -Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model. - -```py -from trl import PPOTrainer - -ppo_trainer = PPOTrainer( - model=model, - config=config, - dataset=dataset, - tokenizer=tokenizer, -) -``` - -### Starting the training loop - -Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above. - -To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training). - -```py -generation_kwargs = { - "min_length": -1, - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": tokenizer.eos_token_id, -} -``` - -We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm. - -```py -from tqdm import tqdm - - -epochs = 10 -for epoch in tqdm(range(epochs), "epoch: "): - for batch in tqdm(ppo_trainer.dataloader): - query_tensors = batch["input_ids"] - - #### Get response from SFTModel - response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) - batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] - - #### Compute reward score - texts = [q + r for q, r in zip(batch["query"], batch["response"])] - pipe_outputs = reward_model(texts) - rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] - - #### Run PPO step - stats = ppo_trainer.step(query_tensors, response_tensors, rewards) - ppo_trainer.log_stats(stats, batch, rewards) - -#### Save model -ppo_trainer.save_pretrained("my_ppo_model") -``` - -## Logging - -While training and evaluating we log the following metrics: - -- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc. -- `batch`: The batch of data used to train the SFT model. -- `rewards`: The rewards obtained from the Reward model. - -## PPOTrainer - -[[autodoc]] PPOTrainer - -## PPOConfig - -[[autodoc]] PPOConfig diff --git a/examples/hello_world.py b/examples/hello_world.py deleted file mode 100644 index d7805d7885..0000000000 --- a/examples/hello_world.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# 0. imports -import torch -from transformers import GPT2Tokenizer - -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer - - -# 1. load a pretrained model -model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") -ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") -tokenizer = GPT2Tokenizer.from_pretrained("gpt2") -tokenizer.pad_token = tokenizer.eos_token - -# 2. initialize trainer -ppo_config = {"mini_batch_size": 1, "batch_size": 1} -config = PPOConfig(**ppo_config) -ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer) - -# 3. encode a query -query_txt = "This morning I went to the " -query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) - -# 4. generate model response -generation_kwargs = { - "min_length": -1, - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": tokenizer.eos_token_id, - "max_new_tokens": 20, -} -response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs) -response_txt = tokenizer.decode(response_tensor[0]) - -# 5. define a reward for response -# (this could be any reward such as human feedback or output from another model) -reward = [torch.tensor(1.0, device=model.pretrained_model.device)] - -# 6. train model with ppo -train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) diff --git a/examples/scripts/ppo.py b/examples/scripts/ppo.py deleted file mode 100644 index c6bd208d06..0000000000 --- a/examples/scripts/ppo.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -python examples/scripts/ppo.py \ - --log_with=wandb -""" - -from dataclasses import dataclass, field -from typing import Optional - -import torch -from accelerate import Accelerator, PartialState -from datasets import load_dataset -from peft import LoraConfig -from tqdm import tqdm -from transformers import AutoTokenizer, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available, pipeline - -from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed -from trl.core import LengthSampler - - -tqdm.pandas() - - -@dataclass -class ScriptArguments: - use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"}) - trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) - - # LoraConfig - use_peft: bool = field(default=False, metadata={"help": "whether to use peft"}) - lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) - lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"}) - - -parser = HfArgumentParser((ScriptArguments, PPOConfig)) -script_args, ppo_config = parser.parse_args_into_dataclasses() - -# We then define the arguments to pass to the sentiment analysis pipeline. -# We set `return_all_scores` to True to get the sentiment score for each token. -sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16} - -trl_model_class = ( - AutoModelForCausalLMWithValueHead if not script_args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead -) - -tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name) -tokenizer.pad_token = tokenizer.eos_token - - -# Below is an example function to build the dataset. In our case, we use the IMDB dataset -# from the `datasets` library. One should customize this function to train the model on -# its own dataset. -def build_dataset(query_dataset, dataset_num_proc, input_min_text_length=2, input_max_text_length=8): - """ - Build dataset for training. This builds the dataset from `load_dataset`, one should - customize this function to train the model on its own dataset. - - Args: - query_dataset (`str`): - The name of the dataset to be loaded. - - Returns: - dataloader (`torch.utils.data.DataLoader`): - The dataloader for the dataset. - """ - # load imdb with datasets - dataset = load_dataset(query_dataset, split="train") - dataset = dataset.rename_columns({"text": "review"}) - dataset = dataset.filter(lambda x: len(x["review"]) > 200, num_proc=dataset_num_proc) - - input_size = LengthSampler(input_min_text_length, input_max_text_length) - - def tokenize(sample): - sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] - sample["query"] = tokenizer.decode(sample["input_ids"]) - return sample - - dataset = dataset.map(tokenize, num_proc=dataset_num_proc) - dataset.set_format(type="torch") - return dataset - - -# We retrieve the dataloader by calling the `build_dataset` function. -# Compute that only on the main process for faster data processing. -# see: https://github.com/huggingface/trl/pull/1255 -with PartialState().local_main_process_first(): - dataset = build_dataset(ppo_config.query_dataset, ppo_config.dataset_num_proc) - - -def collator(data): - return {key: [d[key] for d in data] for key in data[0]} - - -# set seed before initializing value head for deterministic eval -set_seed(ppo_config.seed) - -# Now let's build the model, the reference model, and the tokenizer. -if not script_args.use_peft: - ref_model = trl_model_class.from_pretrained(ppo_config.model_name, trust_remote_code=script_args.trust_remote_code) - device_map = None - peft_config = None -else: - peft_config = LoraConfig( - r=script_args.lora_r, - lora_alpha=script_args.lora_alpha, - bias="none", - task_type="CAUSAL_LM", - ) - ref_model = None - # Copy the model to each device - device_map = {"": Accelerator().local_process_index} - -model = trl_model_class.from_pretrained( - ppo_config.model_name, - trust_remote_code=script_args.trust_remote_code, - device_map=device_map, - peft_config=peft_config, -) - - -tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name) - -# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. -tokenizer.pad_token_id = tokenizer.eos_token_id - -# We then build the PPOTrainer, passing the model, the reference model, the tokenizer -ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator) - -# We then build the sentiment analysis pipeline, passing the model name and the -# sentiment analysis pipeline arguments. Let's also make sure to set the device -# to the same device as the PPOTrainer. -device = ppo_trainer.accelerator.device -if ppo_trainer.accelerator.num_processes == 1: - if is_torch_xpu_available(): - device = "xpu:0" - elif is_torch_npu_available(): - device = "npu:0" - else: - device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug -ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin -task, model_name = ppo_config.reward_model.split(":") -if ds_plugin is not None and ds_plugin.is_zero3_init_enabled(): - with ds_plugin.zero3_init_context_manager(enable=False): - sentiment_pipe = pipeline(task, model=model_name, device=device) -else: - sentiment_pipe = pipeline(task, model=model_name, device=device) - -# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. -if sentiment_pipe.tokenizer.pad_token_id is None: - sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id - -if sentiment_pipe.model.config.pad_token_id is None: - sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id - -# We then define the arguments to pass to the `generate` function. These arguments -# are passed to the `generate` function of the PPOTrainer, which is a wrapper around -# the `generate` function of the trained model. -generation_kwargs = { - "min_length": -1, - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": tokenizer.eos_token_id, - "max_new_tokens": 32, -} - -for batch in tqdm(ppo_trainer.dataloader): - query_tensors = batch["input_ids"] - - # Get response from gpt2 - response_tensors, ref_response_tensors = ppo_trainer.generate( - query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs - ) - batch["response"] = tokenizer.batch_decode(response_tensors) - batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors) - - # Compute sentiment score - texts = [q + r for q, r in zip(batch["query"], batch["response"])] - pipe_outputs = sentiment_pipe(texts, **sent_kwargs) - rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] - ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] - ref_pipe_outputs = sentiment_pipe(ref_texts, **sent_kwargs) - ref_rewards = [torch.tensor(output[1]["score"]) for output in ref_pipe_outputs] - batch["ref_rewards"] = ref_rewards - - # Run PPO step - stats = ppo_trainer.step(query_tensors, response_tensors, rewards) - ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"]) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index a71ed138d1..41c7c8b69d 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -23,7 +23,7 @@ HfArgumentParser, ) -from trl import ModelConfig, PPOv2Config, PPOv2Trainer +from trl import ModelConfig, PPOConfig, PPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE @@ -55,7 +55,7 @@ if __name__ == "__main__": - parser = HfArgumentParser((PPOv2Config, ModelConfig)) + parser = HfArgumentParser((PPOConfig, ModelConfig)) training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -118,7 +118,7 @@ def tokenize(element): ################ # Training ################ - trainer = PPOv2Trainer( + trainer = PPOTrainer( config=training_args, processing_class=tokenizer, policy=policy, diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 14f12e680e..441db0502f 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -23,7 +23,7 @@ HfArgumentParser, ) -from trl import ModelConfig, PPOv2Config, PPOv2Trainer +from trl import ModelConfig, PPOConfig, PPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE @@ -58,7 +58,7 @@ if __name__ == "__main__": - parser = HfArgumentParser((PPOv2Config, ModelConfig)) + parser = HfArgumentParser((PPOConfig, ModelConfig)) training_args, model_config = parser.parse_args_into_dataclasses() # remove output_dir if exists shutil.rmtree(training_args.output_dir, ignore_errors=True) @@ -123,7 +123,7 @@ def tokenize(element): ################ # Training ################ - trainer = PPOv2Trainer( + trainer = PPOTrainer( config=training_args, processing_class=tokenizer, policy=policy, diff --git a/examples/scripts/ppo_multi_adapter.py b/examples/scripts/ppo_multi_adapter.py deleted file mode 100644 index 9cc358e823..0000000000 --- a/examples/scripts/ppo_multi_adapter.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass, field -from typing import Optional - -import torch -from accelerate import PartialState -from datasets import load_dataset -from peft import LoraConfig -from tqdm import tqdm -from transformers import ( - AutoTokenizer, - BitsAndBytesConfig, - HfArgumentParser, - is_torch_npu_available, - is_torch_xpu_available, -) - -from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer -from trl.core import LengthSampler - - -input_min_text_length = 6 -input_max_text_length = 12 - - -@dataclass -class ScriptArguments: - """ - The name of the Casual LM model we wish to fine with PPO - """ - - model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"}) - dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) - rm_adapter: Optional[str] = field( - default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"} - ) - log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) - use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"}) - seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) - use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"}) - use_score_norm: Optional[bool] = field( - default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"} - ) - score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"}) - dataset_num_proc: Optional[int] = field( - default=None, metadata={"help": "The number of workers to use to tokenize the data"} - ) - - -parser = HfArgumentParser(ScriptArguments) -script_args = parser.parse_args_into_dataclasses()[0] - - -def create_and_prepare_dataset(tokenizer, num_proc): - dataset = load_dataset(script_args.dataset_name, split="train[:1%]") - - input_size = LengthSampler(input_min_text_length, input_max_text_length) - - def tokenize(example): - text_size = input_size() - example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size] - example["query"] = tokenizer.decode(example["input_ids"]) - return example - - dataset = dataset.map(tokenize, batched=False, num_proc=num_proc) - dataset.set_format("torch") - return dataset - - -lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", -) -nf4_config = BitsAndBytesConfig( - load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 -) -model = AutoModelForCausalLMWithValueHead.from_pretrained( - script_args.model_name, - device_map={"": "xpu:0"} if is_torch_xpu_available() else {"": "npu:0"} if is_torch_npu_available else {"": 0}, - peft_config=lora_config, - quantization_config=nf4_config, - reward_adapter=script_args.rm_adapter, - use_safetensors=script_args.use_safetensors, -) -tokenizer = AutoTokenizer.from_pretrained(script_args.model_name) - -tokenizer.pad_token = tokenizer.eos_token - -# Compute that only on the main process for faster data processing. -# see: https://github.com/huggingface/trl/pull/1255 -with PartialState().local_main_process_first(): - dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc) - - -def collator(data): - return {key: [d[key] for d in data] for key in data[0]} - - -config = PPOConfig( - model_name=script_args.model_name, - log_with=script_args.log_with, - learning_rate=1e-5, - batch_size=8, - mini_batch_size=2, - gradient_accumulation_steps=2, - optimize_cuda_cache=True, - seed=script_args.seed, - use_score_scaling=script_args.use_score_scaling, - use_score_norm=script_args.use_score_norm, - score_clip=script_args.score_clip, -) - -ppo_trainer = PPOTrainer( - config, - model, - ref_model=None, - tokenizer=tokenizer, - dataset=dataset, - data_collator=collator, -) - -generation_kwargs = { - "top_k": 0.0, - "top_p": 0.9, - "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "max_new_tokens": 32, -} - -for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): - question_tensors = batch["input_ids"] - - response_tensors = ppo_trainer.generate( - question_tensors, - return_prompt=False, - **generation_kwargs, - ) - batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) - - # Compute reward score - texts = [q + r for q, r in zip(batch["query"], batch["response"])] - inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device) - raw_rewards = ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).compute_reward_score(**inputs) - rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token - - # Run PPO step - stats = ppo_trainer.step(question_tensors, response_tensors, rewards) - ppo_trainer.log_stats(stats, batch, rewards) diff --git a/tests/test_e2e.py b/tests/test_e2e.py deleted file mode 100644 index 10cbc251ee..0000000000 --- a/tests/test_e2e.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import subprocess - - -def test_hello_world(): - subprocess.run( - "python examples/hello_world.py", - shell=True, - check=True, - ) diff --git a/tests/test_no_peft.py b/tests/test_no_peft.py deleted file mode 100644 index 3d0e9f90a1..0000000000 --- a/tests/test_no_peft.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys -import unittest -from functools import partial -from unittest.mock import patch - -import pytest -import torch -from transformers import AutoTokenizer -from transformers.utils import import_utils - - -class DummyDataset(torch.utils.data.Dataset): - def __init__(self, query_data, response_data): - self.query_data = query_data - self.response_data = response_data - - def __len__(self): - return len(self.query_data) - - def __getitem__(self, idx): - return self.query_data[idx], self.response_data[idx] - - -EXPECTED_STATS = [ - "objective/kl", - "objective/kl_dist", - "objective/logprobs", - "objective/ref_logprobs", - "objective/kl_coef", - "objective/entropy", - "ppo/mean_non_score_reward", - "ppo/loss/policy", - "ppo/loss/value", - "ppo/loss/total", - "ppo/policy/entropy", - "ppo/policy/approxkl", - "ppo/policy/policykl", - "ppo/policy/clipfrac", - "ppo/policy/advantages", - "ppo/policy/advantages_mean", - "ppo/policy/ratio", - "ppo/returns/mean", - "ppo/returns/var", - "ppo/val/vpred", - "ppo/val/error", - "ppo/val/clipfrac", - "ppo/val/mean", - "ppo/val/var", - "ppo/val/var_explained", - "time/ppo/forward_pass", - "time/ppo/compute_rewards", - "time/ppo/optimize_step", - "time/ppo/calc_stats", - "time/ppo/total", - "ppo/learning_rate", -] - - -class TestPeftDependancy(unittest.TestCase): - def setUp(self): - self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" - self.seq_to_seq_model_id = "trl-internal-testing/tiny-random-T5ForConditionalGeneration" - - def test_no_peft(self): - _peft_available = import_utils._peft_available - import_utils._peft_available = False # required so that is_peft_available() returns False - with patch.dict(sys.modules, {"peft": None}): - from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead - - # Check that loading a model with `peft` will raise an error - with pytest.raises(ModuleNotFoundError): - import peft # noqa: F401 - - _trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) - _trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id) - import_utils._peft_available = _peft_available - - def test_imports_no_peft(self): - _peft_available = import_utils._peft_available - import_utils._peft_available = False # required so that is_peft_available() returns False - with patch.dict(sys.modules, {"peft": None}): - from trl import ( # noqa: F401 - AutoModelForCausalLMWithValueHead, - AutoModelForSeq2SeqLMWithValueHead, - PPOConfig, - PPOTrainer, - PreTrainedModelWrapper, - ) - import_utils._peft_available = _peft_available - - def test_ppo_trainer_no_peft(self): - _peft_available = import_utils._peft_available - import_utils._peft_available = False # required so that is_peft_available() returns False - with patch.dict(sys.modules, {"peft": None}): - from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer - - ppo_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - - trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_model_id) - tokenizer = AutoTokenizer.from_pretrained(ppo_model_id) - tokenizer.pad_token_id = tokenizer.eos_token_id - - ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) - - dummy_dataset = DummyDataset( - [torch.LongTensor([0, 1, 0, 1, 0, 1]), torch.LongTensor([0, 1, 0, 1, 0, 1])], - [torch.LongTensor([1, 0, 1, 0, 1, 0]), torch.LongTensor([0, 1, 0, 1, 0, 1])], - ) - - ppo_trainer = PPOTrainer( - config=ppo_config, - model=trl_model, - ref_model=None, - tokenizer=tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - # check gradients are not None - for _, param in trl_model.named_parameters(): - if param.requires_grad: - assert param.grad is not None - - # check expected stats - for stat in EXPECTED_STATS: - assert stat in train_stats - import_utils._peft_available = _peft_available diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index fc88db81a3..21dffd9bee 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -11,1271 +11,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy -import fnmatch -import gc -import re -import tempfile -import unittest -from functools import partial - -import pytest -import torch -from huggingface_hub import HfApi -from parameterized import parameterized -from requests.exceptions import HTTPError -from transformers import AutoTokenizer -from transformers.testing_utils import require_peft, require_torch_multi_accelerator - -from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed -from trl.core import respond_to_batch - -from .testing_constants import CI_HUB_ENDPOINT, CI_HUB_USER - - -EXPECTED_STATS = [ - "objective/kl", - "objective/kl_dist", - "objective/logprobs", - "objective/ref_logprobs", - "objective/kl_coef", - "objective/entropy", - "ppo/mean_non_score_reward", - "ppo/loss/policy", - "ppo/loss/value", - "ppo/loss/total", - "ppo/policy/entropy", - "ppo/policy/approxkl", - "ppo/policy/policykl", - "ppo/policy/clipfrac", - "ppo/policy/advantages", - "ppo/policy/advantages_mean", - "ppo/policy/ratio", - "ppo/returns/mean", - "ppo/returns/var", - "ppo/val/vpred", - "ppo/val/error", - "ppo/val/clipfrac", - "ppo/val/mean", - "ppo/val/var", - "ppo/val/var_explained", - "time/ppo/forward_pass", - "time/ppo/compute_rewards", - "time/ppo/optimize_step", - "time/ppo/calc_stats", - "time/ppo/total", - "ppo/learning_rate", -] - - -class DummyDataset(torch.utils.data.Dataset): - def __init__(self, query_data, response_data): - self.query_data = query_data - self.response_data = response_data - - def __len__(self): - return len(self.query_data) - - def __getitem__(self, idx): - return self.query_data[idx], self.response_data[idx] - - -def apply_mask(values, mask): - unmasked_values = [] - for v, m in zip(values, mask): - if m == 1: - unmasked_values.append(v) - return torch.Tensor(unmasked_values) - - -def abs_diff_masked_tensors(tensor_1, tensor_2, mask_1, mask_2): - diffs = [] - for l1, l2, m1, m2 in zip(tensor_1, tensor_2, mask_1, mask_2): - diff = apply_mask(l1, m1) - apply_mask(l2, m2) - diffs.append(diff.sum()) - return abs(sum(diffs)) - - -class PPOTrainerTester(unittest.TestCase): - """ - A wrapper class for testing PPOTrainer - """ - - @classmethod - def setUpClass(cls): - cls._api = HfApi(endpoint=CI_HUB_ENDPOINT) - - def setUp(self): - set_seed(42) - - # model_id - self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" - - # get models and tokenizer - self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) - self.gpt2_ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) - self.gpt2_tokenizer = AutoTokenizer.from_pretrained(self.model_id) - - self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token - - # get bloom as right padding examples: - model_id = "trl-internal-testing/tiny-BloomForCausalLM-correct-vocab" - self.bloom_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) - self.bloom_tokenizer = AutoTokenizer.from_pretrained(model_id) - - model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" - self.t5_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(model_id) - self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) - - # initialize trainer - self.ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None) - - @classmethod - def tearDownClass(cls): - for model in [f"{CI_HUB_USER}/test-ppo-trainer"]: - try: - cls._api.delete_repo(repo_id=model) - except HTTPError: - pass - - def tearDown(self): - # free memory - gc.collect() - - def _init_dummy_dataset(self): - # encode a query - query_txt = "This morning I went to the " - query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt") - assert query_tensor.shape == (1, 7) - # get model response - response_tensor = respond_to_batch(self.gpt2_model, query_tensor) - assert response_tensor.shape == (1, 20) - - # create a dummy dataset - min_length = min(len(query_tensor[0]), len(response_tensor[0])) - dummy_dataset = DummyDataset( - [query_tensor[:, :min_length].squeeze(0) for _ in range(2)], - [response_tensor[:, :min_length].squeeze(0) for _ in range(2)], - ) - - return dummy_dataset - - def test_drop_last_dataloader(self): - self.ppo_config = PPOConfig(batch_size=3, mini_batch_size=1, log_with=None) - - dummy_dataset = self._init_dummy_dataset() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=self.gpt2_ref_model, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - dummy_dataloader = ppo_trainer.dataloader - - assert len(dummy_dataloader) == 0 - - def test_ppo_step(self): - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=self.gpt2_ref_model, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - for param in ppo_trainer.model.parameters(): - assert param.grad is not None - - for stat in EXPECTED_STATS: - assert stat in train_stats.keys() - - def test_ppo_step_with_masks(self): - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=self.gpt2_ref_model, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - - response_mask = [torch.ones_like(r) for r in response_tensor] - - # train model - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward, response_mask) - break - - for param in ppo_trainer.model.parameters(): - assert param.grad is not None - - for stat in EXPECTED_STATS: - assert stat in train_stats.keys() - - def test_ppo_step_with_no_ref_sgd(self): - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - optimizer = torch.optim.SGD(self.gpt2_model.parameters(), lr=0.01) - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - optimizer=optimizer, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - - assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD) - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - for name, param in ppo_trainer.model.named_parameters(): - assert param.grad is not None, f"Parameter {name} has no gradient" - - # ref model should not be trained - for name, param in ppo_trainer.ref_model.named_parameters(): - assert param.grad is None, f"Parameter {name} has a gradient" - - # Finally check stats - for stat in EXPECTED_STATS: - assert stat in train_stats.keys() - - def test_ppo_step_with_no_ref_sgd_lr_scheduler(self): - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - optimizer = torch.optim.SGD(self.gpt2_model.parameters(), lr=0.01) - lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - optimizer=optimizer, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - lr_scheduler=lr_scheduler, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - - assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD) - assert isinstance(ppo_trainer.lr_scheduler.scheduler, torch.optim.lr_scheduler.ExponentialLR) - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - for name, param in ppo_trainer.model.named_parameters(): - assert param.grad is not None, f"Parameter {name} has no gradient" - - # ref model should not be trained - for name, param in ppo_trainer.ref_model.named_parameters(): - assert param.grad is None, f"Parameter {name} has a gradient" - - # Finally check stats - for stat in EXPECTED_STATS: - assert stat in train_stats.keys() - - # assert that the LR has increased for exponential decay - assert train_stats["ppo/learning_rate"] > self.ppo_config.learning_rate - - def test_ppo_step_with_no_ref(self): - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - for name, param in ppo_trainer.model.named_parameters(): - assert param.grad is not None, f"Parameter {name} has no gradient" - - # ref model should not be trained - for name, param in ppo_trainer.ref_model.named_parameters(): - assert param.grad is None, f"Parameter {name} has a gradient" - - # initialize a new gpt2 model: - model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) - for name, param in ppo_trainer.ref_model.named_parameters(): - if "v_head" not in name: - name = name.replace("pretrained_model.", "") - - assert torch.allclose( - param.cpu(), model.state_dict()[name].cpu() - ), f"Parameter {name} has changed from the original model" - - # Finally check stats - for stat in EXPECTED_STATS: - assert stat in train_stats.keys() - - def test_ppo_step_with_no_ref_custom_layers(self): - """ - Test PPO step with no reference model and custom layers - For shared layers configuration, all the layers after the `num_shared_layers` are considered as custom layers - therefore the gradients should be computed for these layers only. - """ - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) - num_shared_layers = 1 - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - num_shared_layers=num_shared_layers, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - pattern = r".*transformer\.h\.(\d+)\..*" - final_layers = ["ln_f", "v_head", "lm_head"] - - for name, param in ppo_trainer.model.named_parameters(): - if re.match(pattern, name): - layer_number = int(re.match(pattern, name).groups(0)[0]) - if layer_number < num_shared_layers: - assert param.grad is None, f"Parameter {name} has a gradient" - else: - assert param.grad is not None, f"Parameter {name} has no gradient" - elif any(layer in name for layer in final_layers): - assert param.grad is not None, f"Parameter {name} has no gradient" - - # ref model should not be trained - for name, param in ppo_trainer.ref_model.named_parameters(): - assert param.grad is None, f"Parameter {name} has a gradient" - - for stat in EXPECTED_STATS: - assert stat in train_stats.keys() - - def test_ppo_step_with_ref_and_custom_layers_warning(self): - """ - Test PPO step with a reference model and custom layers - The trainer should raise a warning if the argument `num_shared_layers` is set - together with a reference model. - """ - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - num_shared_layers = 6 - - with self.assertWarns(UserWarning): - _ = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=self.gpt2_ref_model, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - num_shared_layers=num_shared_layers, - ) - - def test_ppo_step_rewards_shape(self): - """ - Test if the rewards shape is correct by asserting that if a wrong reward shape is passed, we get - a value error. - """ - - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor([[1.0]]), torch.tensor([[0.0]])] - # train model - this should raise an error - with pytest.raises(ValueError): - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - - reward = [torch.tensor([1.0]), torch.tensor([0.0])] - # train model - this should work - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - # check if the gradients are computed for the model - for name, param in ppo_trainer.model.named_parameters(): - assert param.grad is not None, f"Parameter {name} has no gradient" - - # ref model should not be trained - for name, param in ppo_trainer.ref_model.named_parameters(): - assert param.grad is None, f"Parameter {name} has a gradient" - - def test_ppo_step_input_shape(self): - """ - Test if the shape of the expected inputs are correct - """ - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor([1.0]), torch.tensor([0.0])] - # train model - this should raise an error - bs = ppo_trainer.config.batch_size - - queries, responses, _, _ = ppo_trainer._step_safety_checker( - bs, list(query_tensor), list(response_tensor), reward - ) - - assert isinstance(queries, list), f"queries should be a list, got {type(queries)}" - assert isinstance(responses, list), f"responses should be a list, got {type(responses)}" - - # check the shapes - for i in range(bs): - assert queries[i].shape == torch.Size([7]) - assert responses[i].size() == torch.Size([7]) - break - - def test_ppo_step_no_dataset(self): - """ - Test if the training loop works fine without passing a dataset - """ - query_txt = "This morning I went to the " - query_tensor = self.gpt2_tokenizer.encode(query_txt, return_tensors="pt") - self.ppo_config.batch_size = 1 - - response_tensor = respond_to_batch(self.gpt2_model, query_tensor) - - # Check that this warns the user about batch size - with self.assertWarns(UserWarning): - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=self.gpt2_ref_model, - tokenizer=self.gpt2_tokenizer, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - # train model with ppo - reward = [torch.tensor([1.0])] - # train model - this should work fine - train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) - - # check gradients - for name, param in ppo_trainer.model.named_parameters(): - assert param.grad is not None, f"Parameter {name} has no gradient" - - # ref model should not be trained - for name, param in ppo_trainer.ref_model.named_parameters(): - assert param.grad is None, f"Parameter {name} has a gradient" - - # check train stats - for stat in EXPECTED_STATS: - assert stat in train_stats, f"Train stats should contain {stat}" - - def test_loss_trainer(self): - """ - Test if the loss trainer works fine - """ - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - self.gpt2_model.eval() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])] - dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])] - dummy_scores = torch.Tensor([1, 2]) - - ppo_trainer.config.mini_batch_size = 1 - ppo_trainer.config.batch_size = 1 - model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) - all_logprobs, _, values, mask = ppo_trainer.batched_forward_pass( - self.gpt2_model, dummy_queries, dummy_responses, model_inputs - ) - - # dummy values - ref_logprobs = all_logprobs + 1 - logits = torch.exp(all_logprobs) - vpreds = values + 0.1 - - score, non_score, kls = ppo_trainer.compute_rewards(dummy_scores, all_logprobs, ref_logprobs, mask) - values, advantages, returns = ppo_trainer.compute_advantages(values, score, mask) - - # just make sure a dummy loss is computed - idx = 0 - pg_loss, v_loss, _ = ppo_trainer.loss( - all_logprobs[idx].unsqueeze(0), - values[idx].unsqueeze(0), - logits[idx].unsqueeze(0), - vpreds[idx].unsqueeze(0), - ref_logprobs[idx].unsqueeze(0), - mask[idx].unsqueeze(0), - advantages[idx].unsqueeze(0), - returns[idx].unsqueeze(0), - ) - - assert abs(pg_loss.item() - 1.8226) < 0.0001 - assert abs(v_loss.item() - 0.1260) < 0.0001 - - # check if we get same results with masked parts removed - pg_loss_unmasked, v_loss_unmasked, _ = ppo_trainer.loss( - apply_mask(all_logprobs[idx], mask[idx]).unsqueeze(0), - apply_mask(values[idx], mask[idx]).unsqueeze(0), - apply_mask(logits[idx], mask[idx]).unsqueeze(0), - apply_mask(vpreds[idx], mask[idx]).unsqueeze(0), - apply_mask(ref_logprobs[idx], mask[idx]).unsqueeze(0), - apply_mask(mask[idx], mask[idx]).unsqueeze(0), - apply_mask(advantages[idx], mask[idx]).unsqueeze(0), - apply_mask(returns[idx], mask[idx]).unsqueeze(0), - ) - assert abs(pg_loss_unmasked.item() - 1.8226) < 0.0001 - assert abs(v_loss_unmasked.item() - 0.1260) < 0.0001 - - @parameterized.expand( - [ - ["gpt2"], - ["bloom"], - ["t5"], - ] +import platform +import subprocess + + +def test(): + command = """\ +python examples/scripts/ppo/ppo.py \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10 \ + --model_name_or_path EleutherAI/pythia-14m \ + --missing_eos_penalty 1.0 \ + --save_strategy no \ + --stop_token eos +""" + if platform.system() == "Windows": + # windows CI does not work with subprocesses for some reason + # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 + return + subprocess.run( + command, + shell=True, + check=True, ) - def test_batched_forward_pass(self, name): - """ - Test if the loss trainer works fine - """ - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - dummy_queries = [torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 3, 4, 5, 6, 7])] - dummy_responses = [torch.tensor([5, 6, 7, 8, 9]), torch.tensor([8, 9, 10, 11, 12, 13])] - - if name == "gpt2": - model = self.gpt2_model - tokenizer = self.gpt2_tokenizer - elif name == "bloom": - model = self.bloom_model - tokenizer = self.bloom_tokenizer - elif name == "t5": - model = self.t5_model - tokenizer = self.t5_tokenizer - - model.eval() - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=model, - ref_model=None, - tokenizer=tokenizer, - dataset=dummy_dataset, - ) - - # we test all combinations of fwd_bs and bs: - # if fwd_bs=bs=1: no padding is applied and only one forward pass - # if fwd_bs=1/bs=2: padding is applied and results computed in two fwd passes - # if fwd_bs=bs=2: padding is applied and results computed in one fwd pass - - ppo_trainer.config.mini_batch_size = 1 - ppo_trainer.config.batch_size = 1 - - model_inputs = ppo_trainer.prepare_model_inputs([dummy_queries[0]], [dummy_responses[0]]) - logprobs_0, logits_0, values_0, mask_0 = ppo_trainer.batched_forward_pass( - model, [dummy_queries[0]], [dummy_responses[0]], model_inputs - ) - - ppo_trainer.config.batch_size = 2 - model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) - logprobs_1, logits_1, values_1, mask_1 = ppo_trainer.batched_forward_pass( - model, dummy_queries, dummy_responses, model_inputs - ) - - ppo_trainer.config.mini_batch_size = 2 - model_inputs = ppo_trainer.prepare_model_inputs(dummy_queries, dummy_responses) - logprobs_2, logits_2, values_2, mask_2 = ppo_trainer.batched_forward_pass( - model, dummy_queries, dummy_responses, model_inputs - ) - assert abs_diff_masked_tensors(logprobs_1, logprobs_2, mask_1, mask_2) <= 0.0001 - assert abs_diff_masked_tensors(values_1, values_2, mask_1, mask_2) <= 0.0001 - assert abs_diff_masked_tensors(logprobs_0, logprobs_2[:1], mask_0, mask_2[:1]) <= 0.0001 - assert abs_diff_masked_tensors(values_0, values_2[:1], mask_0, mask_2[:1]) <= 0.0001 - - def test_ppo_trainer_max_grad_norm(self): - """ - Test if the `max_grad_norm` feature works as expected - """ - # initialize dataset - dummy_dataset = self._init_dummy_dataset() - - self.ppo_config.max_grad_norm = 0.00001 - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - # check gradients - for name, param in ppo_trainer.model.named_parameters(): - assert param.grad is not None, f"Parameter {name} has no gradient" - assert torch.all( - param.grad.abs() <= self.ppo_config.max_grad_norm - ), f"Parameter {name} has a gradient larger than max_grad_norm" - - def test_ppo_trainer_kl_penalty(self): - dummy_dataset = self._init_dummy_dataset() - - log_probs = torch.Tensor([[0.5, 0.2, 0.1], [0.6, 0.2, 0.1]]) - ref_log_probs = torch.Tensor([[0.4, 0.3, 0.0], [0.7, 0.1, 0.3]]) - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - expected_output = torch.Tensor([[0.1000, -0.1000, 0.1000], [-0.1000, 0.1000, -0.2000]]) - assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output) - - self.ppo_config.kl_penalty = "abs" - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - expected_output = torch.Tensor([[0.1000, 0.1000, 0.1000], [0.1000, 0.1000, 0.2000]]) - assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output) - - self.ppo_config.kl_penalty = "mse" - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - expected_output = torch.Tensor([[0.0050, 0.0050, 0.0050], [0.0050, 0.0050, 0.0200]]) - assert torch.allclose(ppo_trainer._kl_penalty(log_probs, ref_log_probs), expected_output) - - def test_ppo_trainer_full_kl_penalty(self): - # a few more extensive tests for the full kl option as it is more involved - dummy_dataset = self._init_dummy_dataset() - - self.ppo_config.kl_penalty = "full" - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - # Test on tensors for size B,S,T = (1,2,3) - # test for when the two dists are the same - log_probs = torch.Tensor( - [ - [ - [0.1, 0.2, 0.7], - [0.3, 0.4, 0.3], - ] - ] - ).exp() - - ref_log_probs = torch.Tensor( - [ - [ - [0.1, 0.2, 0.7], - [0.3, 0.4, 0.3], - ] - ] - ).exp() - - expected_output = torch.Tensor( - [[0.0, 0.0]], - ) - output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) - assert output.shape == (1, 2) - assert torch.allclose(output, expected_output) - - # test for when the two dists are almost not overlapping - log_probs = torch.Tensor( - [ - [ - [0.98, 0.01, 0.01], - [0.01, 0.98, 0.01], - ] - ] - ).log() - - ref_log_probs = torch.Tensor( - [ - [ - [0.01, 0.01, 0.98], - [0.01, 0.01, 0.98], - ] - ] - ).log() - - expected_output = torch.Tensor( - [[4.4474, 4.4474]], - ) - output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) - assert output.shape == (1, 2) - assert torch.allclose(output, expected_output) - - # test for when the two dists are almost not overlapping - log_probs = torch.Tensor( - [ - [ - [0.49, 0.02, 0.49], - [0.49, 0.02, 0.49], - ] - ] - ).log() - - ref_log_probs = torch.Tensor( - [ - [ - [0.01, 0.98, 0.01], - [0.49, 0.02, 0.49], - ] - ] - ).log() - - expected_output = torch.Tensor( - [[3.7361, 0.0]], - ) - output = ppo_trainer._kl_penalty(log_probs, ref_log_probs) - assert output.shape == (1, 2) - assert torch.allclose(output, expected_output, atol=0.0001) - - @require_peft - def test_peft_model_ppo_trainer(self): - from peft import LoraConfig, get_peft_model - from transformers import AutoModelForCausalLM - - lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", - ) - gpt2_model = AutoModelForCausalLM.from_pretrained(self.model_id) - - # this line is very important - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - peft_model = get_peft_model(gpt2_model, lora_config) - model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model) - - dummy_dataset = self._init_dummy_dataset() - self.ppo_config.batch_size = 2 - self.ppo_config.mini_batch_size = 1 - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - assert ppo_trainer.ref_model is None - - dummy_dataloader = ppo_trainer.dataloader - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model by running a step twice - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - - ppo_trainer.model.train() - ppo_trainer.model.gradient_checkpointing_enable() - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - # check gradients - for name, param in model.named_parameters(): - if "lora" in name or "v_head" in name: - assert param.grad is not None, f"Parameter {name} has a no gradient" - else: - assert param.grad is None, f"Parameter {name} has a gradient" - - @require_peft - def test_peft_model_ppo_adapter_rm_trainer(self): - from peft import LoraConfig, get_peft_model - from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification - - dummy_inputs = torch.LongTensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) - rm_lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="SEQ_CLS", - ) - - reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id) - reward_model = get_peft_model(reward_model, rm_lora_config) - dummy_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, reward_model.parameters()), lr=1e-3) - - previous_rm_logits = reward_model(dummy_inputs).logits - loss = previous_rm_logits.mean() - loss.backward() - - dummy_optim.step() - reward_model.eval() - - original_rm_logits = reward_model(dummy_inputs).logits - - with tempfile.TemporaryDirectory() as tmpdirname: - reward_model.save_pretrained(tmpdirname) - - lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", - ) - gpt2_model = AutoModelForCausalLM.from_pretrained(self.model_id) - - # this line is very important - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - peft_model = get_peft_model(gpt2_model, lora_config) - model = AutoModelForCausalLMWithValueHead.from_pretrained( - peft_model, - reward_adapter=tmpdirname, - ) - - dummy_dataset = self._init_dummy_dataset() - self.ppo_config.batch_size = 2 - self.ppo_config.mini_batch_size = 1 - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - assert ppo_trainer.ref_model is None - - dummy_dataloader = ppo_trainer.dataloader - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model by running a step twice - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - - ppo_trainer.model.train() - ppo_trainer.model.gradient_checkpointing_enable() - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - dummy_inputs = dummy_inputs.to(ppo_trainer.accelerator.device) - new_logits = ppo_trainer.model.compute_reward_score(dummy_inputs) - assert not torch.allclose(previous_rm_logits.to(ppo_trainer.accelerator.device), new_logits[:, -1, :]) - assert torch.allclose(original_rm_logits.to(ppo_trainer.accelerator.device), new_logits[:, -1, :]) - - # check gradients - for name, param in model.named_parameters(): - if ("lora" in name or "v_head" in name) and ("reward" not in name): - assert param.grad is not None, f"Parameter {name} has a no gradient" - else: - assert param.grad is None, f"Parameter {name} has a gradient" - - @unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.") - def test_push_to_hub(self): - REPO_NAME = "test-ppo-trainer" - repo_id = f"{CI_HUB_USER}/{REPO_NAME}" - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=self.gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=self._init_dummy_dataset(), - ) - with tempfile.TemporaryDirectory(): - url = ppo_trainer.push_to_hub(repo_id=repo_id, token=self._token, api_endpoint=CI_HUB_ENDPOINT) - # Extract repo_name from the url - re_search = re.search(CI_HUB_ENDPOINT + r"/([^/]+/[^/]+)/", url) - assert re_search is not None - hub_repo_id = re_search.groups()[0] - # Check we created a Hub repo - assert hub_repo_id == repo_id - # Ensure all files are present - files = sorted(self._api.list_repo_files(hub_repo_id)) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, - [ - ".gitattributes", - "README.md", - "config.json", - "merges.txt", - "pytorch_model.bin", - "special_tokens_map.json", - "tokenizer_config.json", - "vocab.json", - ], - ) - ) - - @require_peft - @require_torch_multi_accelerator - def test_peft_model_ppo_trainer_multi_gpu(self): - from peft import LoraConfig, get_peft_model - from transformers import AutoModelForCausalLM - - lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", - ) - gpt2_model = AutoModelForCausalLM.from_pretrained( - "gpt2", device_map="balanced", max_memory={0: "500MB", 1: "500MB"} - ) - - assert set(gpt2_model.hf_device_map.values()) == {0, 1} - - # this line is very important - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - gpt2_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - peft_model = get_peft_model(gpt2_model, lora_config) - model = AutoModelForCausalLMWithValueHead.from_pretrained(peft_model) - - assert model.is_sequential_parallel - - dummy_dataset = self._init_dummy_dataset() - self.ppo_config.batch_size = 2 - self.ppo_config.mini_batch_size = 1 - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - assert ppo_trainer.ref_model is None - - dummy_dataloader = ppo_trainer.dataloader - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model by running a step twice - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - - ppo_trainer.model.train() - ppo_trainer.model.gradient_checkpointing_enable() - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - # check gradients - for name, param in model.named_parameters(): - if "lora" in name or "v_head" in name: - assert param.grad is not None, f"Parameter {name} has a no gradient" - else: - assert param.grad is None, f"Parameter {name} has a gradient" - - def test_generation(self): - dummy_dataset = self._init_dummy_dataset() - - model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") - tokenizer = AutoTokenizer.from_pretrained("gpt2") - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=model, - ref_model=None, - tokenizer=tokenizer, - dataset=dummy_dataset, - ) - - input_texts = ["this is a test", "this is another, longer test"] - - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id} - - tokenizer.pad_token = tokenizer.eos_token - - model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - model_inputs = [input_ids.to(ppo_trainer.accelerator.device) for input_ids in model_inputs] - - generations_batched = ppo_trainer.generate(model_inputs, batch_size=2, **generation_kwargs) - generations_batched = tokenizer.batch_decode(generations_batched) - - generations_single = [ppo_trainer.generate(inputs, **generation_kwargs).squeeze() for inputs in model_inputs] - generations_single = tokenizer.batch_decode(generations_single) - - assert generations_single == generations_batched - - def test_generation_with_ref_model(self): - dummy_dataset = self._init_dummy_dataset() - model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") - tokenizer = AutoTokenizer.from_pretrained("gpt2") - - # Negate the weights in the last layer of the ref model so it never - # outputs the same things as the primary model - ref_model = copy.deepcopy(model) - lm_head_weight = ref_model.pretrained_model.lm_head.weight - lm_head_weight.data = -lm_head_weight.data - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=model, - ref_model=ref_model, - tokenizer=tokenizer, - dataset=dummy_dataset, - ) - - input_texts = ["this is a test", "this is another, longer test"] - - generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id} - - tokenizer.pad_token = tokenizer.eos_token - - model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] - model_inputs = [input_ids.to(ppo_trainer.accelerator.device) for input_ids in model_inputs] - - generations_batched, ref_generations_batched = ppo_trainer.generate( - model_inputs, batch_size=2, generate_ref_response=True, **generation_kwargs - ) - generations_batched = tokenizer.batch_decode(generations_batched) - ref_generations_batched = tokenizer.batch_decode(ref_generations_batched) - - generations_single = [] - ref_generations_single = [] - for inputs in model_inputs: - generation, ref_generation = ppo_trainer.generate(inputs, generate_ref_response=True, **generation_kwargs) - generations_single.append(generation.squeeze()) - ref_generations_single.append(ref_generation.squeeze()) - - generations_single = tokenizer.batch_decode(generations_single) - ref_generations_single = tokenizer.batch_decode(ref_generations_single) - - assert generations_single == generations_batched - assert ref_generations_single == ref_generations_batched - - assert generations_batched != ref_generations_batched - assert generations_single != ref_generations_single - - def test_grad_accumulation(self): - dummy_dataset = self._init_dummy_dataset() - - torch.manual_seed(0) - gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id, summary_dropout_prob=0.0) - gpt2_model_clone = copy.deepcopy(gpt2_model) - - self.ppo_config.mini_batch_size = 2 - self.ppo_config.ppo_epochs = 1 - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=gpt2_model, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - dummy_dataloader = ppo_trainer.dataloader - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(1.0)] - # train model by running a step twice - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - model_grad = gpt2_model.v_head.summary.weight - - self.ppo_config.mini_batch_size = 1 - self.ppo_config.gradient_accumulation_steps = 2 - - ppo_trainer = PPOTrainer( - config=self.ppo_config, - model=gpt2_model_clone, - ref_model=None, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - dummy_dataloader = ppo_trainer.dataloader - - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(1.0)] - # train model by running a step twice - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - model_grad_acc = gpt2_model_clone.v_head.summary.weight - assert torch.allclose(model_grad_acc, model_grad, rtol=0.001, atol=0.001) - - @unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.") - def test_push_to_hub_if_best_reward(self): - REPO_NAME = "test-ppo-trainer" - repo_id = f"{CI_HUB_USER}/{REPO_NAME}" - - dummy_dataset = self._init_dummy_dataset() - - push_to_hub_if_best_kwargs = {"repo_id": repo_id} - - ppo_config = PPOConfig( - batch_size=2, - mini_batch_size=1, - log_with=None, - push_to_hub_if_best_kwargs=push_to_hub_if_best_kwargs, - compare_steps=1, - ) - - ppo_trainer = PPOTrainer( - config=ppo_config, - model=self.gpt2_model, - ref_model=self.gpt2_ref_model, - tokenizer=self.gpt2_tokenizer, - dataset=dummy_dataset, - ) - - ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False) - dummy_dataloader = ppo_trainer.dataloader - # train model with ppo - for query_tensor, response_tensor in dummy_dataloader: - # define a reward for response - # (this could be any reward such as human feedback or output from another model) - reward = [torch.tensor(1.0), torch.tensor(0.0)] - # train model - _ = ppo_trainer.step(list(query_tensor), list(response_tensor), reward) - break - - def test_batch_size_check(self): - with pytest.raises(ValueError): - PPOConfig(batch_size=2, mini_batch_size=2, gradient_accumulation_steps=2) +def test_num_train_epochs(): + command = """\ +python examples/scripts/ppo/ppo.py \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --num_train_epochs 0.003 \ + --model_name_or_path EleutherAI/pythia-14m \ + --missing_eos_penalty 1.0 \ + --save_strategy no \ + --stop_token eos +""" + if platform.system() == "Windows": + # windows CI does not work with subprocesses for some reason + # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 + return + subprocess.run( + command, + shell=True, + check=True, + ) diff --git a/tests/test_ppov2_trainer.py b/tests/test_ppov2_trainer.py deleted file mode 100644 index 21dffd9bee..0000000000 --- a/tests/test_ppov2_trainer.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import platform -import subprocess - - -def test(): - command = """\ -python examples/scripts/ppo/ppo.py \ - --learning_rate 3e-6 \ - --output_dir models/minimal/ppo \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 1 \ - --total_episodes 10 \ - --model_name_or_path EleutherAI/pythia-14m \ - --missing_eos_penalty 1.0 \ - --save_strategy no \ - --stop_token eos -""" - if platform.system() == "Windows": - # windows CI does not work with subprocesses for some reason - # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 - return - subprocess.run( - command, - shell=True, - check=True, - ) - - -def test_num_train_epochs(): - command = """\ -python examples/scripts/ppo/ppo.py \ - --learning_rate 3e-6 \ - --output_dir models/minimal/ppo \ - --per_device_train_batch_size 4 \ - --gradient_accumulation_steps 1 \ - --num_train_epochs 0.003 \ - --model_name_or_path EleutherAI/pythia-14m \ - --missing_eos_penalty 1.0 \ - --save_strategy no \ - --stop_token eos -""" - if platform.system() == "Windows": - # windows CI does not work with subprocesses for some reason - # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 - return - subprocess.run( - command, - shell=True, - check=True, - ) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 64bfea6768..25c2916025 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,28 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -import os -import sys -import warnings -from dataclasses import dataclass, field -from typing import Literal, Optional - -import numpy as np -import tyro -from transformers import is_wandb_available -from typing_extensions import Annotated - -from trl.trainer.utils import exact_div - -from ..core import flatten_dict +import os +from dataclasses import dataclass -JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)] +from ..trainer.utils import OnPolicyConfig @dataclass -class PPOConfig: +class PPOConfig(OnPolicyConfig): r""" Configuration class for the [`PPOTrainer`]. @@ -41,199 +28,35 @@ class PPOConfig: command line. Parameters: - exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`): + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): Name of this experiment. - seed (`int`, *optional*, defaults to `0`): - Random seed. - log_with (`Optional[Literal["wandb", "tensorboard"]]`, *optional*, defaults to `None`): - Log with either `"wandb"` or `"tensorboard"`. Check - [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details. - task_name (`Optional[str]`, *optional*, defaults to `None`): - Name of task to use - used only for tracking purposes. - model_name (`Optional[str]`, *optional*, defaults to `"gpt2"`): - Name of model to use - used only for tracking purposes. - query_dataset (`Optional[str]`, *optional*, defaults to `"stanfordnlp/imdb"`): - Name of dataset to query - used only for tracking purposes. - reward_model (`Optional[str]`, *optional*, defaults to `"sentiment-analysis:lvwerra/distilbert-imdb"`): - Reward model to use - used only for tracking purposes. - remove_unused_columns (`bool`, *optional*, defaults to `True`): - Remove unused columns from the dataset. - tracker_kwargs (`JSONDict`, *optional*, defaults to `{}`): - Keyword arguments for the tracker (e.g. `python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'`. - accelerator_kwargs (`JSONDict`, *optional*, defaults to `{}`): - Keyword arguments for the accelerator. - project_kwargs (`JSONDict`, *optional*, defaults to `{}`): - Keyword arguments for the accelerator project config (e.g. `logging_dir`). - tracker_project_name (`str`, *optional*, defaults to `"trl"`): - Name of project to use for tracking. - push_to_hub_if_best_kwargs (`JSONDict`, *optional*, defaults to `{}`): - Keyword arguments for pushing model to the hub during training (e.g. repo_id). - steps (`int`, *optional*, defaults to `20000`): - Number of training steps. - learning_rate (`float`, *optional*, defaults to `1.41e-5`): - Learning rate for the optimizer. - adap_kl_ctrl (`bool`, *optional*, defaults to `True`): - Use adaptive KL control, otherwise linear. - init_kl_coef (`Optional[float]`, *optional*, defaults to `0.2`): - Initial KL penalty coefficient (used for adaptive and linear control). - kl_penalty (`Literal["kl", "abs", "mse", "full"]`, *optional*, defaults to `"kl"`): - kl penalty options. Possible values are: - - - `"kl"`: model_logp - ref_logp - - `"abs"`: abs(kl) - - `"mse"`: mean squared error mse(kl) - - `"full"`: the actual kl for all tokens in the distribution. - - target (`float`, *optional*, defaults to `6.0`): - Target KL value for adaptive KL control. - horizon (`float`, *optional*, defaults to `10000.0`): - Horizon for adaptive KL control. - gamma (`float`, *optional*, defaults to `1.0`): - Gamma parameter for advantage calculation. - lam (`float`, *optional*, defaults to `0.95`): - Lambda parameter for advantage calculation. + reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the reward model. + num_ppo_epochs (`int`, *optional*, defaults to `4`): + Number of epochs to train. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whether to whiten the rewards. + kl_coef (`float`, *optional*, defaults to `0.05`): + KL coefficient. cliprange (`float`, *optional*, defaults to `0.2`): - Range for clipping in PPO policy gradient loss. - cliprange_value (`float`, *optional*, defaults to `0.2`): - Range for clipping values in loss calculation. + Clip range. vf_coef (`float`, *optional*, defaults to `0.1`): - Scaling factor for value loss. - batch_size (`int`, *optional*, defaults to `128`): - Number of samples per optimisation step. - forward_batch_size (`Optional[int]`, *optional*, defaults to `None`): - DEPRECATED: use `mini_batch_size` instead, which does the same thing. - mini_batch_size (`int`, *optional*, defaults to `128`): - Number of samples optimized in each mini batch. - gradient_accumulation_steps (`int`, *optional*, defaults to `1`): - Number of gradient accumulation steps. - world_size (`Optional[int]`, *optional*, defaults to `None`): - Number of processes to use for distributed training. - ppo_epochs (`int`, *optional*, defaults to `4`): - Number of optimisation epochs per batch of samples. - optimize_device_cache (`bool`, *optional*, defaults to `False`): - Optimize device cache for slightly more memory-efficient training. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the PPO optimization loop early is the KL too high. - target_kl (`float`, *optional*, defaults to `1.0`): - Stop early if we exceed this value by over 50%. - compare_steps (`int`, *optional*, defaults to `1`): - Compare the current step with the previous `compare_steps` steps. - ratio_threshold (`float`, *optional*, defaults to `10.0`): - Skip mini-batches with high PPO ratios that can cause loss spikes. - use_score_scaling (`bool`, *optional*, defaults to `False`): - Use score scaling. - use_score_norm (`bool`, *optional*, defaults to `False`): - Use score normalization. Only applicable if `use_score_scaling` is True. - score_clip (`Optional[float]`, *optional*, defaults to `None`): - Score clipping. - whiten_rewards (`bool`, *optional*, defaults to `False`): - Whiten the rewards before computing advantages. - is_encoder_decoder (`Optional[bool]`, *optional*, defaults to `None`): - When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, - you need to specify if the model returned by the callable is an encoder-decoder model. - is_peft_model (`Optional[bool]`, *optional*, defaults to `None`): - Whether the model is a PEFT model. - backward_batch_size (`Optional[int]`, *optional*, defaults to `None`): - Number of samples optimized in an `optimizer.step()` call. - global_backward_batch_size (`Optional[int]`, *optional*, defaults to `None`): - Effective `backward_batch_size` across all processes. - global_batch_size (`Optional[int]`, *optional*, defaults to `None`): - Effective `batch_size` across all processes. - dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): - Number of processes to use for processing the dataset. + Value function coefficient. + cliprange_value (`float`, *optional*, defaults to `0.2`): + Clip range for the value function. + gamma (`float`, *optional*, defaults to `1.0`): + Discount factor. + lam (`float`, *optional*, defaults to `0.95`): + Lambda value for GAE. """ - exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] - seed: int = 0 - log_with: Optional[Literal["wandb", "tensorboard"]] = None - task_name: Optional[str] = None - model_name: str = "gpt2" - query_dataset: str = "stanfordnlp/imdb" - reward_model: str = "sentiment-analysis:lvwerra/distilbert-imdb" - remove_unused_columns: bool = True - tracker_kwargs: JSONDict = field(default_factory=dict) - accelerator_kwargs: JSONDict = field(default_factory=dict) - project_kwargs: JSONDict = field(default_factory=dict) - tracker_project_name: str = "trl" - push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) - steps: int = 20000 - learning_rate: float = 1.41e-5 - adap_kl_ctrl: bool = True - init_kl_coef: float = 0.2 - kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl" - target: float = 6.0 - horizon: float = 10000.0 - gamma: float = 1.0 - lam: float = 0.95 + exp_name: str = os.path.basename(__file__)[: -len(".py")] + reward_model_path: str = "EleutherAI/pythia-160m" + num_ppo_epochs: int = 4 + whiten_rewards: bool = False + kl_coef: float = 0.05 cliprange: float = 0.2 - cliprange_value: float = 0.2 vf_coef: float = 0.1 - batch_size: int = 128 - forward_batch_size: Optional[int] = None - mini_batch_size: int = 128 - gradient_accumulation_steps: int = 1 - world_size: tyro.conf.Suppress[int] = None - ppo_epochs: int = 4 - max_grad_norm: Optional[float] = None - optimize_cuda_cache: Optional[bool] = None - optimize_device_cache: bool = False - early_stopping: bool = False - target_kl: float = 1.0 - compare_steps: int = 1 - ratio_threshold: float = 10.0 - use_score_scaling: bool = False - use_score_norm: bool = False - score_clip: Optional[float] = None - whiten_rewards: bool = False - gradient_checkpointing: bool = False - is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None - is_peft_model: Optional[tyro.conf.Suppress[bool]] = None - backward_batch_size: tyro.conf.Suppress[int] = None - global_backward_batch_size: Optional[tyro.conf.Suppress[int]] = None - global_batch_size: tyro.conf.Suppress[int] = None - dataset_num_proc: Optional[int] = None - - if optimize_cuda_cache is not None: - warnings.warn( - "The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead." - ) - - if optimize_device_cache is True: - raise ValueError("Both `optimize_device_cache` and `optimize_cuda_cache` were provided") - - optimize_device_cache = optimize_cuda_cache - - def __post_init__(self): - warnings.warn( - "`PPOConfig` is deprecated and will be removed in the future. Please use `PPOv2Config` with `PPOv2Trainer` instead.", - FutureWarning, - ) - if self.forward_batch_size is not None: - warnings.warn( - "Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization." - ) - self.mini_batch_size = self.forward_batch_size - - self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps - exact_div( - self.batch_size, - self.backward_batch_size, - "`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`", - ) - - # check if wandb is installed - if self.log_with == "wandb": - # raise error if wandb is not installed - if not is_wandb_available(): - raise ImportError( - "Please install wandb to use wandb logging. You can do this by running `pip install wandb`." - ) - - self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size)) - assert self.kl_penalty in ["kl", "abs", "mse", "full"] - - def to_dict(self): - output_dict = {} - for key, value in self.__dict__.items(): - output_dict[key] = value - return flatten_dict(output_dict) + cliprange_value: float = 0.2 + gamma: float = 1.0 + lam: float = 0.95 diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 42f916ddde..e491b0622a 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,1479 +11,703 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect + +import gc import math import os +import textwrap import time -import typing -import warnings -from contextlib import nullcontext -from typing import Callable, List, Optional, Union +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union -import datasets import numpy as np +import pandas as pd import torch +import torch.nn as nn import torch.nn.functional as F from accelerate import Accelerator -from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available +from accelerate.utils import broadcast, gather_object from datasets import Dataset -from huggingface_hub import whoami -from packaging import version -from torch.optim import Adam +from torch.utils.data import DataLoader from transformers import ( - DataCollatorForLanguageModeling, - PreTrainedTokenizer, + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, PreTrainedTokenizerBase, - PreTrainedTokenizerFast, - is_torch_npu_available, - is_torch_xpu_available, -) - -from ..core import ( - WANDB_PADDING, - PPODecorators, - clip_by_value, - convert_to_scalar, - entropy_from_logits, - flatten_dict, - logprobs_from_logits, - masked_mean, - masked_var, - masked_whiten, - set_seed, - stack_dicts, - stats_to_np, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainerControl, + is_wandb_available, ) -from ..import_utils import is_torch_greater_2_0 -from ..models import ( - SUPPORTED_ARCHITECTURES, - PreTrainedModelWrapper, - create_reference_model, - unwrap_model_for_generation, +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback + +from ..core import masked_mean, masked_whiten +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + prepare_deepspeed, + print_rich_table, + truncate_response, ) -from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments - - -if is_deepspeed_available(): - import deepspeed - -MODEL_CARD_TEMPLATE = """--- -license: apache-2.0 -library_name: transformers -tags: -- trl -- ppo -- transformers -- reinforcement-learning ---- - -# {model_name} - -This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to - guide the model outputs according to a value, function, or human feedback. The model can be used for text generation. - -## Usage - -To use this model for inference, first install the TRL library: - -```bash -python -m pip install trl -``` - -You can then generate text as follows: - -```python -from transformers import pipeline - -generator = pipeline("text-generation", model="{model_id}") -outputs = generator("Hello, my llama is cute") -``` - -If you want to use the model for training or to obtain the outputs from the value head, load the model as follows: - -```python -from transformers import AutoTokenizer -from trl import AutoModelForCausalLMWithValueHead - -tokenizer = AutoTokenizer.from_pretrained("{model_id}") -model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}") - -inputs = tokenizer("Hello, my llama is cute", return_tensors="pt") -outputs = model(**inputs, labels=inputs["input_ids"]) -``` -""" - - -class PPOTrainer(BaseTrainer): - """ - The PPOTrainer uses Proximal Policy Optimization to optimise language models. - Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: - https://github.com/openai/summarize-from-feedback - - Attributes: - **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more - details. - **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head. - Check the documentation of `PreTrainedModelWrapper` for more details. - **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face - transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper` - for more details. If no reference model is provided, the trainer will create a reference model with the same - architecture as the model to be optimized with shared layers. - **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the - data. Check the documentation of `transformers.PreTrainedTokenizer` and - `transformers.PreTrainedTokenizerFast` for more details. - **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging - Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be - created outside the trainer users needs to design their own dataloader and make sure the batch - size that is used is the same as the one specified in the configuration object. - **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is - provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration - object. - **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and - passed along the dataloader - **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference - model, if no reference model is passed. If no number is provided, all the layers will be shared. - **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. - """ - - _tag_names = ["trl", "ppo"] - - def __init__( - self, - config: Optional[PPOConfig] = None, - model: Optional[PreTrainedModelWrapper] = None, - ref_model: Optional[PreTrainedModelWrapper] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - data_collator: Optional[typing.Callable] = None, - num_shared_layers: Optional[int] = None, - lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - training_data_collator: Optional[typing.Callable] = None, - ): - """ - Initialize PPOTrainer. - - Args: - config (`PPOConfig`): - Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details. - model (`PreTrainedModelWrapper`): - Hugging Face transformer model with a value head. - ref_model (`PreTrainedModelWrapper`): - Hugging Face transformer model with a casual language modelling head. Used for KL penalty - tokenizer (`transformers.PreTrainedTokenizerBase`): - Hugging Face tokenizer - dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]): - PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset - will be preprocessed by removing the columns that are not used by the model. If none is passed, - a warning will be raised in a multi-GPU setting. - optimizer (`Optional[torch.optim.Optimizer]`): - Optimizer used for training. If `None`, the `Adam` is used as default. - data_collator (Optional[function]): - Data collator function that is going to be used for `prepare_dataloader` method. Note this collator - is different from the one we use for training. Pass a valid `training_data_collator` instead. - num_shared_layers (Optional[int]): - Number of shared layers between the model and the reference model. If `None`, all layers are shared. - used only if `ref_model` is `None`. - lr_scheduler (`Optional[torch.optim.lr_scheduler]`): - Learning rate scheduler used for training. - training_data_collator (Optional[function]): - Custom data collator used for training. - """ - warnings.warn( - "`PPOTrainer` is deprecated and will be removed in trl v0.12. Please use `PPOv2Trainer` instead.", - FutureWarning, - ) - super().__init__(config) - - # initial seed for reproducible experiments - set_seed(config.seed) - - # Step 0: check positional arguments validity - if not isinstance(config, PPOConfig): - raise ValueError(f"config must be a PPOConfig, got {type(config)}") - if not isinstance(tokenizer, (PreTrainedTokenizerBase)): - raise ValueError( - f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" - ) - if not isinstance(model, (SUPPORTED_ARCHITECTURES)): - raise ValueError( - f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" - ) - # Step 1: Initialize Accelerator - self.accelerator = Accelerator( - log_with=config.log_with, - gradient_accumulation_steps=config.gradient_accumulation_steps, - project_config=ProjectConfiguration(**config.project_kwargs), - **config.accelerator_kwargs, - ) - - # Step 1.1 Runtime variables filled by the accelerator - config.world_size = self.accelerator.num_processes - config.global_backward_batch_size = config.backward_batch_size * config.world_size - config.global_batch_size = config.batch_size * config.world_size - - self.model = model - self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) - self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") - self.is_peft_model = getattr(self.model, "is_peft_model", False) - config.is_encoder_decoder = self.is_encoder_decoder - config.is_peft_model = self.is_peft_model - - is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" - self.accelerator.init_trackers( - config.tracker_project_name, - config=dict(trl_ppo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), - init_kwargs=config.tracker_kwargs, - ) - self.is_using_text_environment = getattr(config, "use_text_environment", False) - - if isinstance(ref_model, SUPPORTED_ARCHITECTURES): - self.ref_model = ref_model - if num_shared_layers is not None: - warnings.warn( - "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " - "model and the reference model and no layers are shared.", - UserWarning, - ) - elif ref_model is None and not self.is_peft_model: - self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers) - elif self.is_peft_model: - self.ref_model = None - else: - raise ValueError( - f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " - f"architectures are: {SUPPORTED_ARCHITECTURES} " - ) - self.optional_peft_ctx = ( - self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter - if self.is_peft_model - else nullcontext - ) +from .ppo_config import PPOConfig +from .utils import generate_model_card - if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): - raise ValueError( - "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast" - ) - self.tokenizer = tokenizer - - if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)): - raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset") - elif dataset is None: - warnings.warn( - "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", - UserWarning, - ) - self.dataset = dataset - self._signature_columns = None - if self.dataset is not None: - self.dataloader = self.prepare_dataloader(self.dataset, data_collator) - elif self.dataset is None and self.accelerator.num_processes > 1: - warnings.warn( - "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" - " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`" - " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please " - " refer to the documentation for more details.", - UserWarning, - ) - self.dataloader = None - else: - self.dataloader = None - # Step 3: Initialize optimizer and data collator - if training_data_collator is None: - self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) - else: - self.data_collator = training_data_collator - if optimizer is None: - self.optimizer = Adam( - filter(lambda p: p.requires_grad, self.model.parameters()), - lr=self.config.learning_rate, - ) - else: - self.optimizer = optimizer - - self.lr_scheduler = lr_scheduler - if self.lr_scheduler is not None: - lr_scheduler_class = ( - torch.optim.lr_scheduler._LRScheduler - if not is_torch_greater_2_0() - else torch.optim.lr_scheduler.LRScheduler - ) +if is_wandb_available(): + import wandb - if not isinstance(self.lr_scheduler, lr_scheduler_class): - raise ValueError( - "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)" - ) - if self.config.adap_kl_ctrl: - self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon) - else: - self.kl_ctl = FixedKLController(self.config.init_kl_coef) +INVALID_LOGPROB = 1.0 - # Safety checkers for DS integration - is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( - self.accelerator.state, "deepspeed_plugin" - ) - if config.gradient_checkpointing: - self.model.gradient_checkpointing_enable() +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, value_model) -> None: + super().__init__() + self.policy = policy + self.value_model = value_model + self.critic_backbone = getattr(value_model, value_model.base_model_prefix) - if hasattr(self.model, "enable_input_require_grads"): - self.model.enable_input_require_grads() - else: - # For backward compatibility with older versions of transformers - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - ( - self.model, - self.optimizer, - self.data_collator, - self.dataloader, - self.lr_scheduler, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.data_collator, - self.dataloader, - self.lr_scheduler, + def forward(self, **kwargs): + output = self.critic_backbone( + **kwargs, ) - if is_deepspeed_used: - # Quantized models are already set on the correct device - if not self.is_peft_model and not ( - getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) - or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False) - ): - self.ref_model = self._prepare_deepspeed(self.ref_model) - else: - self.ref_model = self.accelerator.prepare(self.ref_model) - - # In a distributed setup, only logging needs to be performed on the main process - # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html - # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 - self.is_distributed = self.accelerator.num_processes > 1 - - # init the current step - self.current_step = 0 - - # init variables for pushing model to hub - if config.push_to_hub_if_best_kwargs: - if "repo_id" not in config.push_to_hub_if_best_kwargs: - raise ValueError("You have to specify repo_id in order to push the model to the hub!") - self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs - self.compare_step = 0 - self.highest_reward = torch.tensor(-float("inf")) - - # post process for PP - if not getattr(self.model, "is_sequential_parallel", False): - self.current_device = self.accelerator.device - else: - if is_torch_xpu_available(): - self.current_device = torch.device("xpu:0") - elif is_torch_npu_available(): - self.current_device = torch.device("npu:0") - else: - self.current_device = torch.device("cuda:0") - - PPODecorators.optimize_device_cache = self.config.optimize_device_cache + logits = self.value_model.score(output.hidden_states[-1]) + return self.policy(**kwargs), logits - self.running = RunningMoments(self.accelerator) - - def _filter_kwargs(self, kwargs, target_func): - """ - filter the keyword arguments that are supported by the target function. - - Args: - kwargs (dict): - Keyword arguments - target_func (function): - Target function - """ - return {k: v for k, v in kwargs.items() if k in inspect.signature(target_func).parameters.keys()} - def prepare_dataloader(self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None): - """ - Prepare the dataloader for training. - - Args: - dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): - PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset - will be preprocessed by removing the columns that are not used by the model. - data_collator (Optional[function]): - Data collator function. - - Returns: - `torch.utils.data.DataLoader`: PyTorch dataloader - """ - if isinstance(dataset, Dataset): - dataset = self._remove_unused_columns(dataset) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=self.config.batch_size, - collate_fn=data_collator, - shuffle=True, - drop_last=True, - ) - return dataloader - - # Adapted from transformers.Trainer._set_signature_columns_if_needed - def _set_signature_columns_if_needed(self): - if self._signature_columns is None: - # Inspect model forward signature to keep only the arguments it accepts. - signature = inspect.signature(self.model.forward) - self._signature_columns = list(signature.parameters.keys()) - # label => sentiment | we need query and response for logging purpose - self._signature_columns += ["label", "query", "response"] - - # Adapted from transformers.Trainer._remove_unused_columns - def _remove_unused_columns(self, dataset: "Dataset"): - if not self.config.remove_unused_columns: - return dataset - self._set_signature_columns_if_needed() - signature_columns = self._signature_columns - - ignored_columns = list(set(dataset.column_names) - set(signature_columns)) - - columns = [k for k in signature_columns if k in dataset.column_names] - - if version.parse(datasets.__version__) < version.parse("1.4.0"): - dataset.set_format( - type=dataset.format["type"], - columns=columns, - format_kwargs=dataset.format["format_kwargs"], - ) - return dataset - else: - return dataset.remove_columns(ignored_columns) +class PPOTrainer(Trainer): + _tag_names = ["trl", "ppo"] - def generate( + def __init__( self, - query_tensor: Union[torch.Tensor, List[torch.Tensor]], - length_sampler: Optional[Callable] = None, - batch_size: int = 4, - return_prompt: bool = True, - generate_ref_response: bool = False, - **generation_kwargs, - ): - """ - Generate response with the model given the query tensor. - call the `generate` method of the model. - - Args: - query_tensor (`torch.LongTensor`): - A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`). - length_sampler (`Callable`, *optional*): - Callable that returns the number of newly generated tokens. - batch_size (`int`, *optional): - Batch size used for generation, defaults to `4`. - return_prompt (`bool`, *optional*): - If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. - generate_ref_response (`bool`, *optional*): - If set to `True` the reference response is also generated, defaults to `False`. - generation_kwargs (dict[str, Any]): - Keyword arguments for generation. - - Returns: - `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. - """ - if generate_ref_response: - ref_model = self.model if self.is_peft_model else self.ref_model - if isinstance(query_tensor, List): - response = self._generate_batched( - self.model, - query_tensor, - length_sampler=length_sampler, - batch_size=batch_size, - return_prompt=return_prompt, - **generation_kwargs, + config: PPOConfig, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ], + policy: nn.Module, + ref_policy: nn.Module, + reward_model: nn.Module, + train_dataset: Dataset, + value_model: Optional[nn.Module] = None, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + # less commonly used + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: Optional[List[TrainerCallback]] = None, + ) -> None: + if ref_policy is policy: + raise ValueError( + "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " + "same as `policy`, you must mass a copy of it, or `None` if you use peft." ) - if generate_ref_response: - ref_response = self._generate_batched( - ref_model, - query_tensor, - length_sampler=length_sampler, - batch_size=batch_size, - return_prompt=return_prompt, - **generation_kwargs, - ) - - else: - if len(query_tensor.shape) == 2: - raise ValueError( - "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)" - ) - - if length_sampler is not None: - generation_kwargs["max_new_tokens"] = length_sampler() - - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) - - if generate_ref_response: - with unwrap_model_for_generation( - ref_model, self.accelerator, is_peft_model=self.is_peft_model - ) as unwrapped_model: - ref_response = unwrapped_model.generate( - input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs - ) - - if not return_prompt and not self.is_encoder_decoder: - response = response[:, query_tensor.shape[0] :] - if generate_ref_response: - ref_response = ref_response[:, query_tensor.shape[0] :] - - if generate_ref_response: - return response, ref_response - return response - - def _generate_batched( - self, - model: PreTrainedModelWrapper, - query_tensors: List[torch.Tensor], - length_sampler: Optional[Callable] = None, - batch_size: int = 4, - return_prompt: bool = True, - pad_to_multiple_of: Optional[int] = None, - remove_padding: bool = True, - **generation_kwargs, - ): - outputs = [] - - padding_side_default = self.tokenizer.padding_side - if not self.is_encoder_decoder: - self.tokenizer.padding_side = "left" - - # in case we have fewer examples than bs - batch_size = min(len(query_tensors), batch_size) - - for i in range(0, len(query_tensors), batch_size): - if length_sampler is not None: - generation_kwargs["max_new_tokens"] = length_sampler() - - # prevent overflow if query tensors are not even multiple of bs - end_index = min(len(query_tensors), i + batch_size) - - batch = query_tensors[i:end_index] - batch_mask = [torch.ones_like(element) for element in batch] - inputs = {"input_ids": batch, "attention_mask": batch_mask} - - padded_inputs = self.tokenizer.pad( - inputs, - padding=True, - max_length=None, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ).to(self.current_device) - - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs) - - for generation, mask in zip(generations, padded_inputs["attention_mask"]): - if not self.is_encoder_decoder: - output = generation[(1 - mask).sum() :] # remove padding - else: - output = generation - if not return_prompt and not self.is_encoder_decoder: - output = output[(mask).sum() :] # remove prompt + self.args = config + args = config + self.processing_class = processing_class + self.policy = policy - if remove_padding and self.tokenizer.eos_token_id in output: - pad_mask = output == self.tokenizer.eos_token_id - pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() - output = output[: pad_start + 1] # keep the eos token at the end - - outputs.append(output) - - self.tokenizer.padding_side = padding_side_default - return outputs - - def _step_safety_checker( - self, - batch_size: int, - queries: List[torch.LongTensor], - responses: List[torch.LongTensor], - scores: List[torch.FloatTensor], - masks: Optional[List[torch.LongTensor]] = None, - ): - """ - Check if the input data is valid for training. - - Args: - batch_size (int): - Batch size from the config file. - queries (List[`torch.LongTensor`]): - List of tensors containing the encoded queries of shape (`query_length`) - responses (List[`torch.LongTensor`]): - List of tensors containing the encoded responses of shape (`response_length`) - scores (List[`torch.FloatTensor`]): - List of tensors containing the scores. - masks (List[`torch.LongTensor`], *optional*): - list of optional tensors containing the masks of shape (`response_length`) - - Returns: - `tuple`: The input processed data. - """ - for name, tensor_list in zip(["queries", "responses", "scores"], [queries, responses, scores]): - if not isinstance(tensor_list, list): - raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") - if not isinstance(tensor_list[0], torch.Tensor): - raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") - if batch_size is not None and len(tensor_list) != batch_size: - raise ValueError( - f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}" - ) - - # add queries, scores and responses on the correct device - queries = [tensor.to(self.current_device) for tensor in queries] - responses = [tensor.to(self.current_device) for tensor in responses] - scores = [tensor.to(self.current_device) for tensor in scores] - masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None - - # squeeze scores if needed - for i, score in enumerate(scores): - if score.dim() > 1: - raise ValueError(f"Scores must be 1-dimensional - got {score.dim()} for {score}") - elif score.dim() == 1: - scores[i] = score.squeeze() - - return queries, responses, scores, masks - - @PPODecorators.empty_device_cache() - def step( - self, - queries: List[torch.LongTensor], - responses: List[torch.LongTensor], - scores: List[torch.FloatTensor], - response_masks: Optional[List[torch.LongTensor]] = None, - ): - """ - Run a PPO optimisation step given a list of queries, model responses, and rewards. - - Args: - queries (List[`torch.LongTensor`]): - List of tensors containing the encoded queries of shape (`query_length`) - responses (List[`torch.LongTensor`]): - List of tensors containing the encoded responses of shape (`response_length`) - scores (List[`torch.FloatTensor`]): - List of tensors containing the scores. - response_masks (List[`torch.FloatTensor`], *optional*)): - List of tensors containing masks of the response tokens. - - Returns: - `dict[str, Any]`: A summary of the training statistics - """ - bs = self.config.batch_size - - queries, responses, scores, response_masks = self._step_safety_checker( - bs, queries, responses, scores, response_masks + self.policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to ) - scores = torch.tensor(scores, device=self.current_device) - if self.config.use_score_scaling: - # Score scaling - scores_mean, scores_std = self.running.update(scores) - tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) - score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps - if self.config.use_score_norm: - scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor - else: - scores /= score_scaling_factor - - if self.config.score_clip is not None: - # Score clipping - scores_dtype = scores.dtype - scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype) - - # if we want to push best model to the hub - if hasattr(self, "highest_reward"): - if self.compare_step % self.config.compare_steps == 0: - curr_mean_reward = scores.mean() - # if the best reward ever seen - if curr_mean_reward > self.highest_reward: - self.highest_reward = curr_mean_reward - # push model to hub - self.push_to_hub(**self.push_to_hub_kwargs) - self.compare_step += 1 - - timing = dict() - t0 = time.time() - - t = time.time() - - model_inputs = self.prepare_model_inputs(queries, responses) - - if self.is_distributed: - pad_first = self.tokenizer.padding_side == "left" - - model_inputs["input_ids"] = self.accelerator.pad_across_processes( - model_inputs["input_ids"], - dim=1, - pad_index=self.tokenizer.pad_token_id, - pad_first=pad_first, - ) - model_inputs["attention_mask"] = self.accelerator.pad_across_processes( - model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first - ) - if self.is_encoder_decoder: - model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes( - model_inputs["decoder_input_ids"], - dim=1, - pad_index=self.tokenizer.pad_token_id, - pad_first=pad_first, - ) - model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( - model_inputs["decoder_attention_mask"], - dim=1, - pad_index=0, - pad_first=pad_first, - ) - - model_inputs_names = list(model_inputs.keys()) - - full_kl_penalty = self.config.kl_penalty == "full" - - with torch.no_grad(): - all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( - self.model, - queries, - responses, - model_inputs, - response_masks=response_masks, - return_logits=full_kl_penalty, - ) - with self.optional_peft_ctx(): - ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( - self.model if self.is_peft_model else self.ref_model, - queries, - responses, - model_inputs, - return_logits=full_kl_penalty, - ) - - timing["time/ppo/forward_pass"] = time.time() - t - - with torch.no_grad(): - t = time.time() - if full_kl_penalty: - active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) - ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) - - rewards, non_score_reward, kls = self.compute_rewards( - scores, active_full_logprobs, ref_full_logprobs, masks - ) - else: - rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) - timing["time/ppo/compute_rewards"] = time.time() - t - - t = time.time() - values, advantages, returns = self.compute_advantages(values, rewards, masks) - timing["time/ppo/compute_advantages"] = time.time() - t - - # upcast to float32 to avoid dataset issues - batch_dict = { - "queries": queries, - "responses": responses, - "logprobs": all_logprobs.to(torch.float32), - "values": values.to(torch.float32), - "masks": masks, - "advantages": advantages, - "returns": returns, - } - batch_dict.update(model_inputs) - - t = time.time() - all_stats = [] - early_stop = False - for _ in range(self.config.ppo_epochs): - if early_stop: - break - b_inds = np.random.permutation(bs) - for backward_batch_start in range(0, bs, self.config.backward_batch_size): - backward_batch_end = backward_batch_start + self.config.backward_batch_size - backward_batch_inds = b_inds[backward_batch_start:backward_batch_end] - - for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size): - mini_batch_end = mini_batch_start + self.config.mini_batch_size - mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end] - mini_batch_dict = { - "logprobs": batch_dict["logprobs"][mini_batch_inds], - "values": batch_dict["values"][mini_batch_inds], - "masks": batch_dict["masks"][mini_batch_inds], - # hacks: the queries and responses are ragged. - "queries": [batch_dict["queries"][i] for i in mini_batch_inds], - "responses": [batch_dict["responses"][i] for i in mini_batch_inds], - "advantages": batch_dict["advantages"][mini_batch_inds], - "returns": batch_dict["returns"][mini_batch_inds], - } - for k in model_inputs_names: - mini_batch_dict[k] = batch_dict[k][mini_batch_inds] - with self.accelerator.accumulate(self.model): - model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} - - logprobs, logits, vpreds, _ = self.batched_forward_pass( - self.model, - mini_batch_dict["queries"], - mini_batch_dict["responses"], - model_inputs, - return_logits=True, - ) - train_stats = self.train_minibatch( - mini_batch_dict["logprobs"], - mini_batch_dict["values"], - logprobs, - logits, - vpreds, - mini_batch_dict["masks"], - mini_batch_dict["advantages"], - mini_batch_dict["returns"], - ) - all_stats.append(train_stats) - - # typically, early stopping is done at the epoch level - if self.config.early_stopping: - policykl = train_stats["policy/policykl"] - early_stop = self._early_stop(policykl) - if early_stop: - break - - timing["time/ppo/optimize_step"] = time.time() - t - - t = time.time() - train_stats = stack_dicts(all_stats) - - # reshape advantages/ratios such that they are not averaged. - train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0) - train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING) - train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0) - - stats = self.record_step_stats( - scores=scores, - logprobs=all_logprobs, - ref_logprobs=ref_logprobs, - non_score_reward=non_score_reward, - train_stats=train_stats, - kl_coef=self.kl_ctl.value, - masks=masks, - queries=queries, - responses=responses, - kls=kls, + self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.ref_policy = ref_policy + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.value_model = value_model + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches ) - # Gather/Reduce stats from all processes - if self.is_distributed: - stats = self.gather_stats(stats) - stats = stats_to_np(stats) - timing["time/ppo/calc_stats"] = time.time() - t - stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"] - - # Update the KL control - multiply the batch_size by the number of processes - self.kl_ctl.update( - stats["objective/kl"], - self.config.batch_size * self.accelerator.num_processes, + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size + + ######### + # setup model, optimizer, and others + ######### + for module in [policy, ref_policy, value_model, reward_model]: + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = processing_class.eos_token_id + self.model = PolicyAndValueWrapper(policy, value_model) + self.model.config = policy.config # needed for pushing to hub + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + ### trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) - # Log the total ppo time - timing["time/ppo/total"] = time.time() - t0 - stats.update(timing) - - # post-process stats for tensorboard and other loggers - if self.config.log_with != "wandb": - stats = convert_to_scalar(stats) - - if self.lr_scheduler is not None: - self.lr_scheduler.step() - - return stats - - def _early_stop(self, policykl): - r""" - Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed and - the optimization step is skipped. - This also handles the multi-gpu case where the policy KL is averaged across all processes. - - Args: - policy_kl (torch.Tensor): - the policy KL - - Returns: - `bool`: whether to early stop or not - """ - early_stop = False - if not self.config.early_stopping: - return early_stop - - if not self.is_distributed and policykl > 1.5 * self.config.target_kl: - self.optimizer.zero_grad() - early_stop = True - elif self.is_distributed: - import torch.distributed as dist - - # Wait for all processes to finish - dist.barrier() - - # all gather the policykl - dist.all_reduce(policykl, dist.ReduceOp.SUM) - policykl /= self.accelerator.num_processes - - if policykl > 1.5 * self.config.target_kl: - self.optimizer.zero_grad() - early_stop = True - return early_stop - - def gather_stats(self, stats): - """ - Gather stats from all processes. Useful in the context of distributed training. - - Args: - stats (dict[str, Any]): - a dictionary of stats to be gathered. The stats should contain torch tensors. - - Returns: - `dict[str, Any]`: A dictionary of stats with the tensors gathered. - """ - import torch.distributed as dist - - # Wait for all processes to finish - dist.barrier() - - for k, v in stats.items(): - if isinstance(v, torch.Tensor): - dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM) - v /= self.accelerator.num_processes - stats[k] = v - return stats - - def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): - if self.is_encoder_decoder: - input_data = self.data_collator( - [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries] - ).to(self.current_device) - - decoder_inputs = self.data_collator( - [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses] - ).to(self.current_device) - - input_data["decoder_input_ids"] = decoder_inputs["input_ids"] - input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 + ) else: - input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] - input_data = self.data_collator( - [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids] - ).to(self.current_device) - - input_data.pop("labels", None) # we don't want to compute LM losses - return input_data - - @PPODecorators.empty_device_cache() - def batched_forward_pass( - self, - model: PreTrainedModelWrapper, - queries: torch.Tensor, - responses: torch.Tensor, - model_inputs: dict, - return_logits: bool = False, - response_masks: Optional[torch.Tensor] = None, - ): - """ - Calculate model outputs in multiple batches. + self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.policy # save only the policy + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) - Args: - queries (`torch.LongTensor`): - List of tensors containing the encoded queries, shape (`batch_size`, `query_length`) - responses (`torch.LongTensor`): - List of tensors containing the encoded responses, shape (`batch_size`, `response_length`) - return_logits (`bool`, *optional*, defaults to `False`): - Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption. - - Returns: - (tuple): - - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses, - shape (`batch_size`, `response_length`) - - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses, - shape (`batch_size`, `response_length`) - - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`) - """ - bs = len(queries) - fbs = self.config.mini_batch_size - all_logprobs = [] - all_logits = [] - all_masks = [] - all_values = [] - - model.eval() - - for i in range(math.ceil(bs / fbs)): - input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} - query_batch = queries[i * fbs : (i + 1) * fbs] - response_batch = responses[i * fbs : (i + 1) * fbs] - if response_masks is not None: - response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] - logits, _, values = model(**input_kwargs) - - if self.is_encoder_decoder: - input_ids = input_kwargs["decoder_input_ids"] - attention_mask = input_kwargs["decoder_attention_mask"] + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) else: - input_ids = input_kwargs["input_ids"] - attention_mask = input_kwargs["attention_mask"] - - logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) - masks = torch.zeros_like(attention_mask) - masks[:, :-1] = attention_mask[:, 1:] - - for j in range(len(query_batch)): - if self.is_encoder_decoder: - # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models - start = 1 - end = attention_mask[j, :].sum() - 1 - else: - start = len(query_batch[j]) - 1 # logprobs starts from the second query token - if attention_mask[j, 0] == 0: # offset left padding - start += attention_mask[j, :].nonzero()[0] - end = start + len(response_batch[j]) - - masks[j, :start] = 0 - masks[j, end:] = 0 - if response_masks is not None: - masks[j, start:end] = masks[j, start:end] * response_masks_batch[j] - - if return_logits: - all_logits.append(logits) + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) else: - del logits - all_values.append(values) - all_logprobs.append(logprobs) - all_masks.append(masks) - - return ( - torch.cat(all_logprobs), - torch.cat(all_logits)[:, :-1] if return_logits else None, - torch.cat(all_values)[:, :-1], - torch.cat(all_masks)[:, :-1], - ) - - @PPODecorators.empty_device_cache() - def train_minibatch( - self, - old_logprobs: torch.FloatTensor, - values: torch.FloatTensor, - logprobs: torch.FloatTensor, - logits: torch.FloatTensor, - vpreds: torch.FloatTensor, - mask: torch.LongTensor, - advantages: torch.FloatTensor, - returns: torch.FloatTensor, - ): - """ - Train one PPO minibatch - - Args: - logprobs (`torch.FloatTensor`): - Log probabilities of the model, shape [mini_batch_size, response_length] - values (`torch.FloatTensor`): - Values of the value head, shape [mini_batch_size, response_length] - query (`torch.LongTensor`): - Encoded queries, shape [mini_batch_size, query_length] - response (`torch.LongTensor`): - Encoded responses, shape [mini_batch_size, response_length] - model_input (`torch.LongTensor`): - Concatenated queries and responses, shape [mini_batch_size, query_length+response_length] - - Returns: - train_stats (dict[str, `torch.Tensor`]): - Dictionary of training statistics - """ - self.model.train() - loss_p, loss_v, train_stats = self.loss( - old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns - ) - loss = loss_p + loss_v - self.accelerator.backward(loss) - if self.config.max_grad_norm is not None: - if self.accelerator.sync_gradients: - self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm) - self.optimizer.step() - # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation - # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code - self.optimizer.zero_grad() - return train_stats - - def compute_rewards( - self, - scores: torch.FloatTensor, - logprobs: torch.FloatTensor, - ref_logprobs: torch.FloatTensor, - masks: torch.LongTensor, - ): - """ - Compute per token rewards from scores and KL-penalty. - - Args: - scores (`torch.FloatTensor`): - Scores from the reward model, shape (`batch_size`) - logprobs (`torch.FloatTensor`): - Log probabilities of the model, shape (`batch_size`, `response_length`) - ref_logprobs (`torch.FloatTensor`): - Log probabilities of the reference model, shape (`batch_size`, `response_length`) - - Returns: - `torch.FloatTensor`: Per token rewards, shape (`batch_size`, `response_length`) - `torch.FloatTensor`: Non score rewards, shape (`batch_size`, `response_length`) - `torch.FloatTensor`: KL penalty, shape (`batch_size`, `response_length`) - """ - rewards, non_score_rewards, kls = [], [], [] - for score, logprob, ref_logprob, mask in zip(scores, logprobs, ref_logprobs, masks): - # compute KL penalty (from difference in logprobs) - kl = self._kl_penalty(logprob, ref_logprob) - kls.append(kl) - non_score_reward = -self.kl_ctl.value * kl - non_score_rewards.append(non_score_reward) - reward = non_score_reward.clone() - last_non_masked_index = mask.nonzero()[-1] - - # reward is preference model score + KL penalty - reward[last_non_masked_index] += score - rewards.append(reward) - return torch.stack(rewards), torch.stack(non_score_rewards), torch.stack(kls) - - def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor) -> torch.FloatTensor: - if self.config.kl_penalty == "kl": - return logprob - ref_logprob - - if self.config.kl_penalty == "abs": - return (logprob - ref_logprob).abs() - - if self.config.kl_penalty == "mse": - return 0.5 * (logprob - ref_logprob).square() - - if self.config.kl_penalty == "full": - # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459 - return F.kl_div(ref_logprob, logprob, log_target=True, reduction="none").sum(-1) - - raise NotImplementedError - - def compute_advantages( - self, - values: torch.FloatTensor, - rewards: torch.FloatTensor, - mask: torch.FloatTensor, - ): - lastgaelam = 0 - advantages_reversed = [] - gen_len = rewards.shape[-1] - - values = values * mask - rewards = rewards * mask - - if self.config.whiten_rewards: - rewards = masked_whiten(rewards, mask, shift_mean=False) - - for t in reversed(range(gen_len)): - nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 - delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] - lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + values = [] + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + query_responses, logitss = batch_generation( + unwrapped_model.policy, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) - returns = advantages + values - advantages = masked_whiten(advantages, mask) - advantages = advantages.detach() - return values, advantages, returns + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del logits, all_logprob + torch.cuda.empty_cache() + + ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) - def loss( - self, - old_logprobs: torch.FloatTensor, - values: torch.FloatTensor, - logits: torch.FloatTensor, - vpreds: torch.FloatTensor, - logprobs: torch.FloatTensor, - mask: torch.LongTensor, - advantages: torch.FloatTensor, - returns: torch.FloatTensor, - ): - """ - Calculate policy and value losses. + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + unwrapped_value_model = accelerator.unwrap_model(model).value_model + full_value, _, _ = get_reward( + unwrapped_value_model, query_response, processing_class.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) - Args: - old_logprobs (`torch.FloatTensor`): - Log probabilities of the model, shape (`batch_size`, `response_length`) - values (`torch.FloatTensor`): - Values of the value head, shape (`batch_size`, `response_length`) - rewards (`torch.FloatTensor`): - Rewards from the reward model, shape (`batch_size`, `response_length`) - logits (`torch.FloatTensor`): - Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`) - v_pred (`torch.FloatTensor`): - Values of the value head, shape (`batch_size`, `response_length`) - logprobs (`torch.FloatTensor`): - Log probabilities of the model, shape (`batch_size`, `response_length`) - """ + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + del (logprob, ref_logprob, full_value, value, score, unwrapped_model) + torch.cuda.empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + + # 4. compute rewards + kl = logprobs - ref_logprobs + non_score_reward = -args.kl_coef * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_all_logprobs = F.log_softmax(logits, dim=-1) + new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) + vf_clipfrac = masked_mean( + (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] + ) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + args.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + vf_clipfrac + ) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, + vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, + pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, + mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + torch.cuda.empty_cache() + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + rlhf_reward = mean_non_score_reward + scores.mean() + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() + metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) - vpredclipped = clip_by_value( - vpreds, - values - self.config.cliprange_value, - values + self.config.cliprange_value, + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + torch.cuda.empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + torch.cuda.empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + logprobs, + ref_logprobs, + values, + sequence_lengths, + contain_eos_token, + sequence_lengths_p1, + response_idxs, + padding_mask, + padding_mask_p1, + rewards, + actual_start, + actual_end, + advantages, + returns, + ) + torch.cuda.empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, ) - vf_losses1 = (vpreds - returns) ** 2 - vf_losses2 = (vpredclipped - returns) ** 2 - vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask) - vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask) - - ratio = torch.exp(logprobs - old_logprobs) - - pg_losses = -advantages * ratio - pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange) - - pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) - pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask) - - loss = pg_loss + self.config.vf_coef * vf_loss + table = defaultdict(list) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model.policy, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) - avg_ratio = masked_mean(ratio, mask).item() - if avg_ratio > self.config.ratio_threshold: - warnings.warn( - f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. Skipping batch." - ) - pg_loss = pg_loss * 0.0 - vf_loss = vf_loss * 0.0 - loss = loss * 0.0 - - entropy = masked_mean(entropy_from_logits(logits), mask) - - approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask) - policykl = masked_mean(old_logprobs - logprobs, mask) - - return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask) - value_mean, value_var = masked_mean(values, mask), masked_var(values, mask) - - stats = dict( - loss=dict(policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach()), - policy=dict( - entropy=entropy.detach(), - approxkl=approxkl.detach(), - policykl=policykl.detach(), - clipfrac=pg_clipfrac.detach(), - advantages=advantages.detach(), - advantages_mean=masked_mean(advantages, mask).detach(), - ratio=ratio.detach(), - ), - returns=dict(mean=return_mean.detach(), var=return_var.detach()), - val=dict( - vpred=masked_mean(vpreds, mask).detach(), - error=masked_mean((vpreds - returns) ** 2, mask).detach(), - clipfrac=vf_clipfrac.detach(), - mean=value_mean.detach(), - var=value_var.detach(), - ), - ) - return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats) + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - def record_step_stats(self, kl_coef: float, **data): - """ - Record training step statistics. + if sampling: + break + df = pd.DataFrame(table) + if self.accelerator.is_main_process: + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb - Args: - kl_coef (`float`): - KL coefficient - data (`dict`): - Dictionary of training step data - - Returns: - stats (`dict`): - Dictionary of training step statistics - """ - mask = data.pop("masks") - - kls = data.pop("kls") - kl_list = ((kls) * mask).sum(axis=-1) - mean_kl = kl_list.mean() - mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean() - - mean_non_score_reward = masked_mean( - data["non_score_reward"], mask - ) # non_score_reward is size `batch_size`, `response_length` - mean_scores = data["scores"].mean() # scores is size `batch_size` - std_scores = data["scores"].std() - - if mean_kl.item() < -1.0: - # warn users - warnings.warn( - f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for failed training." - " sometimes this happens because the generation kwargs are not correctly set. Please make sure" - " that the generation kwargs are set correctly, or review your training hyperparameters." - ) + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) - stats = { - "objective/kl": mean_kl, - "objective/kl_dist": kl_list, - "objective/logprobs": data["logprobs"], - "objective/ref_logprobs": data["ref_logprobs"], - "objective/kl_coef": kl_coef, - "objective/entropy": mean_entropy, - "ppo/mean_non_score_reward": mean_non_score_reward, - "ppo/mean_scores": mean_scores, - "ppo/std_scores": std_scores, - } - - # Log text properties - query_lens = torch.tensor([len(query) for query in data["queries"]], dtype=torch.float) - response_lens = torch.tensor([len(response) for response in data["responses"]], dtype=torch.float) - - stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item() - stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item() - stats["tokens/queries_dist"] = query_lens.cpu().numpy() - stats["tokens/responses_len_mean"] = torch.mean(response_lens).cpu().numpy().item() - stats["tokens/responses_len_std"] = torch.std(response_lens).cpu().numpy().item() - stats["tokens/responses_dist"] = response_lens.cpu().numpy() - - for k, v in data["train_stats"].items(): - stats[f"ppo/{k}"] = torch.mean(v, axis=0) - stats["ppo/val/var_explained"] = 1 - stats["ppo/val/error"] / stats["ppo/returns/var"] - return stats - - def log_stats( + def create_model_card( self, - stats: dict, - batch: dict, - rewards: List[torch.FloatTensor], - columns_to_log: typing.Iterable[str] = ("query", "response"), + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, ): """ - A function that logs all the training stats. Call it at the end of each epoch. - - Args: - stats (dict[str, Any]): - A dictionary of training stats. - batch (dict[str, Any]): - A dictionary of batch data, this contains the queries and responses. - rewards (`List[torch.FloatTensor]`): - A tensor of rewards. - """ - - # all gather stats - if not isinstance(rewards, torch.Tensor): - rewards = torch.tensor(rewards).to(self.current_device) - rewards = self.accelerator.gather(rewards).flatten() - - if self.config.log_with == "wandb": - import wandb - - if any(column_to_log not in batch.keys() for column_to_log in columns_to_log): - raise ValueError(f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}.") - - batch_list = [batch[column_to_log] for column_to_log in columns_to_log] - if self.is_distributed: - gathered_batch_list = [] - for b in batch_list: - flattened = gather_object(b) - gathered_batch_list.append(flattened) - batch_list = gathered_batch_list - - # Log only if we are in the main process - if self.accelerator.is_main_process: - logs = {} - - # Log stats - if "query" not in batch.keys() and "response" not in batch.keys(): - # warn the user that the game logs will not be logged - warnings.warn( - "The game logs will not be logged because the batch does not contain the keys 'query' and " - "'response'. " - ) - elif self.config.log_with == "wandb": - table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] - logs.update({"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)}) - - logs.update(stats) - - # manually cast in fp32 for bf16 torch tensors - for k, v in logs.items(): - if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16: - logs[k] = v.float() - - logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item() - logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item() - logs["env/reward_dist"] = rewards.cpu().numpy() - - if self.config.log_with == "tensorboard": - # update the current step - self.current_step += 1 - - self.accelerator.log( - logs, - step=self.current_step if self.config.log_with == "tensorboard" else None, - ) - - def create_model_card(self, path: str, model_name: Optional[str] = "TRL Model") -> None: - """Creates and saves a model card for a TRL model. + Creates a draft of a model card using the information available to the `Trainer`. Args: - path (`str`): The path to save the model card to. - model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`. + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - try: - user = whoami()["name"] - # handle the offline case - except Exception: - warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + if not self.is_world_process_zero(): return - if not os.path.exists(path): - os.makedirs(path) - - model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") - with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: - f.write(model_card_content) - - def _save_pretrained(self, save_directory: str) -> None: - self.accelerator.unwrap_model(self.model).save_pretrained(save_directory) - self.tokenizer.save_pretrained(save_directory) - self.create_model_card(save_directory) - - def _show_tokens(self, tokens, masks): - from rich import print - from rich.text import Text - - text = Text() - - for _i, (token, mask) in enumerate(zip(tokens, masks)): - if mask == 1: - text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1") - text.append(" ") - else: - text.append(self.tokenizer.decode(token.item()), style="black on cyan3") - text.append(" ") - print(text) - - def _prepare_deepspeed(self, model: PreTrainedModelWrapper): - # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - config_kwargs = deepspeed_plugin.deepspeed_config - if model is not None: - if hasattr(model, "config"): - hidden_size = ( - max(model.config.hidden_sizes) - if getattr(model.config, "hidden_sizes", None) - else getattr(model.config, "hidden_size", None) - ) - if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: - # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` - # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 - config_kwargs.update( - { - "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, - "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, - "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, - } - ) + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + citation = textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="PPO", + trainer_citation=citation, + paper_title="Fine-Tuning Language Models from Human Preferences", + paper_id="1909.08593", + ) - # If ZeRO-3 is used, we shard both the active and reference model. - # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) - if config_kwargs["zero_optimization"]["stage"] != 3: - config_kwargs["zero_optimization"]["stage"] = 0 - model, *_ = deepspeed.initialize(model=model, config=config_kwargs) - model.eval() - return model + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/ppov2_config.py b/trl/trainer/ppov2_config.py index 8268579c14..e81c6f6db6 100644 --- a/trl/trainer/ppov2_config.py +++ b/trl/trainer/ppov2_config.py @@ -12,51 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from dataclasses import dataclass +import warnings -from ..trainer.utils import OnPolicyConfig +from .ppo_config import PPOConfig -@dataclass -class PPOv2Config(OnPolicyConfig): - r""" - Configuration class for the [`PPOv2Trainer`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): - Name of this experiment. - reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): - Path to the reward model. - num_ppo_epochs (`int`, *optional*, defaults to `4`): - Number of epochs to train. - whiten_rewards (`bool`, *optional*, defaults to `False`): - Whether to whiten the rewards. - kl_coef (`float`, *optional*, defaults to `0.05`): - KL coefficient. - cliprange (`float`, *optional*, defaults to `0.2`): - Clip range. - vf_coef (`float`, *optional*, defaults to `0.1`): - Value function coefficient. - cliprange_value (`float`, *optional*, defaults to `0.2`): - Clip range for the value function. - gamma (`float`, *optional*, defaults to `1.0`): - Discount factor. - lam (`float`, *optional*, defaults to `0.95`): - Lambda value for GAE. - """ - - exp_name: str = os.path.basename(__file__)[: -len(".py")] - reward_model_path: str = "EleutherAI/pythia-160m" - num_ppo_epochs: int = 4 - whiten_rewards: bool = False - kl_coef: float = 0.05 - cliprange: float = 0.2 - vf_coef: float = 0.1 - cliprange_value: float = 0.2 - gamma: float = 1.0 - lam: float = 0.95 +# Define an alias for PPOv2Config that raises a warning +class PPOv2Config(PPOConfig): + def __init__(self, *args, **kwargs): + warnings.warn( + "`PPOv2Config` is deprecated and has been renamed to `PPOConfig`. Please use `PPOConfig` instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index e35783c342..a0187fce34 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -12,702 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc -import math -import os -import textwrap -import time -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +import warnings -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -import torch.nn.functional as F -from accelerate import Accelerator -from accelerate.utils import broadcast, gather_object -from datasets import Dataset -from torch.utils.data import DataLoader -from transformers import ( - BaseImageProcessor, - DataCollatorWithPadding, - FeatureExtractionMixin, - GenerationConfig, - PreTrainedTokenizerBase, - ProcessorMixin, - Trainer, - TrainerCallback, - TrainerControl, - is_wandb_available, -) -from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK -from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from .ppo_trainer import PPOTrainer -from ..core import masked_mean, masked_whiten -from ..models.utils import unwrap_model_for_generation -from ..trainer.utils import ( - OnlineTrainerState, - batch_generation, - disable_dropout_in_model, - exact_div, - first_true_indices, - forward, - get_reward, - prepare_deepspeed, - print_rich_table, - truncate_response, -) -from .ppov2_config import PPOv2Config -from .utils import generate_model_card - -if is_wandb_available(): - import wandb - - -INVALID_LOGPROB = 1.0 - - -# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 -# we did this we can do a single `model = accelerator.prepare(model)` -class PolicyAndValueWrapper(nn.Module): - def __init__(self, policy, value_model) -> None: - super().__init__() - self.policy = policy - self.value_model = value_model - self.critic_backbone = getattr(value_model, value_model.base_model_prefix) - - def forward(self, **kwargs): - output = self.critic_backbone( - **kwargs, - ) - logits = self.value_model.score(output.hidden_states[-1]) - return self.policy(**kwargs), logits - - -class PPOv2Trainer(Trainer): - _tag_names = ["trl", "ppo"] - - def __init__( - self, - config: PPOv2Config, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ], - policy: nn.Module, - ref_policy: nn.Module, - reward_model: nn.Module, - train_dataset: Dataset, - value_model: Optional[nn.Module] = None, - data_collator: Optional[DataCollatorWithPadding] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - # less commonly used - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - callbacks: Optional[List[TrainerCallback]] = None, - ) -> None: - if ref_policy is policy: - raise ValueError( - "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " - "same as `policy`, you must mass a copy of it, or `None` if you use peft." - ) - - self.args = config - args = config - self.processing_class = processing_class - self.policy = policy - - self.policy.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding - - self.ref_policy = ref_policy - self.reward_model = reward_model - self.train_dataset = train_dataset - self.train_dataset_len = len(train_dataset) - self.value_model = value_model - self.data_collator = data_collator - self.eval_dataset = eval_dataset - self.optimizer, self.lr_scheduler = optimizers - - ######### - # calculate various batch sizes - ######### - if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - self.accelerator = accelerator - args.world_size = accelerator.num_processes - args.local_batch_size = ( - args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches - ) - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.batch_size = int(args.local_batch_size * args.world_size) - args.mini_batch_size = exact_div( - args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" - ) - args.local_mini_batch_size = exact_div( - args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" - ) - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - # `per_rank_rollout_batch_size` is our `args.local_batch_size` - # `per_rank_minibatch_size` is our `args.local_mini_batch_size` - args.num_total_batches = math.ceil( - args.total_episodes / args.batch_size - ) # we may train for more than `total_episodes` - time_tensor = torch.tensor(int(time.time()), device=accelerator.device) - time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes - args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" - self.local_seed = args.seed + accelerator.process_index * 100003 # Prime - if args.num_sample_generations > 0: - self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = args.local_batch_size - - ######### - # setup model, optimizer, and others - ######### - for module in [policy, ref_policy, value_model, reward_model]: - disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id - self.model = PolicyAndValueWrapper(policy, value_model) - self.model.config = policy.config # needed for pushing to hub - self.create_optimizer_and_scheduler( - num_training_steps=args.num_total_batches - ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level - - ######### - ### trainer specifics - ######### - default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = CallbackHandler( - self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler - ) - self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) - self.control = TrainerControl() - self.state = OnlineTrainerState( - is_local_process_zero=self.is_local_process_zero(), - is_world_process_zero=self.is_world_process_zero(), - stateful_callbacks=[ - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) - ], - ) - self.current_flos = 0 - self.hp_search_backend = None - self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - # Create distant repo and output directory if needed - self.hub_model_id = None - if self.args.push_to_hub: - self.init_hf_repo() - if self.args.should_save: - os.makedirs(self.args.output_dir, exist_ok=True) - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - ######### - ### setup dataloader - ######### - self.dataloader = DataLoader( - self.train_dataset, - batch_size=self.local_dataloader_batch_size, - shuffle=True, - collate_fn=DataCollatorWithPadding(self.processing_class), - drop_last=True, # needed; otherwise the last batch will be of ragged shape - ) - # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` - # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c - torch.manual_seed(args.seed) - self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) - torch.manual_seed(self.local_seed) # reset the local seed again - - self.eval_dataloader = DataLoader( - self.eval_dataset, - batch_size=args.per_device_eval_batch_size, - collate_fn=DataCollatorWithPadding(self.processing_class), - drop_last=True, - ) # no need to shuffle eval dataset - self.eval_dataloader = accelerator.prepare(self.eval_dataloader) - - if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - self.ref_policy = prepare_deepspeed( - self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - else: - self.ref_policy = self.ref_policy.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) - - def get_train_dataloader(self) -> DataLoader: - return self.dataloader - - def get_eval_dataloader(self) -> DataLoader: - return self.eval_dataloader - - def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - backup_model = self.model - self.model = self.model.policy # save only the policy - - if self.is_deepspeed_enabled: - backup_deepspeed = self.deepspeed - self.deepspeed = self.model - - super().save_model(output_dir, _internal_call) - - self.model = backup_model - - if self.is_deepspeed_enabled: - self.deepspeed = backup_deepspeed - - def train(self): - args = self.args - accelerator = self.accelerator - optimizer = self.optimizer - model = self.model - ref_policy = self.ref_policy - reward_model = self.reward_model - processing_class = self.processing_class - dataloader = self.dataloader - device = accelerator.device - - def repeat_generator(): - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - generation_config = GenerationConfig( - max_new_tokens=args.response_length, - temperature=(args.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - accelerator.print("===training policy===") - start_time = time.time() - stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) - approxkl_stats = torch.zeros(stats_shape, device=device) - pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) - model.train() - - # trainer state initialization - self.state.global_step = 0 - self.state.episode = 0 - self.state.max_steps = args.num_total_batches * args.num_mini_batches - self.state.num_train_epochs = args.total_episodes / self.train_dataset_len - # Compute absolute values for logging, eval, and save if given as ratio - if args.logging_steps is not None: - if args.logging_steps < 1: - self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) - else: - self.state.logging_steps = args.logging_steps - if args.eval_steps is not None: - if args.eval_steps < 1: - self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) - else: - self.state.eval_steps = args.eval_steps - if args.save_steps is not None: - if args.save_steps < 1: - self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) - else: - self.state.save_steps = args.save_steps - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - - # backward compatibility - if self.is_deepspeed_enabled: - self.deepspeed = self.model - self.model_wrapped = self.model - - for update in range(1, args.num_total_batches + 1): - self.state.episode += 1 * args.batch_size - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - context_length = queries.shape[1] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - sequence_lengths = [] - values = [] - with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - query_responses, logitss = batch_generation( - unwrapped_model.policy, - queries, - args.local_rollout_forward_batch_size, - processing_class.pad_token_id, - generation_config, - ) - - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response = query_responses[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] - logits = logitss[i : i + args.local_rollout_forward_batch_size] - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del logits, all_logprob - torch.cuda.empty_cache() - - ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response - ) - - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 - unwrapped_value_model = accelerator.unwrap_model(model).value_model - full_value, _, _ = get_reward( - unwrapped_value_model, query_response, processing_class.pad_token_id, context_length - ) - value = full_value[:, context_length - 1 : -1].squeeze(-1) - _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) - - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - values.append(value) - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - values = torch.cat(values, 0) - del (logprob, ref_logprob, full_value, value, score, unwrapped_model) - torch.cuda.empty_cache() - gc.collect() - - # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id - # Completions not passing that filter will receive a lower score. - contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) - if self.args.missing_eos_penalty is not None: - scores[~contain_eos_token] -= self.args.missing_eos_penalty - # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - sequence_lengths_p1 = sequence_lengths + 1 - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) - values = torch.masked_fill(values, padding_mask_p1, 0) - - # 4. compute rewards - kl = logprobs - ref_logprobs - non_score_reward = -args.kl_coef * kl - rewards = non_score_reward.clone() - actual_start = torch.arange(rewards.size(0), device=rewards.device) - actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) - rewards[[actual_start, actual_end]] += scores - - # 5. whiten rewards - if args.whiten_rewards: - rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) - rewards = torch.masked_fill(rewards, padding_mask_p1, 0) - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = responses.shape[1] - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.gamma * args.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = masked_whiten(advantages, ~padding_mask) - advantages = torch.masked_fill(advantages, padding_mask, 0) - torch.cuda.empty_cache() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - - output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - new_all_logprobs = F.log_softmax(logits, dim=-1) - new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) - new_logprobs = torch.masked_fill( - new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB - ) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) - vpredclipped = torch.clamp( - vpred, - mb_values - args.cliprange_value, - mb_values + args.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) - vf_clipfrac = masked_mean( - (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] - ) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) - loss = pg_loss + args.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] - ) - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - pg_clipfrac - ) - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - vf_clipfrac - ) - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - # del everything and empty cache - # fmt: off - del ( - output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped, - vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, - mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, - ) - # fmt: on - torch.cuda.empty_cache() - with torch.no_grad(): - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - rlhf_reward = mean_non_score_reward + scores.mean() - eps = int(self.state.episode / (time.time() - start_time)) - metrics = {} - metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item() - metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item() - metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() - metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = self.state.episode - self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log - self.state.global_step += 1 - self.log(metrics) - - self.lr_scheduler.step() - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=metrics) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward - torch.cuda.empty_cache() - gc.collect() - - if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: - self.generate_completions(sampling=True) - torch.cuda.empty_cache() - del ( - query_responses, - responses, - postprocessed_responses, - logprobs, - ref_logprobs, - values, - sequence_lengths, - contain_eos_token, - sequence_lengths_p1, - response_idxs, - padding_mask, - padding_mask_p1, - rewards, - actual_start, - actual_end, - advantages, - returns, - ) - torch.cuda.empty_cache() - - # HF trainer specifics - self.control = self.callback_handler.on_train_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None, metrics=None) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - - def generate_completions(self, sampling: bool = False): - args = self.args - processing_class = self.processing_class - generation_config = GenerationConfig( - max_new_tokens=self.args.response_length, - temperature=(0.01 + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - table = defaultdict(list) - with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - query_response, _ = batch_generation( - unwrapped_model.policy, - query, - query.shape[0], - processing_class.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response - ) - table["query"].extend( - gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) - ) - table["model response"].extend( - gather_object(processing_class.batch_decode(postprocessed_response)) - ) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) - - if sampling: - break - df = pd.DataFrame(table) - - if self.accelerator.is_main_process: - print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) - - def create_model_card( - self, - model_name: Optional[str] = None, - dataset_name: Optional[str] = None, - tags: Union[str, List[str], None] = None, - ): - """ - Creates a draft of a model card using the information available to the `Trainer`. - - Args: - model_name (`str`, *optional*, defaults to `None`): - The name of the model. - dataset_name (`str`, *optional*, defaults to `None`): - The name of the dataset used for training. - tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): - Tags to be associated with the model card. - """ - if not self.is_world_process_zero(): - return - - if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): - base_model = self.model.config._name_or_path - else: - base_model = None - - tags = tags or [] - if isinstance(tags, str): - tags = [tags] - - if hasattr(self.model.config, "unsloth_version"): - tags.append("unsloth") - - citation = textwrap.dedent("""\ - @article{mziegler2019fine-tuning, - title = {{Fine-Tuning Language Models from Human Preferences}}, - author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, - year = 2019, - eprint = {arXiv:1909.08593} - }""") - - model_card = generate_model_card( - base_model=base_model, - model_name=model_name, - hub_model_id=self.hub_model_id, - dataset_name=dataset_name, - tags=tags, - wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, - trainer_name="PPO", - trainer_citation=citation, - paper_title="Fine-Tuning Language Models from Human Preferences", - paper_id="1909.08593", +# Define an alias for PPOv2Trainer that raises a warning +class PPOv2Trainer(PPOTrainer): + def __init__(self, *args, **kwargs): + warnings.warn( + "`PPOv2Trainer` is deprecated and has been renamed to `PPOTrainer`. Please use `PPOTrainer` instead.", + DeprecationWarning, + stacklevel=2, ) - - model_card.save(os.path.join(self.args.output_dir, "README.md")) + super().__init__(*args, **kwargs)