Skip to content

Commit

Permalink
feat: add context truncator
Browse files Browse the repository at this point in the history
  • Loading branch information
asawczyn committed Apr 19, 2024
1 parent e0b9b20 commit da03629
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions juddges/data/datasets/context_truncator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import warnings

from tokenizers.implementations import BaseTokenizer


class ContextTruncator:
def __init__(self, tokenizer: BaseTokenizer, max_length: int):
self.tokenizer = tokenizer
self.max_length = max_length

empty_messages = [
{"role": "user", "content": ""},
{"role": "assistant", "content": ""},
]

self.empty_messages_length = self.tokenizer.apply_chat_template(
empty_messages, tokenize=True, return_dict=True, return_length=True
)["length"][0]

def __call__(self, prompt: str, context: str, output: str) -> str:
prompt_length, output_length = self.tokenizer(
[prompt, output], return_length=True, add_special_tokens=False
)["length"]

context_length = (
self.max_length - prompt_length - output_length - self.empty_messages_length
)
if context_length <= 0:
warnings.warn(
f"Context was truncated to 0 tokens. "
f"The prompt and output are too long for the max_length of {self.max_length}."
)
return ""
context_ids = self.tokenizer(
context, max_length=context_length, truncation=True, add_special_tokens=False
)["input_ids"]
return self.tokenizer.decode(context_ids)

0 comments on commit da03629

Please sign in to comment.