Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fixes] Fix model worker stuck under some special circumstances. #67

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions realhf/experiments/common/dpo_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ class DPOConfig(CommonExperimentConfig):
:type dataset: PairedComparisonDatasetConfig
:param beta: KL regularization coefficient.
:type beta: float
:param actor_train_n_mbs: Number of microbatches for training the primary LLM.
:type actor_train_n_mbs: int
:param ref_inf_n_mbs: Number of microbatches for inference on the reference LLM.
:type ref_inf_n_mbs: int
"""

is_sft_lora: bool = False
Expand Down
12 changes: 0 additions & 12 deletions realhf/experiments/common/ppo_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,6 @@ class PPOConfig(CommonExperimentConfig):
:type dataset: PromptOnlyDatasetConfig
:param ppo: Configuration for the PPO algorithm.
:type ppo: PPOHyperparameters
:param actor_train_n_mbs: Number of minibatches for TrainActor.
:type actor_train_n_mbs: int
:param critic_train_n_mbs: Number of minibatches for TrainCritic.
:type critic_train_n_mbs: int
:param actor_gen_n_mbs: Number of minibatches for Rollout.
:type actor_gen_n_mbs: int
:param critic_inf_n_mbs: Number of minibatches for InfValues.
:type critic_inf_n_mbs: int
:param rew_inf_n_mbs: Number of minibatches for InfReward.
:type rew_inf_n_mbs: int
:param ref_inf_n_mbs: Number of minibatches for InfRef.
:type ref_inf_n_mbs: int
"""

is_sft_lora: bool = False
Expand Down
2 changes: 0 additions & 2 deletions realhf/experiments/common/rw_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class RWConfig(CommonExperimentConfig):
:type allocation: MFCConfig
:param dataset: Dataset configuration.
:type dataset: PairedComparisonDatasetConfig
:param n_mbs: Number of microbatches.
:type n_mbs: int
"""

is_sft_lora: bool = False
Expand Down
2 changes: 0 additions & 2 deletions realhf/experiments/common/sft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class SFTConfig(CommonExperimentConfig):
:type allocation: MFCConfig
:param dataset: Dataset configuration
:type dataset: PromptAnswerDatasetConfig
:param n_mbs: Number of microbatches.
:type n_mbs: int
"""

model: ModelTrainEvalConfig = dataclasses.field(
Expand Down
8 changes: 6 additions & 2 deletions realhf/impl/model/backend/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def forward(
if num_micro_batches is None:
num_micro_batches = 1
outputs = []
for mb_input in input_.split(num_micro_batches):
for mb_input in input_.split(
num_micro_batches, min_size=constants.pipe_parallel_world_size()
):
if constants.pipe_parallel_world_size() > 1:
model_output = self.pipe_runner.forward(
input_=mb_input,
Expand Down Expand Up @@ -138,7 +140,9 @@ def generate(
if num_micro_batches is None:
num_micro_batches = 1
sequences, scores, logits_mask = [], [], []
for mb_input in input_.split(num_micro_batches):
for mb_input in input_.split(
num_micro_batches, min_size=constants.pipe_parallel_world_size()
):
if constants.pipe_parallel_world_size() > 1:
res = self.pipe_runner.generate(
input_=mb_input,
Expand Down
32 changes: 18 additions & 14 deletions realhf/system/master_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def _attach_payloads_with_hooks(
return payloads, mwids


async def _request_model_function_call(
def _request_model_function_call(
rpc: dfg.MFCDef,
stream: request_reply_stream.NameResolvingRequestClient,
msid2mwid: Dict[config_pkg.ModelShardID, int],
Expand Down Expand Up @@ -438,14 +438,10 @@ async def _request_model_function_call(
)
req_ids = [stream.post(p) for h, p in payloads.items() if h in handlers]
other_req_ids = [stream.post(p) for h, p in payloads.items() if h not in handlers]
await asyncio.gather(
*[
_awaitable_response(
stream, pattern=create_exact_match_pattern([p.syn_reply_id])
)
for p in payloads.values()
]
)
[
stream.poll(block=True, pattern=create_exact_match_pattern([p.syn_reply_id]))
for p in payloads.values()
]
[
stream.post(
request_reply_stream.Payload(
Expand Down Expand Up @@ -585,7 +581,7 @@ async def model_rpc_request_func(
producer_mappings[names[0], k] = producer_mapping

# send partitioned data to model workers
req_ids, other_req_ids = await _request_model_function_call(
req_ids, other_req_ids = _request_model_function_call(
rpc=rpc,
stream=stream,
msid2mwid=msid2mwid,
Expand Down Expand Up @@ -984,9 +980,17 @@ def __lazy_init(self):

# Build some data required for subsequent model function calls.
self.__all_model_handlers: List[config_pkg.ModelShardID] = []
self.__all_mw_handlers: List[config_pkg.ModelShardID] = []
_covered_mws = set()
self.__dp0_model_handlers: List[config_pkg.ModelShardID] = []
self.__trainable_model_handlers: List[config_pkg.ModelShardID] = []
for model_name, topo in self.config.model_topos.items():
for j in range(topo.world_size()):
h = config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
_mw_id = self.config.msid2mwid[h]
if _mw_id not in _covered_mws:
_covered_mws.add(_mw_id)
self.__all_mw_handlers.append(h)
num_dp = topo.get_dim("data")
self.__all_model_handlers += [
config_pkg.ModelShardID.from_parallelism_rank(model_name, topo, j)
Expand Down Expand Up @@ -1487,9 +1491,9 @@ def _log_training_stats(self, e2e_time: float, time_since_configure: float):
def _clear_gpu_cache(self):
request_all(
self.__stream,
[vs[0] for vs in self.__mwid2msids.values()],
self.__all_mw_handlers,
"clear_data_cache",
[self.__rpc_ctrl.ids_to_clear for _ in self.__all_model_handlers],
[self.__rpc_ctrl.ids_to_clear for _ in self.__all_mw_handlers],
)
self.__rpc_ctrl.ids_to_clear.clear()

Expand All @@ -1515,9 +1519,9 @@ def experiment_complete_exit(self):
# Model workers will not respond to this message.
request_all(
self.__stream,
handlers=self.__all_model_handlers,
handlers=self.__all_mw_handlers,
handle_type="reset",
datas=[None for _ in self.__all_model_handlers],
datas=[None for _ in self.__all_mw_handlers],
)
self.__stream.close()
constants.reset_run()
Expand Down
24 changes: 10 additions & 14 deletions realhf/system/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def __lazy_setup(self):
eval_dataloader = None
self.__eval_dataloaders[s.id.model_name] = eval_dataloader

self.__request_cache = []
self.__request_cache = {}
self.__ack_cache = {}

self.__request_queue = queue.Queue(maxsize=8)
Expand Down Expand Up @@ -607,6 +607,7 @@ def model_poll_step(
del self.__data_sent_worker_indices[_id]
if _id in self.__data_received_worker_indices:
del self.__data_received_worker_indices[_id]
gc.collect()
if (
self.config.cuda_cache_cleanliness
and self.__clear_cache_frequency.check()
Expand Down Expand Up @@ -852,24 +853,19 @@ def __maybe_receive_one_request(self):
handle_name="syn",
),
)
self.__request_cache.append(r)
self.__request_cache[r.ack_reply_id] = r
except request_reply_stream.NoMessage:
return

@cuda_tmark("receive_request", CUDATimeMarkType.misc)
def maybe_receive_requests(self):
for _ in range(16):
self.__maybe_receive_one_request()

while len(self.__request_cache) > 0:
request: request_reply_stream.Payload = self.__request_cache[0]
while request.ack_reply_id not in self.__ack_cache:
self.__maybe_receive_one_request()

self.__ack_cache.pop(request.ack_reply_id)
self.__request_cache.pop(0)

self.__request_queue.put_nowait((request, request.data, False, None))
self.__maybe_receive_one_request()
cur_ack_ids = list(self.__ack_cache.keys())
for ack_id in cur_ack_ids:
if ack_id in self.__request_cache:
self.__ack_cache.pop(ack_id)
req = self.__request_cache.pop(ack_id)
self.__request_queue.put_nowait((req, req.data, False, None))

def _poll(self):
if not self.__dist_env_resolved:
Expand Down
Loading