Skip to content

Commit

Permalink
Add support for __all__ and potentilly deleting functions (huggingfac…
Browse files Browse the repository at this point in the history
…e#33859)

* Add support for __all__ and potentailly deleting functions

* updates

* update

* nits

* remove dummies

* fix warning

* fixup

* style

* update

* fixup

* skip copied from when # skip

* remove log

* bring dummies back

* fixup

* remove copied from

* fixup

* remove warnings from `make fix-copies`

* fix doc issues

* nits

* Better error message !

* add support for more flexible naming!

* style

* breaking style?

* fix super() renaming issues

* del not needed when you don't call super().__init__()

* style

* no more fmt on :)

* properly remove `self`

* fixup

* fix

* doc nits

* add some doc 🫡
  • Loading branch information
ArthurZucker authored Oct 8, 2024
1 parent bead0fa commit a3add29
Show file tree
Hide file tree
Showing 15 changed files with 477 additions and 149 deletions.
58 changes: 57 additions & 1 deletion docs/source/en/modular_transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,60 @@ Additionally, you may find a list of examples here:
## What it is not
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
It is not a replacement for the modeling code (yet?), and if your model is not based on anything else that ever existed, then you can add a `modeling` file as usual.
## Advanced usage
### Removing attributes and functions
To remove attributes that are not used in your modular model, and that you don't want to see in the unravelled modeling:

```python
class GemmaModel(LlamaModel): | class GemmaModel(PreTrainedModel):
def __init__(self, config): | def __init__(self, config):
super().__init__(self, eos_token) | super().__init__(config)
del self.embed_tokens | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
|
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.rotary_emb = LlamaRotaryEmbedding(config=config)
| self.gradient_checkpointing = False
|
| # Initialize weights and apply final processing
| self.post_init()
```
If you check the original `LlamaModel`, it has a `embed_tokens` which was removed here (as you would expect!)
Removing a function is pretty similar, you just need to write it with a `raise ValueError("")` to mimick the behaviour you actually want when you remove a parent function in python.
```python
class GemmaTokenizer(LlamaTokenizer):
...
def get_spm_processor(self):
raise AttributeError("Not needed for Gemma")
def unk_token_length(self):
raise AttributeError("Not needed for Gemma")
```
### Calling `super()`
We recently shipped a few features that allow you to go from:
```python
class GemmaTokenizer(LlamaTokenizer, PretrainedTokenizerFast): | class GemmaModel(nn.Module):
def __init__(self, eos_token="</s>"): | def __init__(self):
eos_token = AddedToken(eos_token) | eos_token = AddedToken(eos_token)
PretrainedTokenizerFast.__init__(self, eos_token) | super().__init__(eos_token)
```
This is useful want you **don't** want to unravel the call to `super()`, and you want to differentiate which super init call you are doing!
### Special naming
We now also support special cases like
```python
class GemmaVisionModel(CLIPModel):
pass
```
where the name of your class `GemmaVision` is not the same as the modular `Gemma`. This is super useful for composite models
1 change: 0 additions & 1 deletion examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_xxx.py file directly. One of our CI enforces this
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

import math
from typing import List, Optional, Tuple, Union

Expand Down
1 change: 0 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
"data.metrics": [],
"data.processors": [],
"debug_utils": [],
"deepspeed": [],
"dependency_versions_check": [],
"dependency_versions_table": [],
"dynamic_module_utils": [],
Expand Down
41 changes: 0 additions & 41 deletions src/transformers/deepspeed.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ def __init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


__all__ = ["GemmaConfig"]
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,3 +1314,6 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]
179 changes: 178 additions & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import sentencepiece as spm
import torch
import torch.utils.checkpoint
from torch import nn
Expand All @@ -27,6 +28,7 @@
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
from ..llama.modeling_llama import (
LlamaDecoderLayer,
Expand All @@ -38,6 +40,15 @@
apply_rotary_pos_emb,
repeat_kv,
)
from ..llama.tokenization_llama import LlamaTokenizer


if TYPE_CHECKING:
from ...tokenization_utils_base import TextInput

VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}

SPIECE_UNDERLINE = "▁"


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -164,6 +175,162 @@ def __init__(
)


class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
"""
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
attention mechanisms or loss computation.
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for Gemma should be used.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
"""

def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
use_default_system_prompt=False,
spaces_between_special_tokens=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token

self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)

PreTrainedTokenizer.__init__(
self,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
use_default_system_prompt=use_default_system_prompt,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)

def get_spm_processor(self):
raise AttributeError("Not needed for Gemma")

def unk_token_length(self):
raise AttributeError("Not needed for Gemma")

def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
"""
Args:
text: TextInput
Simply calls PreTrainedTokenizer's method
"""
return PreTrainedTokenizer.tokenize(self, text, **kwargs)

def _tokenize(self, text, **kwargs):
"""
Args:
text: TextInput
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
"""
return self.sp_model.encode(text, out_type=str)

def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = False,
**kwargs,
) -> str:
sub_texts = []
current_sub_text = []
for ids in token_ids:
if skip_special_tokens and ids in self.all_special_ids:
continue
if ids in self._added_tokens_decoder:
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))
sub_texts.append(self._added_tokens_decoder[ids].content)
current_sub_text = []
else:
current_sub_text.append(ids)
if current_sub_text:
sub_texts.append(self.sp_model.decode(current_sub_text))

if spaces_between_special_tokens:
sub_texts = " ".join(sub_texts)
else:
sub_texts = "".join(sub_texts)

return sub_texts.replace(SPIECE_UNDERLINE, " ")

def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self._added_tokens_encoder:
out_string += self.sp_model.decode(current_sub_tokens) + token
current_sub_tokens = []
else:
current_sub_tokens.append(token)
out_string += self.sp_model.decode(current_sub_tokens)
return out_string


class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -874,3 +1041,13 @@ def __init__(self, config):
super().__init__(config)
self.model = GemmaModel(config)
self.post_init()


__all__ = [
"GemmaConfig",
"GemmaTokenizer",
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
]
Loading

0 comments on commit a3add29

Please sign in to comment.