Skip to content

Commit

Permalink
[Bugs] Fix attn mask (#852)
Browse files Browse the repository at this point in the history
* [WIP]: Fix sequence parallel memory bottleneck in DPO

* loss mask before split

* refactor orpo

* fix attention_mask in preference_collate_fn

---------

Co-authored-by: RangiLyu <[email protected]>
  • Loading branch information
HIT-cwh and RangiLyu authored Jul 19, 2024
1 parent ba7afc7 commit 30133d5
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions xtuner/dataset/collate_fns/preference_collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def preference_collate_fn(instances: Sequence[Dict],
labels = torch.stack(labels)

if use_varlen_attn:
attention_mask = None
attention_mask = torch.ones_like(input_ids).bool()
position_ids = torch.stack(position_ids, dim=0)
else:
# Some tokenizers have the same eos token and pad token, so input_ids
Expand All @@ -74,8 +74,10 @@ def preference_collate_fn(instances: Sequence[Dict],
input_ids = pad_for_sequence_parallel(input_ids, pad_index)
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
position_ids = pad_for_sequence_parallel(position_ids, 0)
if attention_mask is not None:
attention_mask = pad_for_sequence_parallel(attention_mask, 0)
# We use attention_mask to distinguish `input_ids` from
# (sequence parallel) pad tokens in `get_var_len_atten_logps` method of
# class `DPO` and `ORPO`
attention_mask = pad_for_sequence_parallel(attention_mask, 0)
if use_varlen_attn:
(cumulative_len, attention_mask
) = pad_cumulative_len_for_sequence_parallel(cumulative_len)
Expand Down

0 comments on commit 30133d5

Please sign in to comment.