Skip to content

Commit

Permalink
feat: push synth qa dataset to hub, add tokenizer util caching
Browse files Browse the repository at this point in the history
  • Loading branch information
jamnicki committed Jun 3, 2024
1 parent 59cbcbc commit 4302ae9
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 189 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 19 additions & 5 deletions juddges/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from joblib import Memory
from jsonlines import jsonlines
from tqdm.auto import tqdm
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)

from juddges.settings import CACHE_DIR

Expand Down Expand Up @@ -36,7 +42,9 @@ def path_safe_udate() -> str:


@util_memory.cache(ignore=["model"])
def _model_call(model: torch.nn.Module, model_name: str, *args, **kwargs) -> Any:
def _model_call(
model: torch.nn.Module | PreTrainedTokenizer | PreTrainedModel, model_name: str, *args, **kwargs
) -> Any:
return model(*args, **kwargs)


Expand Down Expand Up @@ -72,7 +80,9 @@ def get_texts_embeddings(

for text in batch_texts:
# Tokenize text and split into chunks
tokens = tokenizer.encode(text, add_special_tokens=True, truncation=False)
tokens = _model_call(
tokenizer, model_name, text, add_special_tokens=True, truncation=False
).input_ids
chunks = [tokens[j : j + max_length] for j in range(0, len(tokens), max_length)]

# Initialize tensor to accumulate embeddings
Expand Down Expand Up @@ -136,7 +146,9 @@ def get_texts_sentiment(
chunks = [text[j : j + max_length] for j in range(0, len(text), max_length)]
chunk_scores = []
for chunk in chunks:
encoding = sent_tokenizer(
encoding = _model_call(
sent_tokenizer,
model_name,
[chunk],
add_special_tokens=True,
return_token_type_ids=True,
Expand Down Expand Up @@ -191,7 +203,9 @@ def get_texts_formality(
chunks = [text[j : j + max_length] for j in range(0, len(text), max_length)]
chunk_scores = []
for chunk in chunks:
encoding = formality_tokenizer(
encoding = _model_call(
formality_tokenizer,
model_name,
[chunk],
add_special_tokens=True,
return_token_type_ids=True,
Expand Down
Loading

0 comments on commit 4302ae9

Please sign in to comment.