diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a7b2cf8033..21487b0d8e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,7 +20,7 @@ There are several ways you can contribute to TRL: * Fix outstanding issues with the existing code. * Submit issues related to bugs or desired new features. * Implement trainers for new post-training algorithms. -* Contribute to the examples or to the documentation. +* Contribute to the examples or the documentation. If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of @@ -74,19 +74,19 @@ If there is a new feature you'd like to see in TRL, please open an issue and des Whatever it is, we'd love to hear about it! 2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you. -3. Provide a *code snippet* that demonstrates the features usage. +3. Provide a *code snippet* that demonstrates the feature's usage. 4. If the feature is related to a paper, please include a link. If your issue is well written we're already 80% of the way there by the time you create it. ## Do you want to implement a new trainer? -New post-training methods are published on a frequent basis and those which satisfy the following criteria are good candidates to be integrated in TRL: +New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL: -* **Simplicity:** does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods. -* **Efficiency:** does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilises a similar objective as DPO, but requires half the GPU VRAM. +* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods. +* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM. -Methods which only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL. +Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL. If you want to implement a trainer for a new post-training method, first open an issue and provide the following information: @@ -102,7 +102,7 @@ Based on the community and maintainer feedback, the next step will be to impleme ## Do you want to add documentation? -We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links and any missing, unclear or inaccurate content.. We'll be happy to make the changes or help you make a contribution if you're interested! +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested! ## Submitting a pull request (PR) @@ -133,7 +133,7 @@ Follow these steps to start contributing: 3. Create a new branch to hold your development changes, and do this for every new PR you work on. - Start by synchronizing your `main` branch with the `upstream/main` branch (ore details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): + Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): ```bash $ git checkout main @@ -204,7 +204,7 @@ Follow these steps to start contributing: Please write [good commit messages](https://chris.beams.io/posts/git-commit/). It is a good idea to sync your copy of the code with the original - repository regularly. This way you can quickly account for changes: + Repository regularly. This way you can quickly account for changes: ```bash $ git fetch upstream @@ -221,10 +221,7 @@ Follow these steps to start contributing: webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. -7. It's ok if maintainers ask you for changes. It happens to core contributors - too! So everyone can see the changes in the Pull request, work in your local - branch and push the changes to your fork. They will automatically appear in - the pull request. +7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request. ### Checklist @@ -245,14 +242,14 @@ Follow these steps to start contributing: An extensive test suite is included to test the library behavior and several examples. Library tests can be found in the [tests folder](https://github.com/huggingface/trl/tree/main/tests). -We use `pytest` in order to run the tests. From the root of the -repository, here's how to run tests with `pytest` for the library: +We use `pytest` to run the tests. From the root of the +repository here's how to run tests with `pytest` for the library: ```bash $ python -m pytest -sv ./tests ``` -In fact, that's how `make test` is implemented (sans the `pip install` line)! +That's how `make test` is implemented (sans the `pip install` line)! -You can specify a smaller set of tests in order to test only the feature +You can specify a smaller set of tests to test only the feature you're working on. diff --git a/README.md b/README.md index e61b79c152..8a408cd07f 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ ## What is it? -TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). +TRL is a library that post-trains LLMs and diffusion models using methods such as Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). The library is built on top of [🤗 Transformers](https://github.com/huggingface/transformers) and is compatible with any model architecture available there. @@ -31,8 +31,8 @@ The library is built on top of [🤗 Transformers](https://github.com/huggingfac ## Highlights - **`Efficient and scalable`**: - - [🤗 Accelerate](https://github.com/huggingface/accelerate) is the backbone of TRL that model training to scale from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed. - - [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA. + - [🤗 Accelerate](https://github.com/huggingface/accelerate) is the backbone of TRL that models training to scale from a single GPU to a large-scale multi-node cluster with methods such as DDP and DeepSpeed. + - [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantization and methods such as LoRA or QLoRA. - [Unsloth](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels. - **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system. - **`Trainers`**: The trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/ppov2_trainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer). @@ -51,7 +51,7 @@ pip install trl ### From source -If you want to use the latest features before an official release you can install from source: +If you want to use the latest features before an official release, you can install TRL from source: ```bash pip install git+https://github.com/huggingface/trl.git @@ -67,7 +67,7 @@ git clone https://github.com/huggingface/trl.git ## Command Line Interface (CLI) -You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI: +You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI: **SFT:** @@ -178,7 +178,7 @@ trainer.train() ### `DPOTrainer` -`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`: +`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`: ```python from datasets import load_dataset @@ -195,7 +195,7 @@ trainer.train() ## Development -If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install: +If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install: ```bash git clone https://github.com/huggingface/trl.git @@ -214,4 +214,4 @@ make dev journal = {GitHub repository}, howpublished = {\url{https://github.com/huggingface/trl}} } -``` \ No newline at end of file +``` diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx index d63231374d..92fc11a3aa 100644 --- a/docs/source/clis.mdx +++ b/docs/source/clis.mdx @@ -96,24 +96,26 @@ python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org The chat CLI lets you quickly load the model and talk to it. Simply run the following: -```bash -trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat -``` +
$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
+<quentin_gallouedec>:
+What is the best programming language?
-> [!TIP]
-> To use the chat CLI with the developer installation, you must run `make dev`
->
+<Qwen/Qwen1.5-0.5B-Chat>:
+There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
+languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
+and scalability. Ultimately, it depends on personal preference, needs, and goals.
+
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
-- **clear**: clears the current conversation and start a new one
-- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
-- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
-- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
-- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
-- **exit**: closes the interface
+- `clear`: clears the current conversation and start a new one
+- `example {NAME}`: load example named `{NAME}` from the config and use it as the user input
+- `set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
+- `reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
+- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
+- `exit`: closes the interface
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.
diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py
index 99139f209c..d29200055c 100644
--- a/examples/scripts/chat.py
+++ b/examples/scripts/chat.py
@@ -273,7 +273,7 @@ def chat_cli():
user = args.user
model, tokenizer = load_model_and_tokenizer(args)
- generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
+ generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
diff --git a/scripts/log_example_reports.py b/scripts/log_example_reports.py
index b49a608826..fa31fbf589 100644
--- a/scripts/log_example_reports.py
+++ b/scripts/log_example_reports.py
@@ -55,7 +55,7 @@ def main(text_file_name, slack_channel_name=None):
"type": "section",
"text": {
"type": "plain_text",
- "text": "🔴 Something is wrong with the workflow please check ASAP!"
+ "text": "❌ Something is wrong with the workflow please check ASAP!"
"Something went wrong there is no text file being produced. Please check ASAP.",
"emoji": True,
},
@@ -82,7 +82,7 @@ def main(text_file_name, slack_channel_name=None):
for test_name, failed in final_results.items():
failed_table = tabulate(
- [[test_name, "🟢" if not failed else "🔴"]],
+ [[test_name, "✅" if not failed else "❌"]],
headers=["Test Name", "Status"],
showindex="always",
tablefmt="grid",
diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py
index e7d25c3d1c..6538c56b74 100644
--- a/tests/slow/test_dpo_slow.py
+++ b/tests/slow/test_dpo_slow.py
@@ -85,9 +85,9 @@ def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits):
model=model,
ref_model=None,
args=training_args,
+ train_dataset=self.dataset["train"],
+ eval_dataset=self.dataset["test"],
processing_class=tokenizer,
- train_dataset=self.dataset,
- eval_dataset=self.dataset,
)
# train the model
@@ -142,9 +142,9 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_
model=model,
ref_model=None,
args=training_args,
+ train_dataset=self.dataset["train"],
+ eval_dataset=self.dataset["test"],
processing_class=tokenizer,
- train_dataset=self.dataset,
- eval_dataset=self.dataset,
peft_config=self.peft_config,
)
@@ -206,9 +206,9 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra
model=model,
ref_model=None,
args=training_args,
+ train_dataset=self.dataset["train"],
+ eval_dataset=self.dataset["test"],
processing_class=tokenizer,
- train_dataset=self.dataset,
- eval_dataset=self.dataset,
peft_config=self.peft_config,
)
diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py
index 03a7ea7709..bb5bb8f2c9 100644
--- a/tests/test_rloo_trainer.py
+++ b/tests/test_rloo_trainer.py
@@ -13,8 +13,14 @@
# limitations under the License.
import platform
import subprocess
+import tempfile
+import unittest
import torch
+from datasets import Dataset
+from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
+
+from trl import RLOOConfig, RLOOTrainer
def test():
@@ -26,6 +32,8 @@ def test():
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
+ --sft_model_path EleutherAI/pythia-14m \
+ --reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos
@@ -71,3 +79,42 @@ def test_rloo_reward():
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
+
+
+class RLOOTrainerTester(unittest.TestCase):
+ def setUp(self):
+ self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
+ self.reward_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
+
+ self.policy_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
+ self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id)
+ self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
+
+ self.tokenizer = AutoTokenizer.from_pretrained(self.sft_model_id, padding_side="left")
+ self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+
+ def test_rloo_checkpoint(self):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ training_args = RLOOConfig(
+ output_dir=tmp_dir,
+ per_device_train_batch_size=2,
+ total_episodes=1,
+ report_to="none",
+ )
+
+ dummy_text = {"content": "Hello World!", "role": "user"}
+ dummy_data = self.tokenizer.apply_chat_template(dummy_text)
+ dummy_dataset = Dataset.from_dict({"input_ids": dummy_data})
+
+ trainer = RLOOTrainer(
+ config=training_args,
+ policy=self.policy_model,
+ reward_model=self.reward_model,
+ ref_policy=self.policy_ref_model,
+ processing_class=self.tokenizer,
+ train_dataset=dummy_dataset,
+ eval_dataset=dummy_dataset,
+ )
+
+ trainer._save_checkpoint(trainer.model, trial=None)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d23e18c841..e79edc755f 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -158,7 +158,7 @@ def test_val_none(self):
model_name="my_model",
hub_model_id="username/my_hub_model",
dataset_name=None,
- tags=None,
+ tags=[],
wandb_url=None,
trainer_name="My Trainer",
trainer_citation=None,
diff --git a/trl/commands/cli.py b/trl/commands/cli.py
index 3a9f8f83a3..dcc13aaff6 100644
--- a/trl/commands/cli.py
+++ b/trl/commands/cli.py
@@ -96,6 +96,7 @@ def train(command_name):
encoding="utf-8",
cwd=os.getcwd(),
env=os.environ.copy(),
+ capture_output=True,
)
except (CalledProcessError, ChildProcessError) as exc:
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py
index 19342597da..84776a026b 100644
--- a/trl/trainer/alignprop_trainer.py
+++ b/trl/trainer/alignprop_trainer.py
@@ -415,6 +415,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{prabhudesai2024aligning,
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py
index 3054b971d1..f101512cb5 100644
--- a/trl/trainer/bco_trainer.py
+++ b/trl/trainer/bco_trainer.py
@@ -1483,6 +1483,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{jung2024binary,
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py
index ebd70bf93f..ec62cbce13 100644
--- a/trl/trainer/cpo_trainer.py
+++ b/trl/trainer/cpo_trainer.py
@@ -1018,6 +1018,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@inproceedings{xu2024contrastive,
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py
index df6cd94af1..412a461a30 100644
--- a/trl/trainer/ddpo_trainer.py
+++ b/trl/trainer/ddpo_trainer.py
@@ -617,6 +617,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@inproceedings{black2024training,
title = {{Training Diffusion Models with Reinforcement Learning}},
diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py
index f92a5cc98e..0a313420dd 100644
--- a/trl/trainer/dpo_trainer.py
+++ b/trl/trainer/dpo_trainer.py
@@ -1727,6 +1727,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py
index d9c20308cf..20174dd0cc 100644
--- a/trl/trainer/gkd_trainer.py
+++ b/trl/trainer/gkd_trainer.py
@@ -343,6 +343,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@inproceedings{agarwal2024on-policy,
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py
index 1cddd37346..6e81a3586f 100644
--- a/trl/trainer/iterative_sft_trainer.py
+++ b/trl/trainer/iterative_sft_trainer.py
@@ -423,6 +423,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py
index 0e2d60b8aa..b7965b812c 100644
--- a/trl/trainer/kto_trainer.py
+++ b/trl/trainer/kto_trainer.py
@@ -1459,6 +1459,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{ethayarajh2024kto,
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py
index 36f48a48da..db0c3046b3 100644
--- a/trl/trainer/nash_md_trainer.py
+++ b/trl/trainer/nash_md_trainer.py
@@ -424,6 +424,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@inproceedings{munos2024nash,
title = {Nash Learning from Human Feedback},
diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py
index 95c1f826da..ffc407b57d 100644
--- a/trl/trainer/online_dpo_trainer.py
+++ b/trl/trainer/online_dpo_trainer.py
@@ -580,6 +580,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{guo2024direct,
title = {{Direct Language Model Alignment from Online AI Feedback}},
diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py
index 8a41758d02..7837949598 100644
--- a/trl/trainer/orpo_trainer.py
+++ b/trl/trainer/orpo_trainer.py
@@ -1026,6 +1026,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{hong2024orpo,
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py
index 2114cf60de..e35783c342 100644
--- a/trl/trainer/ppov2_trainer.py
+++ b/trl/trainer/ppov2_trainer.py
@@ -682,6 +682,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{mziegler2019fine-tuning,
title = {{Fine-Tuning Language Models from Human Preferences}},
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
index 64f873ef34..787c6cbd54 100644
--- a/trl/trainer/reward_trainer.py
+++ b/trl/trainer/reward_trainer.py
@@ -391,6 +391,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py
index 2532fa9167..18066976ca 100644
--- a/trl/trainer/rloo_trainer.py
+++ b/trl/trainer/rloo_trainer.py
@@ -43,7 +43,7 @@
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
-from transformers.trainer_callback import CallbackHandler, PrinterCallback
+from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
@@ -158,10 +158,6 @@ def __init__(
#########
### trainer specifics
#########
- self.state = OnlineTrainerState(
- is_local_process_zero=self.is_local_process_zero(),
- is_world_process_zero=self.is_world_process_zero(),
- )
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
@@ -169,6 +165,14 @@ def __init__(
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
+ self.state = OnlineTrainerState(
+ is_local_process_zero=self.is_local_process_zero(),
+ is_world_process_zero=self.is_world_process_zero(),
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ],
+ )
+
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
@@ -572,6 +576,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@inproceedings{ahmadian2024back,
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index 6875f95a7b..781c5f0331 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -15,7 +15,6 @@
import inspect
import os
import warnings
-from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union
import datasets
@@ -52,7 +51,6 @@
DataCollatorForCompletionOnlyLM,
generate_model_card,
peft_module_casting_to_bf16,
- trl_sanitze_kwargs_for_tagging,
)
@@ -435,21 +433,6 @@ def make_inputs_require_grad(module, input, output):
elif self.args.max_steps == -1 and args.packing:
self.train_dataset.infinite = False
- @wraps(Trainer.push_to_hub)
- def push_to_hub(
- self,
- commit_message: Optional[str] = "End of training",
- blocking: bool = True,
- **kwargs,
- ) -> str:
- """
- Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
- model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
- Unlike the parent class, we don't use the `token` argument to mitigate security risks.
- """
- kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
- return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
-
def _prepare_dataset(
self,
dataset,
@@ -639,6 +622,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py
index 2f44d5208f..c47e078014 100644
--- a/trl/trainer/utils.py
+++ b/trl/trainer/utils.py
@@ -1389,7 +1389,7 @@ def generate_model_card(
model_name: str,
hub_model_id: str,
dataset_name: Optional[str],
- tags: Union[str, List[str], None],
+ tags: List[str],
wandb_url: Optional[str],
trainer_name: str,
trainer_citation: Optional[str] = None,
@@ -1408,7 +1408,7 @@ def generate_model_card(
Hub model ID as `username/model_id`.
dataset_name (`str` or `None`):
Dataset name.
- tags (`str`, `List[str]`, or `None`):
+ tags (`List[str]`):
Tags.
wandb_url (`str` or `None`):
Weights & Biases run URL.
@@ -1425,10 +1425,6 @@ def generate_model_card(
`ModelCard`:
A ModelCard object.
"""
- if tags is None:
- tags = []
- elif isinstance(tags, str):
- tags = [tags]
card_data = ModelCardData(
base_model=base_model,
datasets=dataset_name,
diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py
index 581acbf3db..0255e6206f 100644
--- a/trl/trainer/xpo_trainer.py
+++ b/trl/trainer/xpo_trainer.py
@@ -481,6 +481,13 @@ def create_model_card(
else:
base_model = None
+ tags = tags or []
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if hasattr(self.model.config, "unsloth_version"):
+ tags.append("unsloth")
+
citation = textwrap.dedent("""\
@article{jung2024binary,
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},