Skip to content

Commit

Permalink
Add a check in PPOExperiment to avoid unintended behaviors. (#16)
Browse files Browse the repository at this point in the history
* refactor interval ops

* remove pybind in cpp code

* minor fix

* minor fix

* fix ppo cuda graph
  • Loading branch information
garrett4wade authored Jul 3, 2024
1 parent 50d508f commit 6caa522
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
3 changes: 2 additions & 1 deletion realhf/api/core/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class GenerationHyperparameters:
:type temperature: float
:param num_samples: The number of samples to generate.
:type num_samples: int
:param use_cuda_graph: Whether to use CUDA graph.
:param use_cuda_graph: Whether to use CUDA graph to reduce kernel launch overhead
during generation. Recommended for pure generation.
:type use_cuda_graph: bool
"""

Expand Down
13 changes: 1 addition & 12 deletions realhf/api/quickstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger("Quickstart Model Config")


@dataclasses.dataclass
@dataclasses.dataclass(unsafe_hash=True)
class ParallelismConfig:
"""Model 3D parallelism configuration.
Expand Down Expand Up @@ -46,17 +46,6 @@ def __str__(self):
)


def parallelism_config_equal(
parallel1: ParallelismConfig, parallel2: ParallelismConfig
) -> bool:
# NOTE: Implementing __eq__ in dataclass will cause error in hydra and omegaconf
return (
parallel1.model_parallel_size == parallel2.model_parallel_size
and parallel1.pipeline_parallel_size == parallel2.pipeline_parallel_size
and parallel1.data_parallel_size == parallel2.data_parallel_size
)


@dataclasses.dataclass
class LoRAConfig:
dim: int = 32
Expand Down
10 changes: 10 additions & 0 deletions realhf/experiments/common/ppo_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,16 @@ def __post_init__(self):
value_norm_eps=self.ppo.value_norm_eps,
)

if self.ppo.gen.use_cuda_graph and (
self.actor_train.parallel != self.actor_gen.parallel
):
raise ValueError(
"CUDA graph cannot be used with parameter reallocation "
"because CUDA graph requires pinned parameter memory. "
"Either set use_cuda_graph=False or set identical parallel "
"strategies for actor_train and actor_gen."
)

@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
Expand Down
4 changes: 1 addition & 3 deletions realhf/experiments/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def make_model_config(cfg: ModelTrainEvalConfig):


def resolve_rpc_hooks(rpc_allocs: List[RPCAllocation]):
from realhf.api.quickstart.model import parallelism_config_equal

role_cnt = collections.defaultdict(int)
for rpc_alloc in rpc_allocs:
Expand All @@ -130,8 +129,7 @@ def resolve_rpc_hooks(rpc_allocs: List[RPCAllocation]):
if rpc.name == other.rpc.name:
continue
if rpc.model_name.role == other.rpc.model_name.role and not (
parallelism_config_equal(parallel, other.parallel)
and device_mesh == other.device_mesh
parallel == other.parallel and device_mesh == other.device_mesh
):
other.rpc.model_name = ModelName(
rpc.model_name.role, role_cnt[rpc.model_name.role] + 1
Expand Down

0 comments on commit 6caa522

Please sign in to comment.