Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a check in PPOExperiment to avoid unintended behaviors. #16

Merged
merged 6 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion realhf/api/core/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class GenerationHyperparameters:
:type temperature: float
:param num_samples: The number of samples to generate.
:type num_samples: int
:param use_cuda_graph: Whether to use CUDA graph.
:param use_cuda_graph: Whether to use CUDA graph to reduce kernel launch overhead
during generation. Recommended for pure generation.
:type use_cuda_graph: bool
"""

Expand Down
13 changes: 1 addition & 12 deletions realhf/api/quickstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = logging.getLogger("Quickstart Model Config")


@dataclasses.dataclass
@dataclasses.dataclass(unsafe_hash=True)
class ParallelismConfig:
"""Model 3D parallelism configuration.

Expand Down Expand Up @@ -46,17 +46,6 @@ def __str__(self):
)


def parallelism_config_equal(
parallel1: ParallelismConfig, parallel2: ParallelismConfig
) -> bool:
# NOTE: Implementing __eq__ in dataclass will cause error in hydra and omegaconf
return (
parallel1.model_parallel_size == parallel2.model_parallel_size
and parallel1.pipeline_parallel_size == parallel2.pipeline_parallel_size
and parallel1.data_parallel_size == parallel2.data_parallel_size
)


@dataclasses.dataclass
class LoRAConfig:
dim: int = 32
Expand Down
10 changes: 10 additions & 0 deletions realhf/experiments/common/ppo_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,16 @@ def __post_init__(self):
value_norm_eps=self.ppo.value_norm_eps,
)

if self.ppo.gen.use_cuda_graph and (
self.actor_train.parallel != self.actor_gen.parallel
):
raise ValueError(
"CUDA graph cannot be used with parameter reallocation "
"because CUDA graph requires pinned parameter memory. "
"Either set use_cuda_graph=False or set identical parallel "
"strategies for actor_train and actor_gen."
)

@property
def models(self) -> Dict[str, ModelTrainEvalConfig]:
# role to config
Expand Down
4 changes: 1 addition & 3 deletions realhf/experiments/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def make_model_config(cfg: ModelTrainEvalConfig):


def resolve_rpc_hooks(rpc_allocs: List[RPCAllocation]):
from realhf.api.quickstart.model import parallelism_config_equal

role_cnt = collections.defaultdict(int)
for rpc_alloc in rpc_allocs:
Expand All @@ -130,8 +129,7 @@ def resolve_rpc_hooks(rpc_allocs: List[RPCAllocation]):
if rpc.name == other.rpc.name:
continue
if rpc.model_name.role == other.rpc.model_name.role and not (
parallelism_config_equal(parallel, other.parallel)
and device_mesh == other.device_mesh
parallel == other.parallel and device_mesh == other.device_mesh
):
other.rpc.model_name = ModelName(
rpc.model_name.role, role_cnt[rpc.model_name.role] + 1
Expand Down
Loading