Skip to content

Commit

Permalink
make encoder attributes private
Browse files Browse the repository at this point in the history
  • Loading branch information
mrjleo committed Dec 12, 2024
1 parent 9857799 commit e0bff48
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/fast_forward/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,20 @@ def __init__(
:param **tokenizer_args: Additional tokenizer arguments.
"""
super().__init__()
self.model = AutoModel.from_pretrained(model)
self.model.to(device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.device = device
self.tokenizer_args = tokenizer_args
self._model = AutoModel.from_pretrained(model)
self._model.to(device)
self._model.eval()
self._tokenizer = AutoTokenizer.from_pretrained(model)
self._device = device
self._tokenizer_args = tokenizer_args

def _encode(self, texts: "Sequence[str]") -> np.ndarray:
inputs = self.tokenizer(list(texts), return_tensors="pt", **self.tokenizer_args)
inputs.to(self.device)
inputs = self._tokenizer(
list(texts), return_tensors="pt", **self._tokenizer_args
)
inputs.to(self._device)
with torch.no_grad():
return self.model(**inputs).pooler_output.detach().cpu().numpy()
return self._model(**inputs).pooler_output.detach().cpu().numpy()


class LambdaEncoder(Encoder):
Expand All @@ -79,17 +81,17 @@ class TCTColBERTQueryEncoder(TransformerEncoder):

def _encode(self, texts: "Sequence[str]") -> np.ndarray:
max_length = 36
inputs = self.tokenizer(
inputs = self._tokenizer(
["[CLS] [Q] " + q + "[MASK]" * max_length for q in texts],
max_length=max_length,
truncation=True,
add_special_tokens=False,
return_tensors="pt",
**self.tokenizer_args,
**self._tokenizer_args,
)
inputs.to(self.device)
inputs.to(self._device)
with torch.no_grad():
embeddings = self.model(**inputs).last_hidden_state.detach().cpu().numpy()
embeddings = self._model(**inputs).last_hidden_state.detach().cpu().numpy()
return np.average(embeddings[:, 4:, :], axis=-2)


Expand All @@ -102,18 +104,18 @@ class TCTColBERTDocumentEncoder(TransformerEncoder):

def _encode(self, texts: "Sequence[str]") -> np.ndarray:
max_length = 512
inputs = self.tokenizer(
inputs = self._tokenizer(
["[CLS] [D] " + text for text in texts],
max_length=max_length,
padding=True,
truncation=True,
add_special_tokens=False,
return_tensors="pt",
**self.tokenizer_args,
**self._tokenizer_args,
)
inputs.to(self.device)
inputs.to(self._device)
with torch.no_grad():
outputs = self.model(**inputs)
outputs = self._model(**inputs)
token_embeddings = outputs["last_hidden_state"][:, 4:, :]
input_mask_expanded = (
inputs.attention_mask[:, 4:]
Expand Down

0 comments on commit e0bff48

Please sign in to comment.