diff --git a/tests/test_cli.py b/tests/test_cli.py index b1b9b8aa67..be5b2ab321 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,6 +14,10 @@ import subprocess import sys import unittest +from pathlib import Path +import os +import glob +from trl.commands.cli_utils import populate_supported_commands @unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") @@ -43,3 +47,22 @@ def test_dpo_cli(): def test_env_cli(): output = subprocess.run("trl env", capture_output=True, text=True, shell=True, check=True) assert "- Python version: " in output.stdout + + +def test_populate_supported_commands(): + commands = populate_supported_commands() + + # Check for specific commands + assert 'sft' in commands, "SFT command not found" + assert 'dpo' in commands, "DPO command not found" + + # Check that all commands are strings and don't have .py extension + for cmd in commands: + assert isinstance(cmd, str), f"Command {cmd} is not a string" + assert not cmd.endswith('.py'), f"Command {cmd} should not have .py extension" + + # Check that the number of commands matches the number of .py files in the scripts directory + trl_dir = Path(__file__).resolve().parent.parent + scripts_path = os.path.join(trl_dir, 'examples', 'scripts', '*.py') + py_files = glob.glob(scripts_path) + assert len(commands) == len(py_files), f"Number of commands ({len(commands)}) doesn't match number of .py files ({len(py_files)})" \ No newline at end of file diff --git a/trl/commands/cli.py b/trl/commands/cli.py index 3a9f8f83a3..4bf9b82853 100644 --- a/trl/commands/cli.py +++ b/trl/commands/cli.py @@ -31,11 +31,10 @@ is_liger_kernel_available, is_llmblender_available, ) -from .cli_utils import get_git_commit_hash +from .cli_utils import get_git_commit_hash, populate_supported_commands -SUPPORTED_COMMANDS = ["sft", "dpo", "chat", "kto", "env"] - +SUPPORTED_COMMANDS = populate_supported_commands() def print_env(): accelerate_config = accelerate_config_str = "not found" diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 44918af961..7751e75b3c 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -17,8 +17,10 @@ import inspect import logging import os +import glob import subprocess import sys +from pathlib import Path from argparse import Namespace from dataclasses import dataclass, field @@ -305,3 +307,16 @@ def get_git_commit_hash(package_name): return None except Exception as e: return f"Error: {str(e)}" + + +def populate_supported_commands(): + # Path to the script examples directory + trl_dir = Path(__file__).resolve().parent.parent.parent + scripts_path = os.path.join(trl_dir, 'examples', 'scripts', '*.py') + # find all the scripts in the examples directory + trainer_files = glob.glob(scripts_path) + + # Extract command names without the .py extension + commands = [os.path.basename(f).replace('.py', '') for f in trainer_files] + + return commands \ No newline at end of file