Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
Sagi Polaczek committed Oct 10, 2024
1 parent 1626ae1 commit 5b237b4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
11 changes: 0 additions & 11 deletions fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,12 @@ def add_single_tokenizer(
# we update the special tokens but do not save here. remember to save yourself.
self.update_special_tokens(
special_tokens=new_tokenize_special_tokens,
# save_tokenizer_path=self.cfg_raw["data"]["tokenizer"]["out_path"],
)

def add_tokenizers(
self,
) -> None:
raise Exception("Not implemented")
# self.build_inner_decoder()
# if self._max_possible_token_id is not None:
# if self._get_max_mapped_id() > self._max_possible_token_id:
# raise Exception(
# f"tokenizer remapping resulted in IDs greater (max_id={self._get_max_mapped_id()}) than max_possible_id ({self._max_possible_token_id}). Reinitialize the modular tokenizer with larger max_possible_id"
# )

def _encode_single_type(
self,
Expand Down Expand Up @@ -1059,10 +1052,6 @@ def encode_list(
merged_encoding = Encoding.merge(encoded_list)

max_len = self.get_expected_max_len(override_max_len=max_len)
# if max_len is None:
# if self.max_len is not None:
# max_len = self.max_len

if max_len is not None:
if len(merged_encoding) > max_len:
overflow_info += f"OVERALL:{len(merged_encoding)}=>{max_len}|"
Expand Down
12 changes: 9 additions & 3 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
validate_ends_with_eos: Optional[bool] = True,
eos: Optional[str] = "<EOS>",
verbose: Optional[bool] = False,
on_unknown_default_value: str = "warn",
**kwargs: Any,
) -> None:
"""
Expand All @@ -41,6 +42,7 @@ def __init__(
validate_ends_with_eos: during encoder request (a _call_ to the op) will make sure that it ends with the provided eos token, and raise exception otherwise.
having an eos (end of sentence) token in the end is useful for multiple scenarios, for example in a generative transformer (like T5 encoder-decoder)
verbose:
on_unknown_default_value: User can define the default behavior of unknown token here in the constructor. In addition, this value can be overwritten in the __call__
"""
super().__init__(**kwargs)

Expand All @@ -60,6 +62,7 @@ def __init__(

self._validate_ends_with_eos = validate_ends_with_eos
self._eos = eos
self._on_unknown_default_value = on_unknown_default_value

if self._validate_ends_with_eos:
eos_id = self._tokenizer.token_to_id(self._eos)
Expand Down Expand Up @@ -211,7 +214,7 @@ def __call__(
key_out_attention_mask: Optional[str] = None,
convert_attention_mask_to_bool: Optional[bool] = True,
max_seq_len: Optional[int] = None,
on_unknown: Optional[str] = "warn",
on_unknown: Optional[str] = None,
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
additional_caller_info_text: Optional[str] = "",
Expand All @@ -230,7 +233,7 @@ def __call__(
key_out_attention_mask (Optional[str], optional): _description_. Defaults to None.
convert_attention_mask_to_bool (Optional[bool], optional): _description_. Defaults to True.
max_seq_len (Optional[int], optional): set maximum sequence len dynamically, used for both padding and truncation.. Defaults to None.
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn".
on_unknown (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'. Defaults to "warn". The default value can be determined in the constructor itself.
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
Expand All @@ -243,7 +246,6 @@ def __call__(
Returns:
NDict: _description_
"""

data = sample_dict[key_in]
if not isinstance(data, (list, str)):
# data is a list of named tuples of type collections.namedtuple("TypedInput", ["input_type", "input_string", "max_len"])
Expand All @@ -263,6 +265,10 @@ def __call__(
f"validate_ends_with_eos was set to {validate_ends_with_eos}, but about to encode a string that does not end with {self._eos}. The str end was: {last_seq}"
)

if on_unknown is None:
# Use tokenizer instance defautl value
on_unknown = self._on_unknown_default_value

if isinstance(data, str):
_ans = self._tokenizer.encode(
data,
Expand Down

0 comments on commit 5b237b4

Please sign in to comment.