From 6caa52281bf4bdf7a360a453a164cd62c2e725b5 Mon Sep 17 00:00:00 2001 From: Wei Fu <36355462+garrett4wade@users.noreply.github.com> Date: Wed, 3 Jul 2024 10:53:10 +0800 Subject: [PATCH] Add a check in `PPOExperiment` to avoid unintended behaviors. (#16) * refactor interval ops * remove pybind in cpp code * minor fix * minor fix * fix ppo cuda graph --- realhf/api/core/model_api.py | 3 ++- realhf/api/quickstart/model.py | 13 +------------ realhf/experiments/common/ppo_exp.py | 10 ++++++++++ realhf/experiments/common/utils.py | 4 +--- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/realhf/api/core/model_api.py b/realhf/api/core/model_api.py index 1b7b3764..14bb7767 100755 --- a/realhf/api/core/model_api.py +++ b/realhf/api/core/model_api.py @@ -36,7 +36,8 @@ class GenerationHyperparameters: :type temperature: float :param num_samples: The number of samples to generate. :type num_samples: int - :param use_cuda_graph: Whether to use CUDA graph. + :param use_cuda_graph: Whether to use CUDA graph to reduce kernel launch overhead + during generation. Recommended for pure generation. :type use_cuda_graph: bool """ diff --git a/realhf/api/quickstart/model.py b/realhf/api/quickstart/model.py index 63099dea..99064cf1 100755 --- a/realhf/api/quickstart/model.py +++ b/realhf/api/quickstart/model.py @@ -7,7 +7,7 @@ logger = logging.getLogger("Quickstart Model Config") -@dataclasses.dataclass +@dataclasses.dataclass(unsafe_hash=True) class ParallelismConfig: """Model 3D parallelism configuration. @@ -46,17 +46,6 @@ def __str__(self): ) -def parallelism_config_equal( - parallel1: ParallelismConfig, parallel2: ParallelismConfig -) -> bool: - # NOTE: Implementing __eq__ in dataclass will cause error in hydra and omegaconf - return ( - parallel1.model_parallel_size == parallel2.model_parallel_size - and parallel1.pipeline_parallel_size == parallel2.pipeline_parallel_size - and parallel1.data_parallel_size == parallel2.data_parallel_size - ) - - @dataclasses.dataclass class LoRAConfig: dim: int = 32 diff --git a/realhf/experiments/common/ppo_exp.py b/realhf/experiments/common/ppo_exp.py index 8d47622a..69890289 100755 --- a/realhf/experiments/common/ppo_exp.py +++ b/realhf/experiments/common/ppo_exp.py @@ -261,6 +261,16 @@ def __post_init__(self): value_norm_eps=self.ppo.value_norm_eps, ) + if self.ppo.gen.use_cuda_graph and ( + self.actor_train.parallel != self.actor_gen.parallel + ): + raise ValueError( + "CUDA graph cannot be used with parameter reallocation " + "because CUDA graph requires pinned parameter memory. " + "Either set use_cuda_graph=False or set identical parallel " + "strategies for actor_train and actor_gen." + ) + @property def models(self) -> Dict[str, ModelTrainEvalConfig]: # role to config diff --git a/realhf/experiments/common/utils.py b/realhf/experiments/common/utils.py index ae3c6cf4..ee962d75 100644 --- a/realhf/experiments/common/utils.py +++ b/realhf/experiments/common/utils.py @@ -117,7 +117,6 @@ def make_model_config(cfg: ModelTrainEvalConfig): def resolve_rpc_hooks(rpc_allocs: List[RPCAllocation]): - from realhf.api.quickstart.model import parallelism_config_equal role_cnt = collections.defaultdict(int) for rpc_alloc in rpc_allocs: @@ -130,8 +129,7 @@ def resolve_rpc_hooks(rpc_allocs: List[RPCAllocation]): if rpc.name == other.rpc.name: continue if rpc.model_name.role == other.rpc.model_name.role and not ( - parallelism_config_equal(parallel, other.parallel) - and device_mesh == other.device_mesh + parallel == other.parallel and device_mesh == other.device_mesh ): other.rpc.model_name = ModelName( rpc.model_name.role, role_cnt[rpc.model_name.role] + 1