Skip to content

Commit

Permalink
fix forward call + xlmr
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jun 19, 2024
1 parent cc4871f commit 08a6ed9
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __call__(self, hashed_ids, attention_mask):

return {"logits": logits}


class SaTORTWrapper:
def __init__(self, config, ort_session):
self.config = config
Expand All @@ -43,11 +44,8 @@ def __getattr__(self, name):

def __call__(self, input_ids, attention_mask):
logits = self.ort_session.run(
["logits"],
{
"attention_mask": attention_mask.astype(np.int64),
"input_ids": input_ids.astype(np.int64)
},
output_names=["logits"],
input_feed={"attention_mask": attention_mask.astype(np.int64), "input_ids": input_ids.astype(np.int64)},
)[0]

return {"logits": logits}
Expand All @@ -62,7 +60,7 @@ def __getattr__(self, name):
assert hasattr(self, "model")
return getattr(self.model, name)

def __call__(self, hashed_ids, attention_mask, language_ids=None, input_ids=None):
def __call__(self, attention_mask, hashed_ids=None, language_ids=None, input_ids=None):
try:
import torch
except ImportError:
Expand Down Expand Up @@ -106,9 +104,9 @@ def extract(
if "xlm" in model.config.model_type:
use_subwords = True
tokenizer = AutoTokenizer.from_pretrained(
model.config.base_model if hasattr(model.config, "base_model") else model.config._name_or_path
"facebookAI/xlm-roberta-base",
)
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
# tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
tokens = tokenizer(batch_of_texts, return_offsets_mapping=True, verbose=False)
# remove CLS and SEP tokens, they are added later anyhow
batch_of_texts = [text[1:-1] for text in tokens["input_ids"]]
Expand Down

0 comments on commit 08a6ed9

Please sign in to comment.