Skip to content

Commit

Permalink
Fix stupid truncation bug in training
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jun 15, 2024
1 parent 4a5a842 commit bfa8d2b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions rl/llm/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
@dataclass
class DataCollatorForCausalLM(object):
tokenizer: transformers.PreTrainedTokenizer
input_max_len: int
output_max_len: int
# input_max_len: int
# output_max_len: int
train_on_input: bool
predict_with_generate: bool

Expand All @@ -42,8 +42,8 @@ def __call__(
]
tokenized_inputs = self.tokenizer(
sources,
max_length=self.input_max_len,
truncation=True,
# max_length=self.input_max_len,
# truncation=True,
add_special_tokens=False,
)
if not self.predict_with_generate:
Expand All @@ -53,8 +53,8 @@ def __call__(
]
tokenized_outputs = self.tokenizer(
targets,
max_length=self.output_max_len,
truncation=True,
# max_length=self.output_max_len,
# truncation=True,
add_special_tokens=False,
)
else:
Expand Down
4 changes: 2 additions & 2 deletions rl/llm/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def get_trainer(
tokenizer,
train_on_input=False,
predict_with_generate=False,
input_max_len=4096,
output_max_len=2048,
# input_max_len=4096,
# output_max_len=2048,
),
}

Expand Down

0 comments on commit bfa8d2b

Please sign in to comment.