Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Aug 28, 2023
1 parent 158f889 commit e23ebb9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ class BnConfig(AdapterConfigBase):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False.
residual_before_ln (:obj:`bool` or :obj:`str`, optional):
If True, take the residual connection around the adapter bottleneck before the layer normalization.
If set to "post_add", take the residual connection around the adapter bottleneck after the previous residual connection.
Only applicable if :obj:`original_ln_before` is True.
If True, take the residual connection around the adapter bottleneck before the layer normalization. If set
to "post_add", take the residual connection around the adapter bottleneck after the previous residual
connection. Only applicable if :obj:`original_ln_before` is True.
adapter_residual_before_ln (:obj:`bool`, optional):
If True, apply the residual connection around the adapter modules before the new layer normalization within
the adapter. Only applicable if :obj:`ln_after` is True and :obj:`is_parallel` is False.
Expand Down
11 changes: 8 additions & 3 deletions src/adapters/models/xmod/mixin_xmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel_
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...context import ForwardContext
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -44,15 +44,20 @@ def hook_after_embeddings(self, hook_fn: Callable):
@ForwardContext.wrap
def forward(self, *args, **kwargs):
if "lang_ids" in kwargs and kwargs["lang_ids"] is not None:
raise ValueError("XmodModel with adapters does not support `lang_ids` as an argument. Use `set_active_adapters` instead.")
raise ValueError(
"XmodModel with adapters does not support `lang_ids` as an argument. Use `set_active_adapters`"
" instead."
)
else:
kwargs["lang_ids"] = 1
return super().forward(*args, **kwargs)

# Override adapter-specific methods in original implementation

def set_default_language(self, language: str):
raise ValueError("`set_default_language` is not implemented for models using `adapters`. Use `set_active_adapters` instead.")
raise ValueError(
"`set_default_language` is not implemented for models using `adapters`. Use `set_active_adapters` instead."
)

def freeze_embeddings_and_language_adapters(self):
"""
Expand Down

0 comments on commit e23ebb9

Please sign in to comment.