Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

advancing to scalars gen2 #374

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 86 additions & 70 deletions fuse/data/tokenizers/modular_tokenizer/inject_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,15 @@ class InjectorToModularTokenizerLib:
supported syntax/format:

for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format:
',' separated float values and/or <MASK> tokens -
for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>"
',' separated float values. For example: "2.7,3.99,-12.9"

for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs

example usage:

encoder_input:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
labels:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>

Expand Down Expand Up @@ -67,16 +65,15 @@ def build_placeholder_meta_tokenization(
)
if len(sequence) > 0:
if isinstance(sequence[0], TypedInput):
sequence_str = list_to_tokenizer_string(
sequence = list_to_tokenizer_string(
sequence
) # currently supporting it in this simple way. Consider optimizing if it causes a bottleneck.
else:
raise Exception(
f"Expected sequence to be either string or a list of TypedInput elements. Got a list, but the first element is of type {type(sequence[0])}"
)
else:
sequence_str = sequence
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence_str)[

hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[
1:
] # the first element is blank - removing it
assert (
Expand All @@ -91,19 +88,18 @@ def build_placeholder_meta_tokenization(
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general solution can't assume there is AA subtokenizer.
Maybe we need a default empty sub-tokenizer? Maybe SCALARS can be an empty sub-tokenizer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting point.
SCALARS is currently fully programmatic and does not rely on any dictionary, so I would rather not mix it.
Probably better to have "base" that gets automatically generated and supported , as the modular tokenizer already knows how to handle special tokens

maybe "Base" or "SpecialTokensBase" or something

) # won't use AA tokens, just an arbitrary one to be able to use a token like <SCALAR>
) # AA tokenizer selection is arbitrary, we only take the special token <SCALAR> from it

if (
tokenizer_type == "SCALARS_LITERALS"
): # note: masking is only supported in literals (not in "from dict")
if tokenizer_type == "SCALARS_LITERALS":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me, can we put mask in scalar literals?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I stopped supporting this option, intentionally.
Now scalar only supports scalars

values = subseq.split(",")
# seq = "<SCALAR>" * len(values)
seq = "".join(
[
"<MASKED_SCALAR>" if x == "<MASK>" else "<SCALAR>"
for x in values
]
)
# validate that all values can be converted to float
try:
[float(x) for x in values]
except:
raise ValueError(
f'expected a string with "," separated values that can each be converted to float. Got {subseq}'
)
seq = "<SCALAR>" * len(values)
elif tokenizer_type == "SCALARS_FROM_DICT":
if sample_dict is None:
raise Exception(
Expand All @@ -126,10 +122,11 @@ def build_placeholder_meta_tokenization(
return "".join(with_placeholders), hints_and_subseq

@staticmethod
def prepare_info_for_model_step(
def build_scalars(
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe build_scalars be a better name here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np, renamed

per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
token_ids: List[int],
sample_dict: Optional[NDict] = None,
) -> Dict:
"""
Expand All @@ -147,10 +144,13 @@ def prepare_info_for_model_step(


"""
scalars_indices = []
scalars_values = []
scalars_masked_indices = []
prev_index_end = -1
## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function

# one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars
all_scalars_values = []
# for each element, whether it's a scalar or not
all_scalars_valid_mask = []
scalar_default_unfound_value = -1000.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we now keep the scalars values and mask at the size of the entire sequence, this is the default value for places that don't actually have a scalar value.
I chose -1000 and not something like 0 to make sure it pops up easily if there are mistakes down the road


for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
Expand All @@ -165,42 +165,30 @@ def prepare_info_for_model_step(
f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}"
)

curr_indices = []
curr_data = []

for i, val in enumerate(curr_str_data):
if val != "<MASK>":
curr_indices.append(i + prev_index_end + 1)
curr_data.append(float(val))
else:
scalars_masked_indices.append(i + prev_index_end + 1)

if len(curr_indices) > 0:
curr_indices = torch.tensor(curr_indices, dtype=torch.int64)
curr_data = torch.tensor(curr_data, dtype=torch.float32)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

assert len(curr_data.shape) == 1

prev_index_end += len(curr_str_data)
curr_scalar_values = [float(val) for val in curr_str_data]
curr_scalar_values = torch.tensor(
curr_scalar_values, dtype=torch.float32
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_data = sample_dict[curr_str_data]
assert len(curr_data.shape) == 1
curr_indices = torch.arange(
prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0]
curr_scalar_values = sample_dict[curr_str_data]
assert len(curr_scalar_values.shape) == 1
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=True, dtype=torch.bool
)
)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

prev_index_end += curr_data.shape[0]

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
Expand All @@ -209,24 +197,52 @@ def prepare_info_for_model_step(
elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
else:
prev_index_end += len(curr_placeholder_encoding.ids)

if len(scalars_indices) > 0:
scalars_indices = torch.concat(scalars_indices)
scalars_values = torch.concat(scalars_values)
else:
scalars_indices = None
scalars_values = None

if len(scalars_masked_indices) > 0:
scalars_masked_indices = torch.tensor(
scalars_masked_indices, dtype=torch.int64
# prev_index_end += len(curr_placeholder_encoding.ids)
curr_scalar_values = torch.full(
(len(curr_placeholder_encoding.ids),),
fill_value=scalar_default_unfound_value,
)
all_scalars_values.append(curr_scalar_values)
all_scalars_valid_mask.append(
torch.full_like(
curr_scalar_values, fill_value=False, dtype=torch.bool
)
)

all_scalars_values = torch.concat(all_scalars_values)
all_scalars_valid_mask = torch.concat(all_scalars_valid_mask)

assert all_scalars_values.shape == all_scalars_valid_mask.shape

# pad if needed
full_query_len = len(token_ids)
if full_query_len > all_scalars_values.shape[0]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't it happend?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "can it happen" ?
If so, then yes, almost always.

The main code logic (before it) iterates over each sub part (with specific sub tokenizer) so it does not contain the padding.

I can explain more if it isn't clear

pad_len = full_query_len - all_scalars_values.shape[0]
all_scalars_values = torch.concat(
[
all_scalars_values,
torch.full(
(pad_len,),
fill_value=scalar_default_unfound_value,
dtype=all_scalars_values.dtype,
),
]
)
all_scalars_valid_mask = torch.concat(
[
all_scalars_valid_mask,
torch.full(
(pad_len,), fill_value=False, dtype=all_scalars_valid_mask.dtype
),
]
)
else:
scalars_masked_indices = None
elif full_query_len > all_scalars_values.shape[0]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mosheraboh see here, related to what we talked about.

I'll try to add more unit tests with interesting cases by the end of this week

print('warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}')
all_scalars_values = all_scalars_values[:full_query_len]
all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len]


return {
"scalars_indices": scalars_indices, # 1d - its length is the number of actual scalars (provided) found
"scalars_values": scalars_values, # 1d - values of provided scalars
"scalars_masked_indices": scalars_masked_indices, # 1d - indices of masked scalars
"scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found
"scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars
}
46 changes: 12 additions & 34 deletions fuse/data/tokenizers/modular_tokenizer/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,7 @@ class ModularTokenizerOp(ModularTokenizerWithoutInjectOp):
supported syntax/format:

for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format:
',' separated float values and/or <MASK> tokens -
for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>"
',' separated float values

for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
Expand Down Expand Up @@ -437,9 +436,7 @@ def __call__(
on_unknown: Optional[str] = "warn",
verbose: Optional[int] = 1,
validate_ends_with_eos: Optional[bool] = None,
key_out_scalars_indices: Optional[str] = None,
key_out_scalars_values: Optional[str] = None,
key_out_masked_scalars_indices: Optional[str] = None,
key_out_scalars: Optional[str] = None,
) -> NDict:
"""_summary_

Expand All @@ -458,10 +455,10 @@ def __call__(
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos
key_out_scalars_inputs_indices:str optional
if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar elements.
key_out_scalars_inputs_values:str optional
if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar values.
key_out_scalars:str optional
if provided, will write to:
`sample_dict[f'{key_out_scalars}.values]` - a 1D torch tensor with all the scalars values
`sample_dict[f'{key_out_scalars}.valid_mask]` - a 1D torch boolean tensor representing which elements have scalar values

Returns:
NDict: _description_
Expand Down Expand Up @@ -490,39 +487,20 @@ def __call__(
+ ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional
)

prepared_data = InjectorToModularTokenizerLib.prepare_info_for_model_step(
prepared_data = InjectorToModularTokenizerLib.build_scalars(
per_meta_tokenizer_data=per_meta_orig,
per_meta_encoding_including_placeholders=sample_dict[
key_in + ".per_meta_part_encoding"
],
token_ids=sample_dict[key_out_tokens_ids],
sample_dict=sample_dict,
)

if key_out_scalars_indices is not None:
sample_dict[key_out_scalars_indices] = prepared_data["scalars_indices"]
else:
if prepared_data["scalars_indices"] is not None:
raise Exception(
"non None scalars_indices found but no key_out_scalars_indices found"
)

if key_out_scalars_values is not None:
sample_dict[key_out_scalars_values] = prepared_data["scalars_values"]
else:
if prepared_data["scalars_values"] is not None:
raise Exception(
"non None scalars_value found but no key_out_scalars_values found"
)

if key_out_masked_scalars_indices is not None:
sample_dict[key_out_masked_scalars_indices] = prepared_data[
"scalars_masked_indices"
if key_out_scalars is not None:
sample_dict[key_out_scalars + ".values"] = prepared_data["scalars_values"]
sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[
"scalars_valid_mask"
]
else:
if prepared_data["scalars_masked_indices"] is not None:
raise Exception(
"non None scalars_masked_indices found but no key_out_masked_scalars_indices found"
)

return sample_dict

Expand Down
Loading