diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index 45411219..c7a5aa16 100755 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -551,7 +551,7 @@ async def model_rpc_request_func( assert sample.bs % dp_size == 0 min_n_seqs_per_dp = sample.bs // dp_size else: - min_n_seqs_per_dp = 1 + min_n_seqs_per_dp = rpc.n_mbs split_spec = sample.get_split_spec(dp_size, min_size=min_n_seqs_per_dp) partitions = split_spec.partitions target_mapping = {i: list(range(v[0], v[1])) for i, v in enumerate(partitions)}