diff --git a/rl/llm/data_collator.py b/rl/llm/data_collator.py index 0058d69..5b7b78b 100644 --- a/rl/llm/data_collator.py +++ b/rl/llm/data_collator.py @@ -15,8 +15,8 @@ @dataclass class DataCollatorForCausalLM(object): tokenizer: transformers.PreTrainedTokenizer - input_max_len: int - output_max_len: int + # input_max_len: int + # output_max_len: int train_on_input: bool predict_with_generate: bool @@ -42,8 +42,8 @@ def __call__( ] tokenized_inputs = self.tokenizer( sources, - max_length=self.input_max_len, - truncation=True, + # max_length=self.input_max_len, + # truncation=True, add_special_tokens=False, ) if not self.predict_with_generate: @@ -53,8 +53,8 @@ def __call__( ] tokenized_outputs = self.tokenizer( targets, - max_length=self.output_max_len, - truncation=True, + # max_length=self.output_max_len, + # truncation=True, add_special_tokens=False, ) else: diff --git a/rl/llm/train_llm.py b/rl/llm/train_llm.py index 24096dc..2260cba 100644 --- a/rl/llm/train_llm.py +++ b/rl/llm/train_llm.py @@ -416,8 +416,8 @@ def get_trainer( tokenizer, train_on_input=False, predict_with_generate=False, - input_max_len=4096, - output_max_len=2048, + # input_max_len=4096, + # output_max_len=2048, ), }