Skip to content

Commit

Permalink
🤝 Mixture of judges (#2159)
Browse files Browse the repository at this point in the history
* base judge

* adding mixture of judges

* update doc

* update doc

* formatting

* fix small typo in doc

* fix randomcontraintjudge

* replace arxiv by hf papers

Co-authored-by: Quentin Gallouédec <[email protected]>

* formatting

Co-authored-by: Quentin Gallouédec <[email protected]>

* fix naming in __init__

* run precommi

* adding gold answers to judges

* cgpo llm judges

* fix init

* output type

* adjust booleans in test

* adapt moj doc

* renaming and removing factuality and safety judges

* fix typo in import

* fix small typo in naming

* formatting

* Update trl/trainer/judges.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* update parameter name

* update tests

* update doc

* Update trl/trainer/judges.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* Update doc

Co-authored-by: Quentin Gallouédec <[email protected]>

* fix alltruejudge type

* Refactor judge variable names and update test names

* Clarify judgment logic

* Fix invalid binary judgment check in AllTrueJudge class

* Fix invalid binary judgment check in AllTrueJudge class

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2024
1 parent cbf9abc commit b5eabbe
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 7 deletions.
12 changes: 12 additions & 0 deletions docs/source/judges.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ judge.judge(
) # Outputs: [0, 1]
```

## AllTrueJudge

[[autodoc]] AllTrueJudge

## BaseJudge

[[autodoc]] BaseJudge

## BaseBinaryJudge

[[autodoc]] BaseBinaryJudge

## BaseRankJudge

[[autodoc]] BaseRankJudge
Expand All @@ -64,6 +72,10 @@ judge.judge(

[[autodoc]] BasePairwiseJudge

## RandomBinaryJudge

[[autodoc]] RandomBinaryJudge

## RandomRankJudge

[[autodoc]] RandomRankJudge
Expand Down
40 changes: 33 additions & 7 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,53 @@
import time
import unittest

from trl import HfPairwiseJudge, PairRMJudge, RandomPairwiseJudge, RandomRankJudge
from trl import (
AllTrueJudge,
HfPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)

from .testing_utils import require_llm_blender


class TestJudges(unittest.TestCase):
def _get_prompts_and_completions(self):
def _get_prompts_and_pairwise_completions(self):
prompts = ["The capital of France is", "The biggest planet in the solar system is"]
completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
return prompts, completions

def _get_prompts_and_single_completions(self):
prompts = ["What's the capital of France?", "What's the color of the sky?"]
completions = ["Marseille", "blue"]
return prompts, completions

def test_all_true_judge(self):
judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_binary_judge(self):
judge = RandomBinaryJudge()
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_pairwise_judge(self):
judge = RandomPairwiseJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))

def test_random_rank_judge(self):
judge = RandomRankJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, list) for rank in ranks))
Expand All @@ -44,7 +70,7 @@ def test_random_rank_judge(self):
@unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HfPairwiseJudge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
Expand All @@ -62,7 +88,7 @@ def load_pair_rm_judge(self):
@require_llm_blender
def test_pair_rm_judge(self):
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))
Expand All @@ -71,7 +97,7 @@ def test_pair_rm_judge(self):
@require_llm_blender
def test_pair_rm_judge_return_scores(self):
judge = self.load_pair_rm_judge()
prompts, completions = self._get_prompts_and_completions()
prompts, completions = self._get_prompts_and_pairwise_completions()
probs = judge.judge(prompts=prompts, completions=completions, return_scores=True)
self.assertEqual(len(probs), 2)
self.assertTrue(all(isinstance(prob, float) for prob in probs))
Expand Down
6 changes: 6 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
"trainer": [
"AlignPropConfig",
"AlignPropTrainer",
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"BCOConfig",
Expand Down Expand Up @@ -79,6 +81,7 @@
"PairRMJudge",
"PPOConfig",
"PPOTrainer",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
"RewardConfig",
Expand Down Expand Up @@ -138,6 +141,8 @@
from .trainer import (
AlignPropConfig,
AlignPropTrainer,
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
Expand Down Expand Up @@ -168,6 +173,7 @@
PairRMJudge,
PPOConfig,
PPOTrainer,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
RewardConfig,
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
"gkd_trainer": ["GKDTrainer"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"judges": [
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"HfPairwiseJudge",
"OpenAIPairwiseJudge",
"PairRMJudge",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
],
Expand Down Expand Up @@ -98,12 +101,15 @@
from .gkd_trainer import GKDTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .judges import (
AllTrueJudge,
BaseBinaryJudge,
BaseJudge,
BasePairwiseJudge,
BaseRankJudge,
HfPairwiseJudge,
OpenAIPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)
Expand Down
92 changes: 92 additions & 0 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,54 @@ def judge(self, prompts: List[str], completions: List[List[str]], shuffle_order:
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class BaseBinaryJudge(BaseJudge):
"""
Base class for binary judges.
"""

@abstractmethod
def judge(
self,
prompts: List[str],
completions: List[str],
gold_completions: Optional[List[str]] = None,
shuffle_order: bool = True,
) -> List[int]:
"""
Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.
This base class should be used to implement binary evaluations as done in section 4.1.4 of the
[CGPO paper](https://huggingface.co/papers/2409.20370).
It is relevant for assessing whether or not a prompt completion pair satisfies a specific contraint.
Args:
prompts (`List[str]`): List of prompts.
completions (`List[str]`): List of completions.
gold_completions (`List[str]`, `optional`): List of gold completions if it exists.
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.
Returns:
List[int]: A list of binary labels:
- 1 indicates that the completion satisfies the evaluated constraint.
- 0 indicates that the completion does not satisfy the evaluated constraint.
Note:
If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed.
For instance, this could occur if the underlying language model or rule based contraint returned an invalid answer.
In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling.
"""
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""

def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]


class RandomRankJudge(BaseRankJudge):
"""
Random rank, for testing purposes.
Expand Down Expand Up @@ -392,3 +440,47 @@ def get_rank(prompt, candidates):

# Return the ranks
return ranks


class AllTrueJudge(BaseBinaryJudge):
"""
Unify the decision of multiple [`BaseBinaryJudge`] instances.
Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`.
If any judge returns `-1`, indicating a failure in its process, this judge will also return `-1`.
Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370).
Args:
judges (`List[BaseBinaryJudge]`): A list of [`BaseBinaryJudge`] instances whose decisions will be unified.
"""

def __init__(self, judges: List[BaseBinaryJudge]):
self.judges = judges

def judge(
self,
prompts: List[str],
completions: List[str],
gold_completions: Optional[List[str]] = None,
shuffle_order: bool = True,
) -> List[int]:
all_binary_judgments = [
judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
]
output = []
for binary_judgments in zip(*all_binary_judgments):
# Check that all values are in {0, 1, -1}
if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments):
raise ValueError(
f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}."
)

# Unify the decision
if -1 in binary_judgments:
output.append(-1)
elif all(binary_judgment == 1 for binary_judgment in binary_judgments):
output.append(1)
else:
output.append(0)
return output

0 comments on commit b5eabbe

Please sign in to comment.