From e0bff48acaeab739e8dbc7418be07dde3672abf3 Mon Sep 17 00:00:00 2001 From: mrjleo Date: Thu, 12 Dec 2024 15:10:47 +0100 Subject: [PATCH] make encoder attributes private --- src/fast_forward/encoder.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/fast_forward/encoder.py b/src/fast_forward/encoder.py index b1b487b..eb26b3c 100644 --- a/src/fast_forward/encoder.py +++ b/src/fast_forward/encoder.py @@ -41,18 +41,20 @@ def __init__( :param **tokenizer_args: Additional tokenizer arguments. """ super().__init__() - self.model = AutoModel.from_pretrained(model) - self.model.to(device) - self.model.eval() - self.tokenizer = AutoTokenizer.from_pretrained(model) - self.device = device - self.tokenizer_args = tokenizer_args + self._model = AutoModel.from_pretrained(model) + self._model.to(device) + self._model.eval() + self._tokenizer = AutoTokenizer.from_pretrained(model) + self._device = device + self._tokenizer_args = tokenizer_args def _encode(self, texts: "Sequence[str]") -> np.ndarray: - inputs = self.tokenizer(list(texts), return_tensors="pt", **self.tokenizer_args) - inputs.to(self.device) + inputs = self._tokenizer( + list(texts), return_tensors="pt", **self._tokenizer_args + ) + inputs.to(self._device) with torch.no_grad(): - return self.model(**inputs).pooler_output.detach().cpu().numpy() + return self._model(**inputs).pooler_output.detach().cpu().numpy() class LambdaEncoder(Encoder): @@ -79,17 +81,17 @@ class TCTColBERTQueryEncoder(TransformerEncoder): def _encode(self, texts: "Sequence[str]") -> np.ndarray: max_length = 36 - inputs = self.tokenizer( + inputs = self._tokenizer( ["[CLS] [Q] " + q + "[MASK]" * max_length for q in texts], max_length=max_length, truncation=True, add_special_tokens=False, return_tensors="pt", - **self.tokenizer_args, + **self._tokenizer_args, ) - inputs.to(self.device) + inputs.to(self._device) with torch.no_grad(): - embeddings = self.model(**inputs).last_hidden_state.detach().cpu().numpy() + embeddings = self._model(**inputs).last_hidden_state.detach().cpu().numpy() return np.average(embeddings[:, 4:, :], axis=-2) @@ -102,18 +104,18 @@ class TCTColBERTDocumentEncoder(TransformerEncoder): def _encode(self, texts: "Sequence[str]") -> np.ndarray: max_length = 512 - inputs = self.tokenizer( + inputs = self._tokenizer( ["[CLS] [D] " + text for text in texts], max_length=max_length, padding=True, truncation=True, add_special_tokens=False, return_tensors="pt", - **self.tokenizer_args, + **self._tokenizer_args, ) - inputs.to(self.device) + inputs.to(self._device) with torch.no_grad(): - outputs = self.model(**inputs) + outputs = self._model(**inputs) token_embeddings = outputs["last_hidden_state"][:, 4:, :] input_mask_expanded = ( inputs.attention_mask[:, 4:]