Skip to content

Commit

Permalink
deps: bump versions
Browse files Browse the repository at this point in the history
  • Loading branch information
asawczyn committed Apr 23, 2024
1 parent 15fb3e4 commit 580dd71
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
8 changes: 5 additions & 3 deletions juddges/data/datasets/context_truncator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ def __init__(self, tokenizer: BaseTokenizer, max_length: int):
{"role": "assistant", "content": ""},
]

self.empty_messages_length = self.tokenizer.apply_chat_template(
empty_messages, tokenize=True, return_dict=True, return_length=True
)["length"][0]
self.empty_messages_length = len(
self.tokenizer.apply_chat_template(
empty_messages, tokenize=True, return_dict=True
).data["input_ids"]
)

def __call__(self, prompt: str, context: str, output: str) -> str:
prompt_length, output_length = self.tokenizer(
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
accelerate==0.28.0
accelerate==0.29.3
bitsandbytes==0.43.0
chardet==5.2.0
datasets==2.18.0
Expand All @@ -23,8 +23,8 @@ tenacity==8.2.3
tensorboard==2.16.2
tiktoken==0.6.0
torch==2.2.1
transformers==4.38.2
trl==0.8.1
transformers==4.40.0
trl==0.8.6
typer==0.9.0
wandb==0.16.5
xmltodict==0.13.0
Expand Down
6 changes: 3 additions & 3 deletions tests/test_context_truncator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def _check(self, model_id: str, max_length: int):
]

original_tokenized = tokenizer.apply_chat_template(
messages, tokenize=True, return_dict=True, return_length=True
messages, tokenize=True, return_dict=True
)
original_length = original_tokenized["length"][0]
original_length = len(original_tokenized.data["input_ids"])

truncated_context = ContextTruncator(tokenizer, max_length)(prompt, context, output)

Expand All @@ -65,7 +65,7 @@ def _check(self, model_id: str, max_length: int):
truncated_tokenized = tokenizer.apply_chat_template(
messages, tokenize=True, return_dict=True, return_length=True
)
truncated_length = truncated_tokenized["length"][0]
truncated_length = len(truncated_tokenized.data["input_ids"])

self.assertLess(truncated_length, original_length)
self.assertLess(truncated_length, max_length)
Expand Down

0 comments on commit 580dd71

Please sign in to comment.