Skip to content

Commit

Permalink
Merge branch 'main' into wpo
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Oct 1, 2024
2 parents 4d0162b + de38765 commit 84269e0
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 11 deletions.
8 changes: 3 additions & 5 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoTokenizer, GenerationConfig
from transformers import AutoTokenizer

from trl import (
GKDConfig,
Expand Down Expand Up @@ -93,6 +93,7 @@

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
padding_side="left",
)
Expand Down Expand Up @@ -124,10 +125,7 @@
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
)
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
trainer.add_callback(completions_callback)
trainer.train()

Expand Down
20 changes: 20 additions & 0 deletions tests/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,23 @@ def test_gkd_trainer(self):
self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))

def test_generation_config_init(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GKDConfig(output_dir=tmp_dir)
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")

trainer = GKDTrainer(
model=self.model_id,
teacher_model=self.model_id,
args=training_args,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
tokenizer=self.tokenizer,
)

self.assertEqual(trainer.generation_config.pad_token_id, self.tokenizer.eos_token_id)
self.assertEqual(trainer.generation_config.eos_token_id, self.model.generation_config.eos_token_id)
self.assertEqual(trainer.generation_config.max_new_tokens, training_args.max_new_tokens)
self.assertEqual(trainer.generation_config.temperature, training_args.temperature)
self.assertEqual(trainer.generation_config.top_k, 0)
10 changes: 10 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,17 @@ def __init__(
do_sample=True,
top_k=0,
use_cache=False if args.gradient_checkpointing else True,
pad_token_id=self.tokenizer.pad_token_id,
)
# Set custom EOS tokens if they are specified by the model's generation
# config. This is important for models with the Llama 3 chat template,
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
# turns or messages.
if (
hasattr(self.model.generation_config, "eos_token_id")
and self.model.generation_config.eos_token_id is not None
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id

@staticmethod
def generalized_jsd_loss(
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,9 +1380,9 @@ def generate_model_card(
tags: Union[str, List[str], None],
wandb_url: Optional[str],
trainer_name: str,
trainer_citation: Optional[str],
paper_title: Optional[str],
paper_id: Optional[str],
trainer_citation: Optional[str] = None,
paper_title: Optional[str] = None,
paper_id: Optional[str] = None,
) -> ModelCard:
"""
Generate a `ModelCard` from a template.
Expand All @@ -1402,11 +1402,11 @@ def generate_model_card(
Weights & Biases run URL.
trainer_name (`str`):
Trainer name.
trainer_citation (`str` or `None`):
trainer_citation (`str` or `None`, defaults to `None`):
Trainer citation as a BibTeX entry.
paper_title (`str` or `None`):
paper_title (`str` or `None`, defaults to `None`):
Paper title.
paper_id (`str` or `None`):
paper_id (`str` or `None`, defaults to `None`):
ArXiv paper ID as `YYMM.NNNNN`.
Returns:
Expand Down

0 comments on commit 84269e0

Please sign in to comment.