Skip to content

Commit

Permalink
fix minibatch splitting in PPO and GRPO
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett4wade committed Sep 3, 2024
1 parent 3b0cfb7 commit 0616fee
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/new_algorithms/grpo/grpo_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def train_step(
ids=list(range(input_.bs * self.group_size)),
)

# Split mini-batches and run PPO training.
# Split mini-batches and run PPO training. Mini-batches have balanced sizes
datas = data_.split(self.n_minibatches, min_size=data_.bs // self.n_minibatches)
train_stats = collections.defaultdict(float)
for data in datas:
Expand Down
16 changes: 14 additions & 2 deletions realhf/impl/model/interface/ppo_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,15 @@ def train_step(
)
# NOTE: We cannot randomly shuffle data here because
# data must have the same shape across different pipeline stages.
if n_mbs is None:
n_mbs = 1
datas = input_.split(
self.n_minibatches,
min_size=constants.pipe_parallel_world_size() * 2,
min_size=(
constants.pipe_parallel_world_size() * 2 * n_mbs
if constants.pipe_parallel_world_size() > 1
else constants.pipe_parallel_world_size() * n_mbs
),
)

### Logging code starts. ###
Expand Down Expand Up @@ -799,9 +805,15 @@ def train_step(
)
# NOTE: We cannot randomly shuffle data here because
# data must have the same shape across different pipeline stages.
if n_mbs is None:
n_mbs = 1
datas = input_.split(
self.n_minibatches,
min_size=constants.pipe_parallel_world_size() * 2,
min_size=(
constants.pipe_parallel_world_size() * 2 * n_mbs
if constants.pipe_parallel_world_size() > 1
else constants.pipe_parallel_world_size() * n_mbs
),
)

# Logging.
Expand Down

0 comments on commit 0616fee

Please sign in to comment.