diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 875c893f55..f9c0d34e98 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -209,15 +209,33 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D ): human_token_ids_idxs = [0] + human_token_ids_idxs - for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): - # Make pytorch loss function ignore all non response tokens - if idx != 0: - batch["labels"][i, start:end] = self.ignore_index + pointer_human = 0 + pointer_response = 0 + mask_start = -1 + mask_end = -1 + + + while pointer_response <= len(response_token_ids_idxs) - 1 and pointer_human <= len(human_token_ids_idxs) - 1: + if mask_start == -1: + mask_start = 0 if response_token_ids_idxs[0] != 0 else human_token_ids_idxs[pointer_human] + if mask_end == -1: + mask_end = response_token_ids_idxs[0] + if response_token_ids_idxs[pointer_response] > human_token_ids_idxs[pointer_human]: + if mask_end < mask_start: + mask_end = response_token_ids_idxs[pointer_response] + pointer_human += 1 + elif response_token_ids_idxs[pointer_response] < human_token_ids_idxs[pointer_human]: + if mask_start < mask_end: + batch["labels"][i, mask_start:mask_end] = self.ignore_index + mask_start = human_token_ids_idxs[pointer_human] + pointer_response += 1 else: - batch["labels"][i, :end] = self.ignore_index - - if len(response_token_ids_idxs) < len(human_token_ids_idxs): - batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index + raise Exception("response_token_id and human_token_id are impossible to be the same. Please check your response and human template ids") + if pointer_human < len(human_token_ids_idxs) - 1: + while human_token_ids_idxs[pointer_human] < mask_end: + pointer_human += 1 + if pointer_human <= len(human_token_ids_idxs) - 1: + batch["labels"][i, human_token_ids_idxs[pointer_human]:] = self.ignore_index if self.padding_free: # remove padding, `attention_mask` and add `position_ids`