diff --git a/model2vec/model.py b/model2vec/model.py index f6d9f65..44e146b 100644 --- a/model2vec/model.py +++ b/model2vec/model.py @@ -56,6 +56,7 @@ def __init__( else: self.unk_token_id = None + self.median_token_length = int(np.median([len(token) for token in self.tokens])) self.config = config self.base_model_name = base_model_name self.language = language @@ -123,6 +124,10 @@ def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple :param max_length: The maximum length of the sentence. :return: The tokens. """ + if max_length is not None: + m = max_length * self.median_token_length + sentences = [sentence[:m] for sentence in sentences] + encodings: list[Encoding] = self.tokenizer.encode_batch(sentences, add_special_tokens=False) encodings_ids = [encoding.ids for encoding in encodings]