Skip to content

Commit

Permalink
🎲 Move random judges in testing utilities (#2365)
Browse files Browse the repository at this point in the history
* Update judges and testing utilities

* Update judges in test files

* Update judges in test files
  • Loading branch information
qgallouedec authored Nov 18, 2024
1 parent b5eabbe commit b80c1a6
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 109 deletions.
46 changes: 19 additions & 27 deletions docs/source/judges.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TRL provides judges to easily compare two completions.
Make sure to have installed the required dependencies by running:

```bash
pip install trl[llm_judge]
pip install trl[judges]
```

## Using the provided judges
Expand Down Expand Up @@ -52,46 +52,38 @@ judge.judge(
) # Outputs: [0, 1]
```

## AllTrueJudge
## Provided judges

[[autodoc]] AllTrueJudge

## BaseJudge

[[autodoc]] BaseJudge

## BaseBinaryJudge
### PairRMJudge

[[autodoc]] BaseBinaryJudge
[[autodoc]] PairRMJudge

## BaseRankJudge
### HfPairwiseJudge

[[autodoc]] BaseRankJudge

## BasePairwiseJudge
[[autodoc]] HfPairwiseJudge

[[autodoc]] BasePairwiseJudge
### OpenAIPairwiseJudge

## RandomBinaryJudge
[[autodoc]] OpenAIPairwiseJudge

[[autodoc]] RandomBinaryJudge
### AllTrueJudge

## RandomRankJudge
[[autodoc]] AllTrueJudge

[[autodoc]] RandomRankJudge
## Base classes

## RandomPairwiseJudge
### BaseJudge

[[autodoc]] RandomPairwiseJudge
[[autodoc]] BaseJudge

## PairRMJudge
### BaseBinaryJudge

[[autodoc]] PairRMJudge
[[autodoc]] BaseBinaryJudge

## HfPairwiseJudge
### BaseRankJudge

[[autodoc]] HfPairwiseJudge
[[autodoc]] BaseRankJudge

## OpenAIPairwiseJudge
### BasePairwiseJudge

[[autodoc]] OpenAIPairwiseJudge
[[autodoc]] BasePairwiseJudge
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"diffusers": ["diffusers>=0.18.0"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"llm_judge": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
Expand Down
34 changes: 3 additions & 31 deletions tests/test_judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,9 @@
import time
import unittest

from trl import (
AllTrueJudge,
HfPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)
from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge

from .testing_utils import require_llm_blender
from .testing_utils import RandomBinaryJudge, require_llm_blender


class TestJudges(unittest.TestCase):
Expand All @@ -45,28 +38,6 @@ def test_all_true_judge(self):
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_binary_judge(self):
judge = RandomBinaryJudge()
prompts, completions = self._get_prompts_and_single_completions()
judgements = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(judgements), 2)
self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements))

def test_random_pairwise_judge(self):
judge = RandomPairwiseJudge()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, int) for rank in ranks))

def test_random_rank_judge(self):
judge = RandomRankJudge()
prompts, completions = self._get_prompts_and_pairwise_completions()
ranks = judge.judge(prompts=prompts, completions=completions)
self.assertEqual(len(ranks), 2)
self.assertTrue(all(isinstance(rank, list) for rank in ranks))
self.assertTrue(all(all(isinstance(rank, int) for rank in ranks) for ranks in ranks))

@unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.")
def test_hugging_face_judge(self):
judge = HfPairwiseJudge()
Expand All @@ -84,6 +55,7 @@ def load_pair_rm_judge(self):
return PairRMJudge()
except ValueError:
time.sleep(5)
raise ValueError("Failed to load PairRMJudge")

@require_llm_blender
def test_pair_rm_judge(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import NashMDConfig, NashMDTrainer, PairRMJudge
from trl import NashMDConfig, NashMDTrainer

from .testing_utils import require_llm_blender
from .testing_utils import RandomPairwiseJudge, require_llm_blender


if is_peft_available():
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_nash_md_trainer_judge_training(self, config_name):
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
judge = PairRMJudge()
judge = RandomPairwiseJudge()

trainer = NashMDTrainer(
model=self.model,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import OnlineDPOConfig, OnlineDPOTrainer, RandomPairwiseJudge, is_llm_blender_available
from trl import OnlineDPOConfig, OnlineDPOTrainer, is_llm_blender_available
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from .testing_utils import RandomPairwiseJudge


if is_peft_available():
from peft import LoraConfig, get_peft_model
Expand Down
4 changes: 3 additions & 1 deletion tests/test_xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import RandomPairwiseJudge, XPOConfig, XPOTrainer, is_llm_blender_available
from trl import XPOConfig, XPOTrainer, is_llm_blender_available

from .testing_utils import RandomPairwiseJudge


if is_peft_available():
Expand Down
24 changes: 23 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest

from transformers import is_sklearn_available, is_wandb_available

from trl import is_diffusers_available, is_llm_blender_available
from trl import BaseBinaryJudge, BasePairwiseJudge, is_diffusers_available, is_llm_blender_available


def require_diffusers(test_case):
Expand Down Expand Up @@ -44,3 +45,24 @@ def require_llm_blender(test_case):
Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available.
"""
return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case)


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""

def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]


class RandomPairwiseJudge(BasePairwiseJudge):
"""
Random pairwise judge, for testing purposes.
"""

def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
if not return_scores:
return [random.randint(0, len(completion) - 1) for completion in completions]
else:
return [random.random() for _ in range(len(prompts))]
8 changes: 1 addition & 7 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
"AlignPropConfig",
"AlignPropTrainer",
"AllTrueJudge",
"BaseJudge",
"BaseBinaryJudge",
"BaseJudge",
"BasePairwiseJudge",
"BaseRankJudge",
"BCOConfig",
Expand Down Expand Up @@ -81,9 +81,6 @@
"PairRMJudge",
"PPOConfig",
"PPOTrainer",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
"RewardConfig",
"RewardTrainer",
"RLOOConfig",
Expand Down Expand Up @@ -173,9 +170,6 @@
PairRMJudge,
PPOConfig,
PPOTrainer,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
RewardConfig,
RewardTrainer,
RLOOConfig,
Expand Down
8 changes: 0 additions & 8 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@
"HfPairwiseJudge",
"OpenAIPairwiseJudge",
"PairRMJudge",
"RandomBinaryJudge",
"RandomPairwiseJudge",
"RandomRankJudge",
],
"kto_config": ["KTOConfig"],
"kto_trainer": ["KTOTrainer"],
Expand Down Expand Up @@ -109,9 +106,6 @@
HfPairwiseJudge,
OpenAIPairwiseJudge,
PairRMJudge,
RandomBinaryJudge,
RandomPairwiseJudge,
RandomRankJudge,
)
from .kto_config import KTOConfig
from .kto_trainer import KTOTrainer
Expand All @@ -124,8 +118,6 @@
from .orpo_trainer import ORPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .ppov2_config import PPOv2Config
from .ppov2_trainer import PPOv2Trainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .rloo_config import RLOOConfig
Expand Down
29 changes: 0 additions & 29 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import concurrent.futures
import logging
import random
from abc import ABC, abstractmethod
from typing import List, Optional, Union

Expand Down Expand Up @@ -183,34 +182,6 @@ def judge(
raise NotImplementedError("Judge subclasses must implement the `judge` method.")


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
"""

def judge(self, prompts, completions, gold_completions=None, shuffle_order=True):
return [random.choice([0, 1, -1]) for _ in range(len(prompts))]


class RandomRankJudge(BaseRankJudge):
"""
Random rank, for testing purposes.
"""

def judge(self, prompts, completions, shuffle_order=True):
num_completions = [len(completions[i]) for i in range(len(prompts))]
return [random.sample(range(n), n) for n in num_completions]


class RandomPairwiseJudge(BasePairwiseJudge):
"""
Random pairwise judge, for testing purposes.
"""

def judge(self, prompts, completions, shuffle_order=True):
return [random.randint(0, len(completion) - 1) for completion in completions]


class PairRMJudge(BasePairwiseJudge):
"""
LLM judge based on the PairRM model from AllenAI.
Expand Down

0 comments on commit b80c1a6

Please sign in to comment.