diff --git a/xtuner/dataset/collate_fns/preference_collate_fn.py b/xtuner/dataset/collate_fns/preference_collate_fn.py index 8a6060410..ca21613bb 100644 --- a/xtuner/dataset/collate_fns/preference_collate_fn.py +++ b/xtuner/dataset/collate_fns/preference_collate_fn.py @@ -58,7 +58,7 @@ def preference_collate_fn(instances: Sequence[Dict], labels = torch.stack(labels) if use_varlen_attn: - attention_mask = None + attention_mask = torch.ones_like(input_ids).bool() position_ids = torch.stack(position_ids, dim=0) else: # Some tokenizers have the same eos token and pad token, so input_ids @@ -74,8 +74,10 @@ def preference_collate_fn(instances: Sequence[Dict], input_ids = pad_for_sequence_parallel(input_ids, pad_index) labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) position_ids = pad_for_sequence_parallel(position_ids, 0) - if attention_mask is not None: - attention_mask = pad_for_sequence_parallel(attention_mask, 0) + # We use attention_mask to distinguish `input_ids` from + # (sequence parallel) pad tokens in `get_var_len_atten_logps` method of + # class `DPO` and `ORPO` + attention_mask = pad_for_sequence_parallel(attention_mask, 0) if use_varlen_attn: (cumulative_len, attention_mask ) = pad_cumulative_len_for_sequence_parallel(cumulative_len)