From f04446d9bcff37a1215279b1394fc8aff93b47d0 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 5 Jun 2023 05:29:32 -0700 Subject: [PATCH 01/36] Implementing support for dense rewards --- examples/ppo_redemption.py | 83 +++++++++++++++++++++++++ trlx/trainer/accelerate_base_trainer.py | 10 +-- trlx/trainer/accelerate_ppo_trainer.py | 46 ++++++++++---- 3 files changed, 120 insertions(+), 19 deletions(-) create mode 100644 examples/ppo_redemption.py diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py new file mode 100644 index 000000000..b930b2dc7 --- /dev/null +++ b/examples/ppo_redemption.py @@ -0,0 +1,83 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline, AutoTokenizer + +import trlx +from trlx.data.default_configs import TRLConfig, default_ppo_config + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + +def get_negative_score(scores): + return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + config.method.cliprange_reward = False + config.method.gen_kwargs["max_new_tokens"] = 70 + config.method.gen_kwargs["temperature"] = 0.3 + config.train.total_steps = 20000 + config.train.checkpoint_interval = 10000000 + #config.method.init_kl_coef = 0 + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]: + # Reward positively for initially negative then positive review + # Reward functions should never receive padded text except for a singel EOS at the end + # Reward function should return token rewards for just the response + # Note: To get trajectory length, the reward fn should not tokenize the samples but should instead separately tokenizer prompts and outputs and then combine them + # Also note outputs has a single EOS at end of each + first_halves = [".".join(sample.split(".")[:len(sample.split(".")) // 2]) for sample in samples] + negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2:]) for sample in samples] + positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) + text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] + tok_scores = [] + for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): + toks = model_tok(response).input_ids + tok_score = [0] * len(toks) + # Hacky way of assigning intermediate score + tok_score[len(tok_score) // 2] = text_score[0] + tok_score[-1] = text_score[1] + tok_scores.append(tok_score) + return tok_scores + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train+test") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=dense_reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 5c82335c0..18af1333d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -232,9 +232,7 @@ def decode( # or add one if it was trimmed with `self.stop_sequences`. # When a generation ended due to `max_new_tokens` exhaustion, # only then or token would not be present in the original sample at the end - if append_eos_token and ( - trimmed or sample[-1] == self.tokenizer.eos_token_id or sample[-1] == self.tokenizer.pad_token_id - ): + if append_eos_token: str_output += self.tokenizer.eos_token str_prompts.append(str_prompt) @@ -427,10 +425,8 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = torch.tensor( - self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata), - dtype=float, - ) + rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) + rewards = torch.tensor([sum(r) if type(r) is list else r for r in rewards], dtype=float) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a3af9aa3f..985b79c21 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -6,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence import transformers from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -297,21 +298,24 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) rollout_score_time = time() - all_scores = torch.tensor( - self.reward_fn( - samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, **metadata - ), - dtype=torch.float, - device=device, - ) + # reward_fn should return list of rewards at each token per sample + # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) + all_scores = self.reward_fn(samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, model_tok=self.tokenizer, **metadata) + all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores] + # Pad 0 reward on the ends + all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-1) + max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device) + stats["time/rollout_score"] = time() - rollout_score_time - all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1).unbind()) + all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind()) else: all_scores = None + max_len = torch.tensor(0, dtype=torch.long, device=device) if torch.distributed.is_initialized(): - scores = torch.empty(len(samples), device=device) + torch.distributed.broadcast(max_len, 0) + scores = torch.empty((len(samples), max_len), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() @@ -342,7 +346,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.mean(), scores.std() + self.ref_mean, self.ref_std = scores.sum(dim=1).mean(), scores.sum(dim=1).std() all_scores_mean, all_scores_std = self.running_moments.update(scores) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() @@ -415,6 +419,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) else: + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) @@ -425,6 +430,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask = sample_outputs != self.tokenizer.pad_token_id start = 0 else: + # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] @@ -436,12 +442,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logprobs = ref_logprobs.cpu() prompt_tensors = prompt_tensors.cpu() sample_outputs = sample_outputs.cpu() + # TODO(dahoas): Why [:, :-1]? Redudant with clipping via start : ends[ix]? + # Actually I think it's just wrong? values = values.cpu()[:, :-1] # Get the logprobs and values, for tokens that are not padding, - # from the start of the prompt up to the token, while also including the latter + # from the end of the prompt up to the token, while also including the latter # (these are taken from the student model and not the reference model) ends = start + attention_mask[:, start:].sum(1) + 1 + # NOTE: values[i] is the value of the state after response token i + # TODO(dahoas): Does it actually make sense to get the rewards one step early? all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] @@ -451,8 +461,20 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count = 0 for sample_idx in range(n_samples): + # To compute per token reward first add in kl penalties over trajectory + # NOTE: kl_penalty[i] is kl_diff at token i+1 in the output (w/o EOS) rewards = kl_penalty[sample_idx] - rewards[-1] += scores[sample_idx].cpu() + # Then add in rewards + if scores.shape[1] == 1: + # NOTE: Final reward given at EOS token following HHH practice + rewards[-1] += scores[sample_idx][0].cpu() + else: + score = scores[sample_idx] + score_right_padding = torch.sum(score != -1) + score = score[:score_right_padding].cpu() + p_score = torch.zeros_like(rewards) + p_score[:score.shape[0]] += score + rewards += p_score ppo_rl_elements.append( PPORLElement( From 13a01fc6f986e36c77572d7d6732ceadc213b098 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Fri, 16 Jun 2023 06:42:41 -0700 Subject: [PATCH 02/36] added "num_return_sequences" param which corresponds to n in Best-of-N sampling --- trlx/data/default_configs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..9f82a5ba3 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -54,6 +54,7 @@ def default_ppo_config(): top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=16, ), ), ) From 5421a73bd680cc328db5b96ce1a3768243da8682 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Fri, 16 Jun 2023 07:20:21 -0700 Subject: [PATCH 03/36] updates to "num_return_sequences" param --- trlx/data/default_configs.py | 2 +- trlx/models/modeling_ppo.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 9f82a5ba3..2b9b67b52 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,12 +49,12 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, + num_return_sequences=10, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, - num_return_sequences=16, ), ), ) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 82d3ec637..eba137802 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -130,6 +130,7 @@ class PPOConfig(MethodConfig): ref_std: Optional[float] cliprange_reward: float gen_kwargs: dict + num_return_sequences: int gen_experience_kwargs: Optional[dict] = None def get_advantages_and_returns( From 2f3ac2816e60af5aeb9f2b8eac5e16a8465e9616 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Fri, 16 Jun 2023 07:30:25 -0700 Subject: [PATCH 04/36] BoN implementation --- trlx/trainer/accelerate_ppo_trainer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 985b79c21..32e6a860c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -274,10 +274,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"]) + samples = self.generate(batch["input_ids"], batch["attention_mask"], num_return_sequences=self.config.method.num_return_sequences) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids + prompt_tensors = batch.input_ids.repeat_interleave(self.config.method.num_return_sequences, dim=0) # TODO: It is hard-coded to 10 here. Change it to a variable device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) @@ -319,6 +319,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() + # Best-of-N Sampling. + max_score_indices = self.get_max_indices(scores, self.config.method.num_return_sequences, device) + scores = scores.index_select(0, max_score_indices) + samples = samples.index_select(0, max_score_indices) + prompt_tensors = prompt_tensors.index_select(0, max_score_indices) str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -507,3 +512,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) + + @staticmethod + def get_max_indices(input_tensor, window_size, device): + # Use unfold to create the sliding windows + unfolded = input_tensor.unfold(0, window_size, window_size) + + # Find the max values and indices along the unfolded dimension + values, indices = unfolded.max(dim=2) + + # Adjust indices to be relative to original tensor + indices += torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) + + return indices.squeeze() From 2f1dace62a637ded875ff7955e16d77e64ac0419 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 19 Jun 2023 03:13:37 -0700 Subject: [PATCH 05/36] Changed back to default. --- trlx/data/default_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 2b9b67b52..57adeea8b 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,7 +49,7 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_return_sequences=10, + num_return_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, From f58170dc3022f1c21f7bd53c5c88882984240751 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 19 Jun 2023 05:25:32 -0700 Subject: [PATCH 06/36] TopK sampling instead of Top1 --- trlx/data/default_configs.py | 1 + trlx/models/modeling_ppo.py | 1 + trlx/trainer/accelerate_ppo_trainer.py | 21 +++++++++------------ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 57adeea8b..b29202628 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -50,6 +50,7 @@ def default_ppo_config(): ref_std=None, cliprange_reward=10, num_return_sequences=1, + num_train_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index eba137802..45c7780d6 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -131,6 +131,7 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict num_return_sequences: int + num_train_sequences: int gen_experience_kwargs: Optional[dict] = None def get_advantages_and_returns( diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 32e6a860c..ec3594174 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -320,10 +320,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: scores = all_scores[0].clone().detach() # Best-of-N Sampling. - max_score_indices = self.get_max_indices(scores, self.config.method.num_return_sequences, device) - scores = scores.index_select(0, max_score_indices) - samples = samples.index_select(0, max_score_indices) - prompt_tensors = prompt_tensors.index_select(0, max_score_indices) + train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device) + scores = scores.index_select(0, train_indices) + samples = samples.index_select(0, train_indices) + prompt_tensors = prompt_tensors.index_select(0, train_indices) str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -514,14 +514,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq self.push_to_store(ppo_rl_elements) @staticmethod - def get_max_indices(input_tensor, window_size, device): + def get_topk_indices(input_tensor, window_size: int, k: int, device): # Use unfold to create the sliding windows unfolded = input_tensor.unfold(0, window_size, window_size) - - # Find the max values and indices along the unfolded dimension - values, indices = unfolded.max(dim=2) - + # Find the topk values and indices along the unfolded dimension + _, indices = torch.topk(unfolded, k, dim=2) # Adjust indices to be relative to original tensor - indices += torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) - - return indices.squeeze() + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) + return indices.reshape(-1) From be8bc1a27929157dfd50c2f7a053ee91083e0f76 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 03:10:16 -0700 Subject: [PATCH 07/36] summed along dim=1 --- trlx/trainer/accelerate_ppo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index ec3594174..cd7fd90a2 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -515,6 +515,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): + # Sum the scores along dim 1 + input_tensor = input_tensor.sum(1) # Use unfold to create the sliding windows unfolded = input_tensor.unfold(0, window_size, window_size) # Find the topk values and indices along the unfolded dimension From 608d812478bb6e38a2e86296de604572bddcb3cc Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 06:45:56 -0700 Subject: [PATCH 08/36] Generating samples in chunks --- trlx/trainer/accelerate_base_trainer.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 18af1333d..60e6233f7 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -247,8 +247,9 @@ def decode( return str_samples, str_prompts, str_outputs - def generate(self, input_ids, attention_mask=None, **kwargs): + def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" + # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) @@ -256,11 +257,23 @@ def generate(self, input_ids, attention_mask=None, **kwargs): kwargs = dict(self.generate_experience_kwargs, **kwargs) else: kwargs = dict(self.generate_kwargs, **kwargs) - - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) + # Chunk input_ids and attention_mask + + input_ids = input_ids.chunk(chunk_size, 0) + if attention_mask is not None: + attention_mask = attention_mask.chunk(chunk_size, 0) + + samples = [] + for chunk_idx in range(chunk_size): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + ) + samples.append(sample) + # Concat samples + samples = torch.cat(samples, 0) + return samples + def generate_eval(self, input_ids, attention_mask=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" From d8557e73002de89764d813753dacac27f5afee82 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 08:09:10 -0700 Subject: [PATCH 09/36] added gen_chunk_size parameter --- trlx/data/default_configs.py | 5 +++-- trlx/models/modeling_ppo.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index b29202628..e49b46f65 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,8 +49,9 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_return_sequences=1, - num_train_sequences=1, + num_return_sequences=10, + num_train_sequences=10, + gen_chunk_size=4, gen_kwargs=dict( max_new_tokens=40, top_k=0, diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 45c7780d6..5bb808b41 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -132,6 +132,7 @@ class PPOConfig(MethodConfig): gen_kwargs: dict num_return_sequences: int num_train_sequences: int + gen_chunk_size: int gen_experience_kwargs: Optional[dict] = None def get_advantages_and_returns( From 8ef9c36622cab21bab46dd1e3a60250bf587113c Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 08:09:30 -0700 Subject: [PATCH 10/36] chunking in forward prop --- trlx/trainer/accelerate_ppo_trainer.py | 86 ++++++++++++++++---------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cd7fd90a2..2ac901f03 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -274,7 +274,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"], num_return_sequences=self.config.method.num_return_sequences) + samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.gen_chunk_size, num_return_sequences=self.config.method.num_return_sequences) stats["time/rollout_generate"] = time() - rollout_generate_time prompt_tensors = batch.input_ids.repeat_interleave(self.config.method.num_return_sequences, dim=0) # TODO: It is hard-coded to 10 here. Change it to a variable @@ -395,39 +395,63 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq return_dict=True, ).logits else: + values_chunks = [] + logits_chunks = [] + ref_logits_chunks = [] + log_probs_chunks = [] + ref_logprobs_chunks = [] all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - with torch.no_grad(): - logits, *_, values = self.model( - all_tokens, attention_mask=attention_mask, position_ids=position_ids - ) - # TODO(dahoas): When hydra model works need to also support generation on hydra head - if hasattr(self.model, "frozen_head") or self.model.peft_type: - ref_logits = self.model.forward_hydra( - all_tokens, - attention_mask=attention_mask, + all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) + position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + with torch.no_grad(): + logits, *_, values = self.model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, position_ids=position_ids, - return_dict=True, - ).logits + ) + # TODO(dahoas): When hydra model works need to also support generation on hydra head + if hasattr(self.model, "frozen_head"): + ref_logits = self.model.forward_hydra( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids, + return_dict=True, + ).logits + elif hasattr(self, "ref_model"): + ref_logits = self.ref_model( + all_tokens_chunk, + attention_mask=attention_mask_chunk, + position_ids=position_ids, + return_dict=True, + ).logits + ref_logits = ref_logits.to(device) + else: + ref_logits = logits.clone().detach() + if self.config.model.model_arch_type == "seq2seq": + logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) else: - ref_logits = self.ref_model( - all_tokens, - attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, - ).logits - ref_logits = ref_logits.to(device) - - if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) - else: - # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) - + # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled + logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens_chunk[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens_chunk[:, 1:]) + + values_chunks.append(values.cpu()) + logits_chunks.append(logits.cpu()) + ref_logits_chunks.append(ref_logits.cpu()) + log_probs_chunks.append(logprobs.cpu()) + ref_logprobs_chunks.append(ref_logprobs.cpu()) + + values = torch.cat(values_chunks, dim=0) + logits = torch.cat(logits_chunks, dim=0) + ref_logits = torch.cat(ref_logits_chunks, dim=0) + logprobs = torch.cat(log_probs_chunks, dim=0) + ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) + n_samples: int = samples.shape[0] # Estimate the KL divergence between the model and reference model @@ -437,7 +461,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 - + attention_mask = attention_mask.cpu() log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio mean_kl_per_token = kl.mean() @@ -494,7 +518,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_count += 1 if torch.distributed.is_initialized(): - torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(mean_kl.to(self.accelerator.device), torch.distributed.ReduceOp.AVG) stats["time/rollout_time"] = clock.tick() stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item() @@ -516,7 +540,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): # Sum the scores along dim 1 - input_tensor = input_tensor.sum(1) + input_tensor = input_tensor.sum(1).unsqueeze(1) # Use unfold to create the sliding windows unfolded = input_tensor.unfold(0, window_size, window_size) # Find the topk values and indices along the unfolded dimension From 4c1d82df50884a3d731d43053da0c63b15ef4508 Mon Sep 17 00:00:00 2001 From: Sharath Raparthy Date: Mon, 26 Jun 2023 08:13:52 -0700 Subject: [PATCH 11/36] chunking generations in train and eval --- trlx/trainer/accelerate_base_trainer.py | 58 ++++++++++++++++--------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 60e6233f7..1efdb8195 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -257,25 +257,28 @@ def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): kwargs = dict(self.generate_experience_kwargs, **kwargs) else: kwargs = dict(self.generate_kwargs, **kwargs) - # Chunk input_ids and attention_mask - - input_ids = input_ids.chunk(chunk_size, 0) - if attention_mask is not None: - attention_mask = attention_mask.chunk(chunk_size, 0) - - samples = [] - for chunk_idx in range(chunk_size): + if chunk_size is not None: + # Chunk input_ids and attention_mask + input_ids = input_ids.chunk(chunk_size, 0) + if attention_mask is not None: + attention_mask = attention_mask.chunk(chunk_size, 0) + samples = [] + for chunk_idx in range(chunk_size): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + ) + samples.append(sample) + # Concat samples + return torch.cat(samples, 0) + else: with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + return self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs ) - samples.append(sample) - # Concat samples - samples = torch.cat(samples, 0) - return samples - + - def generate_eval(self, input_ids, attention_mask=None, **kwargs): + def generate_eval(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: @@ -283,10 +286,25 @@ def generate_eval(self, input_ids, attention_mask=None, **kwargs): kwargs = dict(self.generate_kwargs, **kwargs) - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) + if chunk_size is not None: + # Chunk input_ids and attention_mask + input_ids = input_ids.chunk(chunk_size, 0) + if attention_mask is not None: + attention_mask = attention_mask.chunk(chunk_size, 0) + samples = [] + for chunk_idx in range(chunk_size): + with torch.no_grad(): + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs + ) + samples.append(sample) + # Concat samples + return torch.cat(samples, 0) + else: + with torch.no_grad(): + return self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for From ecd5107e3f119d6f84951cdc1f59ced2d819862b Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 5 Jun 2023 05:29:32 -0700 Subject: [PATCH 12/36] Implementing support for dense rewards --- trlx/trainer/accelerate_base_trainer.py | 7 ++++++- trlx/trainer/accelerate_ppo_trainer.py | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 1efdb8195..3e11d0c53 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -457,7 +457,12 @@ def evaluate(self): # noqa: C901 if self.reward_fn: logger.info("Computing rewards") rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) - rewards = torch.tensor([sum(r) if type(r) is list else r for r in rewards], dtype=float) + if type(rewards[0]) is torch.Tensor: + rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) + elif type(rewards[0]) is list: + rewards = torch.tensor([sum(reward) for reward in rewards]) + else: + rewards = torch.tensor(rewards) mean_reward = rewards.mean().item() columns.append("reward") if not isinstance(rewards, list): diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2ac901f03..ad2d5167c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -319,11 +319,14 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - # Best-of-N Sampling. + # Best-of-N Sampling. + scores_mask = scores != -1 train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device) scores = scores.index_select(0, train_indices) samples = samples.index_select(0, train_indices) prompt_tensors = prompt_tensors.index_select(0, train_indices) + scores_mask = scores_mask[train_indices] + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -351,7 +354,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = scores.sum(dim=1).mean(), scores.sum(dim=1).std() + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(dim=1).std() all_scores_mean, all_scores_std = self.running_moments.update(scores) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() @@ -499,7 +502,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards[-1] += scores[sample_idx][0].cpu() else: score = scores[sample_idx] - score_right_padding = torch.sum(score != -1) + score_right_padding = torch.sum(scores_mask[sample_idx]) score = score[:score_right_padding].cpu() p_score = torch.zeros_like(rewards) p_score[:score.shape[0]] += score From 4071604cbff3983d6d447ee9f64ecc947341f502 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Thu, 15 Jun 2023 08:20:11 -0700 Subject: [PATCH 13/36] Fix distributed ref_mean, ref_var bug for dense rewards --- trlx/trainer/accelerate_ppo_trainer.py | 2 +- trlx/utils/modeling.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index ad2d5167c..d772a8653 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -355,7 +355,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(dim=1).std() - all_scores_mean, all_scores_std = self.running_moments.update(scores) + all_scores_mean, all_scores_std = self.running_moments.update(scores, scores_mask) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() stats["rollout_scores/running_mean"] = self.running_moments.mean.item() diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 47688f553..c6f3dd8ee 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Dict, MutableMapping, Tuple, Union +from typing import Any, Dict, List, MutableMapping, Tuple, Union, Optional import accelerate import numpy as np @@ -276,8 +276,11 @@ def __init__(self): self.var = 1 self.count = 1e-24 - def update(self, xs: torch.Tensor) -> Tuple[float, float]: + def update(self, xs: torch.Tensor, xs_mask: Optional[torch.Tensor] = None) -> Tuple[float, float]: """Updates running moments from batch's moments computed across ranks""" + if xs_mask is None: + xs_mask = torch.ones_like(xs) + xs = torch.sum(xs * xs_mask, dim=1) if dist.is_initialized(): xs_mean, xs_var, xs_count = get_global_statistics(xs) else: From 5f41413bb3b1d8e788f13a76bfac75fe5355c4f9 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Fri, 23 Jun 2023 07:57:15 -0700 Subject: [PATCH 14/36] Make generation respect max seq length --- trlx/trainer/accelerate_base_trainer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 3e11d0c53..780759422 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -253,10 +253,16 @@ def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) + # Update max_new_tokens to respect max_seq_length + prompt_length = input_ids.shape[1] if self.generate_experience_kwargs is not None: kwargs = dict(self.generate_experience_kwargs, **kwargs) else: kwargs = dict(self.generate_kwargs, **kwargs) + if kwargs.get("max_new_tokens") is not None: + kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) + else: + kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) if chunk_size is not None: # Chunk input_ids and attention_mask input_ids = input_ids.chunk(chunk_size, 0) @@ -286,6 +292,11 @@ def generate_eval(self, input_ids, attention_mask=None, chunk_size=None, **kwarg kwargs = dict(self.generate_kwargs, **kwargs) + if kwargs.get("max_new_tokens") is not None: + kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) + else: + kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) + if chunk_size is not None: # Chunk input_ids and attention_mask input_ids = input_ids.chunk(chunk_size, 0) From 22ae83f5e1ffd96a0afb8eddf440bf4f6340d13c Mon Sep 17 00:00:00 2001 From: Dahoas Date: Fri, 23 Jun 2023 08:26:37 -0700 Subject: [PATCH 15/36] Make experience before first round of training --- trlx/trainer/accelerate_ppo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index d772a8653..2f0cc8dd7 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -229,6 +229,8 @@ def prepare_learning(self): self.train_dataloader = self.store.create_loader(self.config.train.batch_size, shuffle=False) + self.make_experience(self.config.method.num_rollouts) + self.n_updates_per_batch = self.config.method.ppo_epochs self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) From 7d0a4be143530167f2e9f7061dfbe30d838b2d12 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Tue, 27 Jun 2023 04:33:44 -0700 Subject: [PATCH 16/36] Refactoring .generate/.generate_eval --- trlx/data/default_configs.py | 5 +- trlx/models/modeling_ppo.py | 10 ++- trlx/trainer/accelerate_base_trainer.py | 89 +++++++++---------------- trlx/trainer/accelerate_ppo_trainer.py | 25 ++++--- 4 files changed, 56 insertions(+), 73 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index e49b46f65..3acee97ab 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,14 +49,13 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_return_sequences=10, - num_train_sequences=10, - gen_chunk_size=4, + num_train_sequences=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True, + num_return_sequences=1, ), ), ) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 5bb808b41..bd0f57f88 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -112,6 +112,12 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] + + :param num_train_sequences: top_k of n sampled sequences from prompt + :type num_train_sequences: int + + :param mix_sft: if this is True, then SFT gradients will be mixed into PPO traininig + :type mix_sft: bool """ ppo_epochs: int @@ -130,10 +136,8 @@ class PPOConfig(MethodConfig): ref_std: Optional[float] cliprange_reward: float gen_kwargs: dict - num_return_sequences: int - num_train_sequences: int - gen_chunk_size: int gen_experience_kwargs: Optional[dict] = None + num_train_sequences: int = 1 def get_advantages_and_returns( self, diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 780759422..1b7ad8423 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -6,6 +6,7 @@ from contextlib import contextmanager from time import time from typing import Dict, List, Optional, Tuple +from copy import copy import ray import torch @@ -247,75 +248,49 @@ def decode( return str_samples, str_prompts, str_outputs - def generate(self, input_ids, attention_mask=None, chunk_size=4, **kwargs): + def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) + + generate_kwargs = copy(self.generate_kwargs) + generate_kwargs.update(kwargs) + # Update max_new_tokens to respect max_seq_length prompt_length = input_ids.shape[1] - if self.generate_experience_kwargs is not None: - kwargs = dict(self.generate_experience_kwargs, **kwargs) - else: - kwargs = dict(self.generate_kwargs, **kwargs) - if kwargs.get("max_new_tokens") is not None: - kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) - else: - kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - if chunk_size is not None: - # Chunk input_ids and attention_mask - input_ids = input_ids.chunk(chunk_size, 0) - if attention_mask is not None: - attention_mask = attention_mask.chunk(chunk_size, 0) - samples = [] - for chunk_idx in range(chunk_size): - with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs - ) - samples.append(sample) - # Concat samples - return torch.cat(samples, 0) + if generate_kwargs.get("max_new_tokens") is not None: + generate_kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"]) else: - with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs - ) - + generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - def generate_eval(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): - """Wraps hf's `generate` adding some specific method's defaults""" - input_ids = input_ids.to(self.accelerator.device) + # Repeat prompts, attention_masks for chunking if returning multiple sequences + if generate_kwargs.get("num_return_sequences") is None: + generate_kwargs["num_return_sequences"] = 1 + + num_return_sequences = generate_kwargs.pop("num_return_sequences") # Pop to hide from model.generate call + input_ids = input_ids.repeat_interleave(num_return_sequences, dim=0) if attention_mask is not None: - attention_mask = attention_mask.to(self.accelerator.device) + attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - kwargs = dict(self.generate_kwargs, **kwargs) + if chunk_size is None: + chunk_size = input_ids.shape[0] - if kwargs.get("max_new_tokens") is not None: - kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), kwargs["max_new_tokens"]) - else: - kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) - - if chunk_size is not None: - # Chunk input_ids and attention_mask - input_ids = input_ids.chunk(chunk_size, 0) - if attention_mask is not None: - attention_mask = attention_mask.chunk(chunk_size, 0) - samples = [] - for chunk_idx in range(chunk_size): - with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **kwargs - ) - samples.append(sample) - # Concat samples - return torch.cat(samples, 0) - else: + # Chunk input_ids and attention_mask + input_ids = input_ids.split(chunk_size, dim=0) + if attention_mask is not None: + attention_mask = attention_mask.split(chunk_size, dim=0) + samples = [] + for chunk_idx in range(len(input_ids)): with torch.no_grad(): - return self.accelerator.unwrap_model(self.model).generate( - input_ids=input_ids, attention_mask=attention_mask, **kwargs + sample = self.accelerator.unwrap_model(self.model).generate( + input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs ) + samples.append(sample) + # Concat samples + samples = torch.cat(samples, 0) + return samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for @@ -417,11 +392,11 @@ def evaluate(self): # noqa: C901 for i_prompt, prompts in enumerate(self.eval_dataloader): metadata = {k: v for k, v in prompts.items() if k != "input_ids" and k != "attention_mask"} if self.generate_sweep_kwarg: - samples = self.generate_eval( + samples = self.generate( prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} ) else: - samples = self.generate_eval(prompts["input_ids"], prompts["attention_mask"]) + samples = self.generate(prompts["input_ids"], prompts["attention_mask"]) # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2f0cc8dd7..cdbe12472 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -3,6 +3,7 @@ import uuid from time import time from typing import Callable, List +from copy import copy import torch import torch.nn.functional as F @@ -268,6 +269,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] + # Require chunk_size * num_train_sequences divides num_rollouts + assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_train_sequences) == 0 + while len(ppo_rl_elements) < num_rollouts: stats = {} # Get next batch in prompt dataset @@ -276,10 +280,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.gen_chunk_size, num_return_sequences=self.config.method.num_return_sequences) + samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.chunk_size, **self.generate_experience_kwargs) stats["time/rollout_generate"] = time() - rollout_generate_time - prompt_tensors = batch.input_ids.repeat_interleave(self.config.method.num_return_sequences, dim=0) # TODO: It is hard-coded to 10 here. Change it to a variable + num_return_sequences = self.generate_experience_kwargs["num_return_sequences"] if self.generate_experience_kwargs.get("num_return_sequences") is not None else 1 + prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device) @@ -323,12 +328,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = all_scores[0].clone().detach() # Best-of-N Sampling. scores_mask = scores != -1 - train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device) - scores = scores.index_select(0, train_indices) - samples = samples.index_select(0, train_indices) - prompt_tensors = prompt_tensors.index_select(0, train_indices) + train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) + scores = scores[train_indices] scores_mask = scores_mask[train_indices] - + samples = samples[train_indices] + prompt_tensors = prompt_tensors[train_indices] str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) @@ -417,21 +421,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq logits, *_, values = self.model( all_tokens_chunk, attention_mask=attention_mask_chunk, - position_ids=position_ids, + position_ids=position_ids_chunk, ) # TODO(dahoas): When hydra model works need to also support generation on hydra head if hasattr(self.model, "frozen_head"): ref_logits = self.model.forward_hydra( all_tokens_chunk, attention_mask=attention_mask_chunk, - position_ids=position_ids, + position_ids=position_ids_chunk, return_dict=True, ).logits elif hasattr(self, "ref_model"): ref_logits = self.ref_model( all_tokens_chunk, attention_mask=attention_mask_chunk, - position_ids=position_ids, + position_ids=position_ids_chunk, return_dict=True, ).logits ref_logits = ref_logits.to(device) @@ -466,6 +470,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: # NOTE: -1 because kl[prompt_tensors.shape[1]] is kl of the second token in the response start = prompt_tensors.shape[1] - 1 + attention_mask = attention_mask.cpu() log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] kl = log_ratio.exp() - 1 - log_ratio From b79dd19915cb93fadb752d8f7740166bed303091 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Thu, 29 Jun 2023 10:03:10 -0700 Subject: [PATCH 17/36] Fix BoN metric support --- trlx/trainer/accelerate_base_trainer.py | 19 ++++++++++++++++++- trlx/trainer/accelerate_ppo_trainer.py | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 1b7ad8423..9695ebb04 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -398,6 +398,14 @@ def evaluate(self): # noqa: C901 else: samples = self.generate(prompts["input_ids"], prompts["attention_mask"]) + # Repeat prompts, metadata num_return_sequence times + num_return_sequences = 1 + if self.generate_kwargs.get("num_return_sequences") is not None: + num_return_sequences = self.generate_kwargs["num_return_sequences"] + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(num_return_sequences, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(num_return_sequences, dim=0) + metadata = {k: self.repeat_interleave(v, num_return_sequences) for k, v in metadata.items()} + # TODO(reciprocated): this should be moved into `decode` # but that needs to be synced with indexing in `make_experience` if self.config.model.model_arch_type == "seq2seq": @@ -460,7 +468,7 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata) + metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -651,6 +659,15 @@ def learn(self): # noqa: C901 self.post_epoch_callback() tbar.close() + @staticmethod + def repeat_interleave(l, n): + if type(l) is torch.Tensor: + l = l.repeat_interleave(n, dim=0) + elif type(l) is list: + l = [[s]*n for s in l] + l = [item for sublist in l for item in sublist] + return l + @abstractmethod def get_arch(self, config: TRLConfig): """Returns a specific wrapper of the decoder architecture""" diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cdbe12472..a14c0752d 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -297,7 +297,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict({k: self.repeat_interleave(v, num_return_sequences) for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) if self.accelerator.is_main_process: all_str_samples, all_str_prompts, all_str_outputs = self.decode( From cb49dc538c592b78651157947d601741e4967247 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 3 Jul 2023 04:27:55 -0700 Subject: [PATCH 18/36] Enforce chunk_size param for eval generation when present --- trlx/trainer/accelerate_base_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 9695ebb04..60c2ae622 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -396,7 +396,8 @@ def evaluate(self): # noqa: C901 prompts["input_ids"], prompts["attention_mask"], **{gen_sweep_arg: gen_sweep_value} ) else: - samples = self.generate(prompts["input_ids"], prompts["attention_mask"]) + chunk_size = self.config.method.chunk_size if hasattr(self.config.method, "chunk_size") else None + samples = self.generate(prompts["input_ids"], prompts["attention_mask"], chunk_size=chunk_size) # Repeat prompts, metadata num_return_sequence times num_return_sequences = 1 From e290412541409206d51ece8f81309c28143af44f Mon Sep 17 00:00:00 2001 From: Dahoas Date: Tue, 4 Jul 2023 07:22:50 -0700 Subject: [PATCH 19/36] Fix: Don't shuffle prompt dataset --- trlx/trainer/accelerate_ppo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a14c0752d..cb553432f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -238,7 +238,7 @@ def prepare_learning(self): def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" - prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) + prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=False) prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) self.prompt_iterator = infinite_dataloader(prompt_dataloader) From 391d04cd51a1ba3d63d4b4421fe4f6295c4be654 Mon Sep 17 00:00:00 2001 From: dahoas Date: Tue, 18 Jul 2023 11:59:36 +0000 Subject: [PATCH 20/36] Move inputs to device --- trlx/trainer/accelerate_ppo_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index cb553432f..dec81d90c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -417,6 +417,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + all_tokens_chunk = all_tokens_chunk.to(device) + attention_mask_chunk = attention_mask_chunk.to(device) + position_ids_chunk = position_ids_chunk.to(device) with torch.no_grad(): logits, *_, values = self.model( all_tokens_chunk, From 8de84e42e572721bb5e1a08b0f61a6c0583b6463 Mon Sep 17 00:00:00 2001 From: dahoas Date: Tue, 18 Jul 2023 12:19:47 +0000 Subject: [PATCH 21/36] Fix style --- examples/ppo_redemption.py | 11 ++-- trlx/trainer/accelerate_base_trainer.py | 30 +++++++---- trlx/trainer/accelerate_ppo_trainer.py | 70 +++++++++++++++++++------ trlx/utils/modeling.py | 2 +- 4 files changed, 80 insertions(+), 33 deletions(-) diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py index b930b2dc7..84435b225 100644 --- a/examples/ppo_redemption.py +++ b/examples/ppo_redemption.py @@ -7,7 +7,7 @@ import torch from datasets import load_dataset -from transformers import pipeline, AutoTokenizer +from transformers import pipeline import trlx from trlx.data.default_configs import TRLConfig, default_ppo_config @@ -17,6 +17,7 @@ def get_positive_score(scores): "Extract value associated with a positive sentiment from pipeline's output" return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + def get_negative_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] @@ -29,7 +30,7 @@ def main(hparams={}): config.method.gen_kwargs["temperature"] = 0.3 config.train.total_steps = 20000 config.train.checkpoint_interval = 10000000 - #config.method.init_kl_coef = 0 + # config.method.init_kl_coef = 0 if torch.cuda.is_available(): device = int(os.environ.get("LOCAL_RANK", 0)) @@ -49,11 +50,9 @@ def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], # Reward positively for initially negative then positive review # Reward functions should never receive padded text except for a singel EOS at the end # Reward function should return token rewards for just the response - # Note: To get trajectory length, the reward fn should not tokenize the samples but should instead separately tokenizer prompts and outputs and then combine them - # Also note outputs has a single EOS at end of each - first_halves = [".".join(sample.split(".")[:len(sample.split(".")) // 2]) for sample in samples] + first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) - second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2:]) for sample in samples] + second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] tok_scores = [] diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 60c2ae622..ff19a4288 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -4,9 +4,9 @@ import sys from abc import abstractmethod from contextlib import contextmanager +from copy import copy from time import time from typing import Dict, List, Optional, Tuple -from copy import copy import ray import torch @@ -221,13 +221,11 @@ def decode( str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True) str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True) # Trim outputs up to `self.stop_sequences` if any are present - trimmed = False if self.stop_sequences: for stop in self.stop_sequences: stop_ix = str_output.find(stop) if stop_ix >= 0: str_output = str_output[:stop_ix].rstrip() - trimmed = True # Recover the last if it was present in the original sample # or add one if it was trimmed with `self.stop_sequences`. @@ -250,18 +248,20 @@ def decode( def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): """Wraps hf's `generate` adding some specific method's defaults""" - # Decide into chunk sizes and generate saples + # Decide into chunk sizes and generate saples input_ids = input_ids.to(self.accelerator.device) if attention_mask is not None: attention_mask = attention_mask.to(self.accelerator.device) generate_kwargs = copy(self.generate_kwargs) generate_kwargs.update(kwargs) - + # Update max_new_tokens to respect max_seq_length prompt_length = input_ids.shape[1] if generate_kwargs.get("max_new_tokens") is not None: - generate_kwargs["max_new_tokens"] = min(max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"]) + generate_kwargs["max_new_tokens"] = min( + max(self.max_length - prompt_length, 0), generate_kwargs["max_new_tokens"] + ) else: generate_kwargs["max_new_tokens"] = max(self.max_length - prompt_length, 0) @@ -451,7 +451,13 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: logger.info("Computing rewards") - rewards = self.reward_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) + rewards = self.reward_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + model_tok=self.tokenizer, + **metadata, + ) if type(rewards[0]) is torch.Tensor: rewards = torch.tensor([reward.sum().item() for reward in rewards], dtype=float) elif type(rewards[0]) is list: @@ -469,7 +475,13 @@ def evaluate(self): # noqa: C901 if self.metric_fn: logger.info("Computing metrics") metric_time = time() - metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, model_tok=self.tokenizer, **metadata) + metrics = self.metric_fn( + samples=str_samples, + prompts=str_prompts, + outputs=str_outputs, + model_tok=self.tokenizer, + **metadata, + ) stats["time/metric"] = time() - metric_time mean_metrics = { @@ -665,7 +677,7 @@ def repeat_interleave(l, n): if type(l) is torch.Tensor: l = l.repeat_interleave(n, dim=0) elif type(l) is list: - l = [[s]*n for s in l] + l = [[s] * n for s in l] l = [item for sublist in l for item in sublist] return l diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index dec81d90c..2fbb2068c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -3,12 +3,11 @@ import uuid from time import time from typing import Callable, List -from copy import copy import torch import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence import transformers +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -280,10 +279,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() # Generate samples from the language model (similar to using HuggingFace `generate` method) - samples = self.generate(batch["input_ids"], batch["attention_mask"], chunk_size=self.config.method.chunk_size, **self.generate_experience_kwargs) + samples = self.generate( + batch["input_ids"], + batch["attention_mask"], + chunk_size=self.config.method.chunk_size, + **self.generate_experience_kwargs, + ) stats["time/rollout_generate"] = time() - rollout_generate_time - num_return_sequences = self.generate_experience_kwargs["num_return_sequences"] if self.generate_experience_kwargs.get("num_return_sequences") is not None else 1 + num_return_sequences = ( + self.generate_experience_kwargs["num_return_sequences"] + if self.generate_experience_kwargs.get("num_return_sequences") is not None + else 1 + ) prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device @@ -297,7 +305,13 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq gathered_samples = self.accelerator.gather(padded_samples) gathered_prompts = self.accelerator.gather(padded_prompts) gathered_prompt_sizes = self.accelerator.gather(prompt_sizes) - metadata = gather_dict({k: self.repeat_interleave(v, num_return_sequences) for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}) + metadata = gather_dict( + { + k: self.repeat_interleave(v, num_return_sequences) + for k, v in batch.items() + if k != "input_ids" and k != "attention_mask" + } + ) if self.accelerator.is_main_process: all_str_samples, all_str_prompts, all_str_outputs = self.decode( @@ -307,8 +321,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_score_time = time() # reward_fn should return list of rewards at each token per sample # NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed) - all_scores = self.reward_fn(samples=all_str_samples, prompts=all_str_prompts, outputs=all_str_outputs, model_tok=self.tokenizer, **metadata) - all_scores = [torch.tensor(score, dtype=torch.float, device=device).view(-1,) for score in all_scores] + all_scores = self.reward_fn( + samples=all_str_samples, + prompts=all_str_prompts, + outputs=all_str_outputs, + model_tok=self.tokenizer, + **metadata, + ) + all_scores = [ + torch.tensor(score, dtype=torch.float, device=device).view( + -1, + ) + for score in all_scores + ] # Pad 0 reward on the ends all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-1) max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device) @@ -326,9 +351,14 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() - # Best-of-N Sampling. + # Best-of-N Sampling. scores_mask = scores != -1 - train_indices = self.get_topk_indices(input_tensor=scores_mask*scores, window_size=num_return_sequences, k=self.config.method.num_train_sequences, device=device) + train_indices = self.get_topk_indices( + input_tensor=scores_mask * scores, + window_size=num_return_sequences, + k=self.config.method.num_train_sequences, + device=device, + ) scores = scores[train_indices] scores_mask = scores_mask[train_indices] samples = samples[train_indices] @@ -360,7 +390,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # store statistics of the initial rollout as reference if self.ref_mean is None: - self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(dim=1).std() + self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum( + dim=1 + ).std() all_scores_mean, all_scores_std = self.running_moments.update(scores, scores_mask) stats["rollout_scores/mean"] = all_scores_mean.item() stats["rollout_scores/std"] = all_scores_std.item() @@ -416,7 +448,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) - for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip(all_tokens_chunks, attention_mask_chunks, position_ids_chunks): + for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( + all_tokens_chunks, attention_mask_chunks, position_ids_chunks + ): all_tokens_chunk = all_tokens_chunk.to(device) attention_mask_chunk = attention_mask_chunk.to(device) position_ids_chunk = position_ids_chunk.to(device) @@ -451,19 +485,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens_chunk[:, 1:]) ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens_chunk[:, 1:]) - + values_chunks.append(values.cpu()) logits_chunks.append(logits.cpu()) ref_logits_chunks.append(ref_logits.cpu()) log_probs_chunks.append(logprobs.cpu()) ref_logprobs_chunks.append(ref_logprobs.cpu()) - + values = torch.cat(values_chunks, dim=0) logits = torch.cat(logits_chunks, dim=0) ref_logits = torch.cat(ref_logits_chunks, dim=0) logprobs = torch.cat(log_probs_chunks, dim=0) ref_logprobs = torch.cat(ref_logprobs_chunks, dim=0) - + n_samples: int = samples.shape[0] # Estimate the KL divergence between the model and reference model @@ -515,7 +549,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq score_right_padding = torch.sum(scores_mask[sample_idx]) score = score[:score_right_padding].cpu() p_score = torch.zeros_like(rewards) - p_score[:score.shape[0]] += score + p_score[: score.shape[0]] += score rewards += p_score ppo_rl_elements.append( @@ -549,7 +583,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) - + @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): # Sum the scores along dim 1 @@ -559,5 +593,7 @@ def get_topk_indices(input_tensor, window_size: int, k: int, device): # Find the topk values and indices along the unfolded dimension _, indices = torch.topk(unfolded, k, dim=2) # Adjust indices to be relative to original tensor - indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1) + indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to( + device + ).unsqueeze(1) return indices.reshape(-1) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index c6f3dd8ee..b0036b3f6 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, List, MutableMapping, Tuple, Union, Optional +from typing import Dict, MutableMapping, Optional, Tuple, Union import accelerate import numpy as np From 3d7e0d5dd16fe88e4f0f5e73304146b8fa2c70e3 Mon Sep 17 00:00:00 2001 From: dahoas Date: Fri, 21 Jul 2023 17:06:43 +0000 Subject: [PATCH 22/36] Fix chunked generation --- trlx/trainer/accelerate_base_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ff19a4288..50977d397 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -10,6 +10,7 @@ import ray import torch +from torch.nn.utils.rnn import pad_sequence from accelerate import Accelerator # type: ignore from ray.air import session from rich.console import Console @@ -288,8 +289,8 @@ def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs ) samples.append(sample) - # Concat samples - samples = torch.cat(samples, 0) + # Concat padded samples + samples = pad_sequence(samples, batch_first=True, self.tokenizer.pad_token_id) return samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): From 1fda0ce312d26e8e81c8c0ecf259b708e2c64abc Mon Sep 17 00:00:00 2001 From: Max <56548574+maxreciprocate@users.noreply.github.com> Date: Sat, 22 Jul 2023 17:52:22 +0300 Subject: [PATCH 23/36] fix(accelerate_base_trainer): order of keyword arguments --- trlx/trainer/accelerate_base_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 50977d397..81b31fdf4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -290,7 +290,7 @@ def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): ) samples.append(sample) # Concat padded samples - samples = pad_sequence(samples, batch_first=True, self.tokenizer.pad_token_id) + samples = pad_sequence(samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) return samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): From 3ce3c2bc9dae6ed67d62cfa3c2d02e8647e23f1d Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 15:28:32 +0000 Subject: [PATCH 24/36] Removing old example --- examples/ppo_redemption.py | 82 -------------------------------------- 1 file changed, 82 deletions(-) delete mode 100644 examples/ppo_redemption.py diff --git a/examples/ppo_redemption.py b/examples/ppo_redemption.py deleted file mode 100644 index 84435b225..000000000 --- a/examples/ppo_redemption.py +++ /dev/null @@ -1,82 +0,0 @@ -# Generates positive movie reviews by tuning a pretrained model on IMDB dataset -# with a sentiment reward function -import json -import os -import sys -from typing import List - -import torch -from datasets import load_dataset -from transformers import pipeline - -import trlx -from trlx.data.default_configs import TRLConfig, default_ppo_config - - -def get_positive_score(scores): - "Extract value associated with a positive sentiment from pipeline's output" - return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] - - -def get_negative_score(scores): - return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] - - -def main(hparams={}): - # Merge sweep config with default config if given - config = TRLConfig.update(default_ppo_config().to_dict(), hparams) - config.method.cliprange_reward = False - config.method.gen_kwargs["max_new_tokens"] = 70 - config.method.gen_kwargs["temperature"] = 0.3 - config.train.total_steps = 20000 - config.train.checkpoint_interval = 10000000 - # config.method.init_kl_coef = 0 - - if torch.cuda.is_available(): - device = int(os.environ.get("LOCAL_RANK", 0)) - else: - device = -1 - - sentiment_fn = pipeline( - "sentiment-analysis", - "lvwerra/distilbert-imdb", - top_k=2, - truncation=True, - batch_size=256, - device=device, - ) - - def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], model_tok, **kwargs) -> List[float]: - # Reward positively for initially negative then positive review - # Reward functions should never receive padded text except for a singel EOS at the end - # Reward function should return token rewards for just the response - first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] - negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) - second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] - positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) - text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] - tok_scores = [] - for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): - toks = model_tok(response).input_ids - tok_score = [0] * len(toks) - # Hacky way of assigning intermediate score - tok_score[len(tok_score) // 2] = text_score[0] - tok_score[-1] = text_score[1] - tok_scores.append(tok_score) - return tok_scores - - # Take few words off of movies reviews as prompts - imdb = load_dataset("imdb", split="train+test") - prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] - - trlx.train( - reward_fn=dense_reward_fn, - prompts=prompts, - eval_prompts=["I don't know much about Hungarian underground"] * 256, - config=config, - ) - - -if __name__ == "__main__": - hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) - main(hparams) From 2635de50ee451a08975a204172e31d4eb2c8c1d7 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 15:29:54 +0000 Subject: [PATCH 25/36] Fix: remove extraneous method args --- trlx/models/modeling_ppo.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 22d1cad0f..48db7d263 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -115,9 +115,6 @@ class PPOConfig(MethodConfig): :param num_train_sequences: top_k of n sampled sequences from prompt :type num_train_sequences: int - - :param mix_sft: if this is True, then SFT gradients will be mixed into PPO traininig - :type mix_sft: bool """ ppo_epochs: int From 1be2c3cffe6a0f88c2916955e5ce900bbe1cb19f Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 15:36:53 +0000 Subject: [PATCH 26/36] Fix: Always set generate_experience_kwargs --- trlx/trainer/accelerate_ppo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 340903594..18b30d1f4 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -97,7 +97,7 @@ def __init__(self, config: TRLConfig, **kwargs): if config.method.gen_experience_kwargs is not None: self.generate_experience_kwargs = {**generate_kwargs, **config.method.gen_experience_kwargs} else: - self.generate_experience_kwargs = None + self.generate_experience_kwargs = {**self.generate_kwargs} # Setup stats tracker self.running_moments = RunningMoments() From 3cba0dbff33a1a21ffd13b3caa660eea5d7d0e2c Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 15:50:23 +0000 Subject: [PATCH 27/36] Fix: Remove mask from RunningMoments update call --- trlx/utils/modeling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index d220cd2a0..70b0c6f81 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -282,11 +282,8 @@ def __init__(self): self.var = 1 self.count = 1e-24 - def update(self, xs: torch.Tensor, xs_mask: Optional[torch.Tensor] = None) -> Tuple[float, float]: + def update(self, xs: torch.Tensor) -> Tuple[float, float]: """Updates running moments from batch's moments computed across ranks""" - if xs_mask is None: - xs_mask = torch.ones_like(xs) - xs = torch.sum(xs * xs_mask, dim=1) if dist.is_initialized(): xs_mean, xs_var, xs_count = get_global_statistics(xs) else: From 0cb91c42d756ad477463c8c558c9863df018d27a Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 15:55:54 +0000 Subject: [PATCH 28/36] Fix: style --- trlx/utils/modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 70b0c6f81..6e737c080 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -1,5 +1,5 @@ import functools -from typing import Dict, MutableMapping, Optional, Tuple, Union +from typing import Dict, MutableMapping, Tuple, Union import accelerate import numpy as np From cc92911f5302a90ee69ce210a2252f9d8d2e6fb4 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 16:23:24 +0000 Subject: [PATCH 29/36] Fix: rename 'gen_chunk_size' to 'chunk_size' --- trlx/trainer/accelerate_ppo_trainer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 18b30d1f4..d9666b596 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -282,6 +282,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() + print("ONE") + # Generate samples from the language model (similar to using HuggingFace `generate` method) samples = self.generate( batch["input_ids"], @@ -291,6 +293,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) stats["time/rollout_generate"] = time() - rollout_generate_time + print("TWO") + num_return_sequences = ( self.generate_experience_kwargs["num_return_sequences"] if self.generate_experience_kwargs.get("num_return_sequences") is not None @@ -349,12 +353,16 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_scores = None max_len = torch.tensor(0, dtype=torch.long, device=device) + print("THREE") + if torch.distributed.is_initialized(): torch.distributed.broadcast(max_len, 0) scores = torch.empty((len(samples), max_len), device=device) torch.distributed.scatter(scores, all_scores) else: scores = all_scores[0].clone().detach() + + print("FOUR") # Best-of-N Sampling. scores_mask = scores != -np.inf train_indices = self.get_topk_indices( @@ -449,9 +457,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.gen_chunk_size, dim=0) - attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.gen_chunk_size, dim=0) - position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.gen_chunk_size, dim=0) + all_tokens_chunks = torch.chunk(all_tokens, chunks=self.config.method.chunk_size, dim=0) + attention_mask_chunks = torch.chunk(attention_mask, chunks=self.config.method.chunk_size, dim=0) + position_ids_chunks = torch.chunk(position_ids, chunks=self.config.method.chunk_size, dim=0) for all_tokens_chunk, attention_mask_chunk, position_ids_chunk in zip( all_tokens_chunks, attention_mask_chunks, position_ids_chunks ): From 4297f98fdeabb5ea73a0f94e2f50117fbb13e173 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 16:40:03 +0000 Subject: [PATCH 30/36] Fix: generated samples padding --- trlx/trainer/accelerate_base_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 21a6cd649..ac8135cff 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -282,16 +282,16 @@ def generate(self, input_ids, attention_mask=None, chunk_size=None, **kwargs): input_ids = input_ids.split(chunk_size, dim=0) if attention_mask is not None: attention_mask = attention_mask.split(chunk_size, dim=0) - samples = [] + all_samples = [] for chunk_idx in range(len(input_ids)): with torch.no_grad(): - sample = self.accelerator.unwrap_model(self.model).generate( + samples = self.accelerator.unwrap_model(self.model).generate( input_ids=input_ids[chunk_idx], attention_mask=attention_mask[chunk_idx], **generate_kwargs ) - samples.append(sample) - # Concat padded samples - samples = pad_sequence(samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) - return samples + all_samples += [sample for sample in samples] + # Pad all_samples into one tensor + all_samples = pad_sequence(all_samples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + return all_samples def save_pretrained(self, directory: Optional[str] = None, **kwargs): """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for From 36f06af66bb30ce00eda5072553f5ae650bbc457 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 7 Aug 2023 16:43:46 +0000 Subject: [PATCH 31/36] Remove prints --- trlx/trainer/accelerate_ppo_trainer.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index d9666b596..0a0dd56a7 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -282,8 +282,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rollout_generate_time = time() - print("ONE") - # Generate samples from the language model (similar to using HuggingFace `generate` method) samples = self.generate( batch["input_ids"], @@ -293,8 +291,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) stats["time/rollout_generate"] = time() - rollout_generate_time - print("TWO") - num_return_sequences = ( self.generate_experience_kwargs["num_return_sequences"] if self.generate_experience_kwargs.get("num_return_sequences") is not None @@ -353,8 +349,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_scores = None max_len = torch.tensor(0, dtype=torch.long, device=device) - print("THREE") - if torch.distributed.is_initialized(): torch.distributed.broadcast(max_len, 0) scores = torch.empty((len(samples), max_len), device=device) @@ -362,7 +356,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: scores = all_scores[0].clone().detach() - print("FOUR") # Best-of-N Sampling. scores_mask = scores != -np.inf train_indices = self.get_topk_indices( From a2980ddb2e8ff5b45a38aaa208fb0638a4aaa10a Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 21 Aug 2023 11:10:14 +0000 Subject: [PATCH 32/36] Rename 'num_train_sequences' to 'num_topk_samples' --- trlx/data/default_configs.py | 2 +- trlx/models/modeling_ppo.py | 6 +++--- trlx/trainer/accelerate_ppo_trainer.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 3acee97ab..864b5abf8 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -49,7 +49,7 @@ def default_ppo_config(): ref_mean=None, ref_std=None, cliprange_reward=10, - num_train_sequences=1, + num_topk_samples=1, gen_kwargs=dict( max_new_tokens=40, top_k=0, diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 48db7d263..6c00815c2 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -113,8 +113,8 @@ class PPOConfig(MethodConfig): :param gen_experience_kwargs: if this is not None, then the experience is generated using this :type gen_experience_kwargs: Dict[str, Any] - :param num_train_sequences: top_k of n sampled sequences from prompt - :type num_train_sequences: int + :param num_topk_samples: top_k of n sampled sequences from prompt + :type num_topk_samples: int """ ppo_epochs: int @@ -134,7 +134,7 @@ class PPOConfig(MethodConfig): cliprange_reward: float gen_kwargs: dict gen_experience_kwargs: Optional[dict] = None - num_train_sequences: int = 1 + num_topk_samples: int = 1 num_value_layers_unfrozen: int = 0 def get_advantages_and_returns( diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 0a0dd56a7..2af09eed3 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -272,8 +272,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] - # Require chunk_size * num_train_sequences divides num_rollouts - assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_train_sequences) == 0 + # Require chunk_size * num_topk_samples divides num_rollouts + assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_topk_samples) == 0 while len(ppo_rl_elements) < num_rollouts: stats = {} @@ -361,7 +361,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq train_indices = self.get_topk_indices( input_tensor=scores_mask * scores, window_size=num_return_sequences, - k=self.config.method.num_train_sequences, + k=self.config.method.num_topk_samples, device=device, ) scores = scores[train_indices] From 3d5a63952ce0079f57fb817bc454e0d52bf9237f Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 21 Aug 2023 11:18:25 +0000 Subject: [PATCH 33/36] Address nits --- trlx/trainer/accelerate_base_trainer.py | 2 +- trlx/trainer/accelerate_ppo_trainer.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ac8135cff..913b642bd 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -480,7 +480,7 @@ def evaluate(self): # noqa: C901 samples=str_samples, prompts=str_prompts, outputs=str_outputs, - model_tok=self.tokenizer, + tokenizer=self.tokenizer, **metadata, ) stats["time/metric"] = time() - metric_time diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 2af09eed3..7bcd2871c 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -272,8 +272,15 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] + num_return_sequences = ( + self.generate_experience_kwargs["num_return_sequences"] + if self.generate_experience_kwargs.get("num_return_sequences") is not None + else 1 + ) + # Require chunk_size * num_topk_samples divides num_rollouts assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_topk_samples) == 0 + assert self.config.method.num_topk_samples <= num_return_sequence while len(ppo_rl_elements) < num_rollouts: stats = {} @@ -291,11 +298,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ) stats["time/rollout_generate"] = time() - rollout_generate_time - num_return_sequences = ( - self.generate_experience_kwargs["num_return_sequences"] - if self.generate_experience_kwargs.get("num_return_sequences") is not None - else 1 - ) prompt_tensors = batch.input_ids.repeat_interleave(num_return_sequences, dim=0) device = samples.device @@ -588,6 +590,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): + """Computes the indices of the top_k values among `input_tensor` on chunks of size `window_size` + """ # Sum the scores along dim 1 input_tensor = input_tensor.sum(1).unsqueeze(1) # Use unfold to create the sliding windows From 87837b632e8e4c23345092ab819c72b35ef97cb7 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 21 Aug 2023 11:25:10 +0000 Subject: [PATCH 34/36] Fix: style --- trlx/trainer/accelerate_ppo_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 7bcd2871c..81796332b 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -590,8 +590,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq @staticmethod def get_topk_indices(input_tensor, window_size: int, k: int, device): - """Computes the indices of the top_k values among `input_tensor` on chunks of size `window_size` - """ + """Computes the indices of the top_k values among `input_tensor` on chunks of size `window_size`""" # Sum the scores along dim 1 input_tensor = input_tensor.sum(1).unsqueeze(1) # Use unfold to create the sliding windows From ed93be857c48b121d8db898fb86c1cf2671e3628 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 21 Aug 2023 14:00:46 +0000 Subject: [PATCH 35/36] Set 'num_return_sequences' to 1 by default --- trlx/trainer/accelerate_ppo_trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 81796332b..859815eca 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -94,6 +94,9 @@ def __init__(self, config: TRLConfig, **kwargs): ) self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs} + if self.generate_kwargs.get("num_return_sequences") is None: + self.generate_kwargs["num_return_sequences"] = 1 + if config.method.gen_experience_kwargs is not None: self.generate_experience_kwargs = {**generate_kwargs, **config.method.gen_experience_kwargs} else: @@ -272,11 +275,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ppo_rl_elements = [] accumulated_stats = [] - num_return_sequences = ( - self.generate_experience_kwargs["num_return_sequences"] - if self.generate_experience_kwargs.get("num_return_sequences") is not None - else 1 - ) + num_return_sequences = self.generate_experience_kwargs["num_return_sequences"] # Require chunk_size * num_topk_samples divides num_rollouts assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_topk_samples) == 0 From 24925c83279c5517411e2695d11e4f5c1a02a3c4 Mon Sep 17 00:00:00 2001 From: dahoas Date: Mon, 21 Aug 2023 14:21:28 +0000 Subject: [PATCH 36/36] Fix: typo --- trlx/trainer/accelerate_ppo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 859815eca..23d18c5c9 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -279,7 +279,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq # Require chunk_size * num_topk_samples divides num_rollouts assert num_rollouts % (self.config.method.chunk_size * self.config.method.num_topk_samples) == 0 - assert self.config.method.num_topk_samples <= num_return_sequence + assert self.config.method.num_topk_samples <= num_return_sequences while len(ppo_rl_elements) < num_rollouts: stats = {}