Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐾 Process-supervised RM Trainer #2127

Open
wants to merge 102 commits into
base: main
Choose a base branch
from

Conversation

gaetanlop
Copy link
Contributor

@gaetanlop gaetanlop commented Sep 26, 2024

What does this PR do?

Adding support for process-supervised reward training to TRL as requested in #2110 .

List of papers using PRMs: [1], [2], [3], [4]...

Fixes # (issue)

#2110

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

@lewtun @kashif

@gaetanlop gaetanlop marked this pull request as draft September 26, 2024 03:15
@lewtun
Copy link
Member

lewtun commented Sep 26, 2024

This is awesome @gaetanlop ! Would you like some early feedback on the PR or would you prefer I wait a bit until it's more polished?

@gaetanlop
Copy link
Contributor Author

Hey @lewtun, thank you for the message. Currently, the only files that are more or less ready are prm_trainer.py and prm_config.py. The rest are just placeholders that I haven’t had the opportunity to work on yet.

Implementing a PRMs seems to be pretty straighforward, it seems to be a token classification task where only prediction for the last token of each step gets assigned a label and other tokens are ignored during loss calculation.

If the dataset isn’t pre-tokenized, I assume it should contain the following columns:

  • prompt: Either a string or past messages
  • steps: A list of strings
  • labels: A list of integers corresponding to the label associated to each step

Are you aware of an HF dataset to train PRMs for the example file? Also, how can I add a new subset to the trl-internal-testing/zen dataset to support stepwise reward models for the unit test of the prm_trainer?

Thanks again for your time!

@gaetanlop gaetanlop marked this pull request as ready for review September 28, 2024 18:34
@gaetanlop
Copy link
Contributor Author

gaetanlop commented Sep 28, 2024

PR ready for review. I have changed the naming conventions that I used before prm to the suggested naming in #2110 stepwise.

Tests: I created a dummy_dataset but we should add a subset to trl-internal-testing/zen as done in other scripts.
Example: The example is currently using a placeholder for the dataset name as to the best of my knowledge trl didn't release a dataset for stepwise reasoning on HF. We should add this too.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the very clean PR @gaetanlop - this looks great! I've left some minor suggestions regarding the structure, but aside from that and having a smallish dataset in the right format we can sanity check that the accuracy goes up, loss goes down etc I think this is quite close to being ready

docs/source/_toctree.yml Outdated Show resolved Hide resolved
docs/source/stepwise_reward_trainer.mdx Show resolved Hide resolved
docs/source/dataset_formats.mdx Outdated Show resolved Hide resolved
examples/scripts/stepwise_reward_modeling.py Outdated Show resolved Hide resolved
trl/trainer/stepwise_reward_config.py Outdated Show resolved Hide resolved
trl/trainer/stepwise_reward_config.py Outdated Show resolved Hide resolved
trl/trainer/stepwise_reward_trainer.py Outdated Show resolved Hide resolved
trl/trainer/stepwise_reward_trainer.py Outdated Show resolved Hide resolved
@gaetanlop gaetanlop changed the title [DRAFT] Process-supervised RM Trainer Process-supervised RM Trainer Oct 1, 2024
@gaetanlop
Copy link
Contributor Author

gaetanlop commented Oct 1, 2024

Thanks for looking at this @lewtun. Seems like trl-internal-testing/zen is the dataset you are using for testing. I have done a PR to trl-lib/zen, should I also PR trl-internal-testing/zen to add 19 samples of PRM800K for testing or are you handling it on your side (it looks like they are both the same dataset)?

@qgallouedec
Copy link
Member

qgallouedec commented Nov 26, 2024

@qgallouedec the term used for the labels column in standard_stepwise_supervision and prm800K differs, prm800K uses labels while standard_stepwise_supervision uses label.

Thanks for spotting it. Generation script fixed in c7cf42f, I just pushed the fixed dataset zen/47aee34.

@qgallouedec
Copy link
Member

@gaetanlop about the tokenize function, it would be better to

  • have it as a static method
  • using it with a single example (not a batch).

In fact take example from the implementation of DPOTrainer. I can fix this quickly. I offer to do a PR on your branch, wdyt?

@qgallouedec qgallouedec changed the title Process-supervised RM Trainer 🐾 Process-supervised RM Trainer Nov 26, 2024
@gaetanlop
Copy link
Contributor Author

@qgallouedec yes of course, as you prefer, I was following the implementation done in the KTOTrainer. You can open a PR otherwise I will make the changes later today

@qgallouedec
Copy link
Member

qgallouedec commented Nov 26, 2024

Should we add the separator token between the prompt and the first step? If you don't (like the current code) you get something like:

prompt = "This is my prompt."
completions = ["This is my first step.", "This is my second step."]
separator = "\n"

# Processing here

result == "This is my prompt.This is my first step. This is my second step."
#                           ^💀

@qgallouedec
Copy link
Member

qgallouedec commented Nov 26, 2024

gaetanlop#1

I still need to add the collator and then it's ready. No collator is needed in fact.

@qgallouedec
Copy link
Member

First trained model: https://huggingface.co/qgallouedec/Qwen2-0.5B-Reward

Screenshot 2024-11-26 at 21 51 58

@gaetanlop
Copy link
Contributor Author

gaetanlop commented Nov 27, 2024

@qgallouedec Thank you for the refactoring work on the tokenize_row function. I have made some adjustments to ensure proper handling of special tokens. Also, I refined the label creation process and updated the tokenize_row function to support truncation based on both max_length and max_completion_ids. I have added some tests to confirm that the updated tokenize_row function behaves as intended.

I also made some experiments. The model gets 99.8% accuracy after just a few steps... It might just be predicting True all the time, I will need to double check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants