From ec8d1875f4bf39387ddd2e9a7a6eba1904c70d8a Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Mon, 7 Oct 2024 08:28:01 -0400 Subject: [PATCH 1/7] advancing to scalars gen2 --- .../modular_tokenizer/inject_utils.py | 150 ++++++++++-------- fuse/data/tokenizers/modular_tokenizer/op.py | 44 ++--- 2 files changed, 91 insertions(+), 103 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 540e3cff..15568819 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -28,17 +28,15 @@ class InjectorToModularTokenizerLib: supported syntax/format: for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format: - ',' separated float values and/or tokens - - for example: "2.7,3.99,-12.9" or "" or "2.19,,3.19," + ',' separated float values. For example: "2.7,3.99,-12.9" for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict for example: "blah.boo.banana" or "data.input.encoder_input" - note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs example usage: encoder_input: - <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS><@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY + <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY labels: <@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY @@ -47,7 +45,7 @@ class InjectorToModularTokenizerLib: @staticmethod def build_placeholder_meta_tokenization( *, - sequence: Union[str, list, tuple], + sequence: str, sample_dict: Optional[NDict] = None, ) -> Tuple[str, List[str]]: """ @@ -67,16 +65,15 @@ def build_placeholder_meta_tokenization( ) if len(sequence) > 0: if isinstance(sequence[0], TypedInput): - sequence_str = list_to_tokenizer_string( + sequence = list_to_tokenizer_string( sequence ) # currently supporting it in this simple way. Consider optimizing if it causes a bottleneck. else: raise Exception( f"Expected sequence to be either string or a list of TypedInput elements. Got a list, but the first element is of type {type(sequence[0])}" ) - else: - sequence_str = sequence - hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence_str)[ + + hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[ 1: ] # the first element is blank - removing it assert ( @@ -91,19 +88,18 @@ def build_placeholder_meta_tokenization( if tokenizer_type.startswith("SCALARS_"): with_placeholders.append( "<@TOKENIZER-TYPE=AA>" - ) # won't use AA tokens, just an arbitrary one to be able to use a token like + ) # AA tokenizer selection is arbitrary, we only take the special token from it - if ( - tokenizer_type == "SCALARS_LITERALS" - ): # note: masking is only supported in literals (not in "from dict") + if tokenizer_type == "SCALARS_LITERALS": values = subseq.split(",") - # seq = "" * len(values) - seq = "".join( - [ - "" if x == "" else "" - for x in values - ] - ) + # validate that all values can be converted to fload + try: + [float(x) for x in values] + except: + raise ValueError( + f'expected a string with "," separated values that can each be converted to float. Got {subseq}' + ) + seq = "" * len(values) elif tokenizer_type == "SCALARS_FROM_DICT": if sample_dict is None: raise Exception( @@ -130,6 +126,7 @@ def prepare_info_for_model_step( *, per_meta_tokenizer_data: List[str], per_meta_encoding_including_placeholders: List[Encoding], + token_ids: List[int], sample_dict: Optional[NDict] = None, ) -> Dict: """ @@ -147,10 +144,12 @@ def prepare_info_for_model_step( """ - scalars_indices = [] - scalars_values = [] - scalars_masked_indices = [] - prev_index_end = -1 + ## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function + all_scalars_values = ( + [] + ) # one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars + all_scalars_valid_mask = [] # for each element, whether it's a scalar or not + scalar_default_unfound_value = -1000.0 for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip( per_meta_tokenizer_data[::2], @@ -165,42 +164,30 @@ def prepare_info_for_model_step( f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}" ) - curr_indices = [] - curr_data = [] - - for i, val in enumerate(curr_str_data): - if val != "": - curr_indices.append(i + prev_index_end + 1) - curr_data.append(float(val)) - else: - scalars_masked_indices.append(i + prev_index_end + 1) - - if len(curr_indices) > 0: - curr_indices = torch.tensor(curr_indices, dtype=torch.int64) - curr_data = torch.tensor(curr_data, dtype=torch.float32) - - scalars_indices.append(curr_indices) - scalars_values.append(curr_data) - - assert len(curr_data.shape) == 1 - - prev_index_end += len(curr_str_data) + curr_scalar_values = [float(val) for val in curr_str_data] + curr_scalar_values = torch.tensor( + curr_scalar_values, dtype=torch.float32 + ) + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=True, dtype=torch.bool + ) + ) elif "SCALARS_FROM_DICT" == tokenizer_name: if sample_dict is None: raise Exception( "SCALARS_FROM_DICT used but the provided sample_dict is None" ) - curr_data = sample_dict[curr_str_data] - assert len(curr_data.shape) == 1 - curr_indices = torch.arange( - prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0] + curr_scalar_values = sample_dict[curr_str_data] + assert len(curr_scalar_values.shape) == 1 + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=True, dtype=torch.bool + ) ) - scalars_indices.append(curr_indices) - scalars_values.append(curr_data) - - prev_index_end += curr_data.shape[0] - else: raise Exception( "Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT" @@ -209,24 +196,47 @@ def prepare_info_for_model_step( elif tokenizer_name.startswith("VECTORS_"): raise NotImplementedError else: - prev_index_end += len(curr_placeholder_encoding.ids) - - if len(scalars_indices) > 0: - scalars_indices = torch.concat(scalars_indices) - scalars_values = torch.concat(scalars_values) - else: - scalars_indices = None - scalars_values = None - - if len(scalars_masked_indices) > 0: - scalars_masked_indices = torch.tensor( - scalars_masked_indices, dtype=torch.int64 + # prev_index_end += len(curr_placeholder_encoding.ids) + curr_scalar_values = torch.full( + (len(curr_placeholder_encoding.ids),), + fill_value=scalar_default_unfound_value, + ) + all_scalars_values.append(curr_scalar_values) + all_scalars_valid_mask.append( + torch.full_like( + curr_scalar_values, fill_value=False, dtype=torch.bool + ) + ) + + all_scalars_values = torch.concat(all_scalars_values) + all_scalars_valid_mask = torch.concat(all_scalars_valid_mask) + + assert all_scalars_values.shape == all_scalars_valid_mask.shape + + # pad if needed + full_query_len = len(token_ids) + if full_query_len > all_scalars_values.shape[0]: + pad_len = full_query_len - all_scalars_values.shape[0] + all_scalars_values = torch.concat( + [ + all_scalars_values, + torch.full( + (pad_len,), + fill_value=scalar_default_unfound_value, + dtype=all_scalars_values.dtype, + ), + ] + ) + all_scalars_valid_mask = torch.concat( + [ + all_scalars_valid_mask, + torch.full( + (pad_len,), fill_value=False, dtype=all_scalars_valid_mask.dtype + ), + ] ) - else: - scalars_masked_indices = None return { - "scalars_indices": scalars_indices, # 1d - its length is the number of actual scalars (provided) found - "scalars_values": scalars_values, # 1d - values of provided scalars - "scalars_masked_indices": scalars_masked_indices, # 1d - indices of masked scalars + "scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found + "scalars_valid_mask": all_scalars_valid_mask, # 1d - values of provided scalars } diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index 9ccf6650..63e314a4 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -372,8 +372,7 @@ class ModularTokenizerOp(ModularTokenizerWithoutInjectOp): supported syntax/format: for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format: - ',' separated float values and/or tokens - - for example: "2.7,3.99,-12.9" or "" or "2.19,,3.19," + ',' separated float values for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict for example: "blah.boo.banana" or "data.input.encoder_input" @@ -437,9 +436,7 @@ def __call__( on_unknown: Optional[str] = "warn", verbose: Optional[int] = 1, validate_ends_with_eos: Optional[bool] = None, - key_out_scalars_indices: Optional[str] = None, - key_out_scalars_values: Optional[str] = None, - key_out_masked_scalars_indices: Optional[str] = None, + key_out_scalars: Optional[str] = None, ) -> NDict: """_summary_ @@ -458,10 +455,10 @@ def __call__( verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning with full data. Defaults to 1. validate_ends_with_eos (Optional[bool], optional): if not None, overrides self._validate_ends_with_eos - key_out_scalars_inputs_indices:str optional - if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar elements. - key_out_scalars_inputs_values:str optional - if provided, will write to sample_dict in this key a 1D torch tensor with indices of all inputs scalar values. + key_out_scalars:str optional + if provided, will write to: + `sample_dict[f'{key_out_scalars}.values]` - a 1D torch tensor with all the scalars values + `sample_dict[f'{key_out_scalars}.valid_mask]` - a 1D torch boolean tensor representing which elements have scalar values Returns: NDict: _description_ @@ -495,34 +492,15 @@ def __call__( per_meta_encoding_including_placeholders=sample_dict[ key_in + ".per_meta_part_encoding" ], + token_ids=sample_dict[key_out_tokens_ids], sample_dict=sample_dict, ) - if key_out_scalars_indices is not None: - sample_dict[key_out_scalars_indices] = prepared_data["scalars_indices"] - else: - if prepared_data["scalars_indices"] is not None: - raise Exception( - "non None scalars_indices found but no key_out_scalars_indices found" - ) - - if key_out_scalars_values is not None: - sample_dict[key_out_scalars_values] = prepared_data["scalars_values"] - else: - if prepared_data["scalars_values"] is not None: - raise Exception( - "non None scalars_value found but no key_out_scalars_values found" - ) - - if key_out_masked_scalars_indices is not None: - sample_dict[key_out_masked_scalars_indices] = prepared_data[ - "scalars_masked_indices" + if key_out_scalars is not None: + sample_dict[key_out_scalars + ".values"] = prepared_data["scalars_values"] + sample_dict[key_out_scalars + ".valid_mask"] = prepared_data[ + "scalars_valid_mask" ] - else: - if prepared_data["scalars_masked_indices"] is not None: - raise Exception( - "non None scalars_masked_indices found but no key_out_masked_scalars_indices found" - ) return sample_dict From 6010b3461ae5f9b7a8a43161dd2e2240053576e7 Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Mon, 7 Oct 2024 08:29:41 -0400 Subject: [PATCH 2/7] fix static code checks --- fuse/data/tokenizers/modular_tokenizer/inject_utils.py | 2 +- fuse/data/tokenizers/modular_tokenizer/op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 15568819..b491e89f 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Dict, Union +from typing import Optional, List, Tuple, Dict from tokenizers import Encoding import torch import re diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index 63e314a4..be98f52b 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -372,7 +372,7 @@ class ModularTokenizerOp(ModularTokenizerWithoutInjectOp): supported syntax/format: for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format: - ',' separated float values + ',' separated float values for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict for example: "blah.boo.banana" or "data.input.encoder_input" From e19cb2c18eab3dae0f751ef943d3cd327bd704a9 Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Wed, 9 Oct 2024 09:40:59 -0400 Subject: [PATCH 3/7] PR comments --- .../tokenizers/modular_tokenizer/inject_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index b491e89f..164404ca 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Dict +from typing import Optional, List, Tuple, Dict, Union from tokenizers import Encoding import torch import re @@ -45,7 +45,7 @@ class InjectorToModularTokenizerLib: @staticmethod def build_placeholder_meta_tokenization( *, - sequence: str, + sequence: Union[str, list, tuple], sample_dict: Optional[NDict] = None, ) -> Tuple[str, List[str]]: """ @@ -92,7 +92,7 @@ def build_placeholder_meta_tokenization( if tokenizer_type == "SCALARS_LITERALS": values = subseq.split(",") - # validate that all values can be converted to fload + # validate that all values can be converted to float try: [float(x) for x in values] except: @@ -145,10 +145,11 @@ def prepare_info_for_model_step( """ ## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function - all_scalars_values = ( - [] - ) # one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars - all_scalars_valid_mask = [] # for each element, whether it's a scalar or not + + # one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars + all_scalars_values = [] + # for each element, whether it's a scalar or not + all_scalars_valid_mask = [] scalar_default_unfound_value = -1000.0 for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip( From d2a7b836464f7ea1680f9fbf9c2343ea6db66772 Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Wed, 9 Oct 2024 09:42:31 -0400 Subject: [PATCH 4/7] PR comments --- fuse/data/tokenizers/modular_tokenizer/inject_utils.py | 2 +- fuse/data/tokenizers/modular_tokenizer/op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 164404ca..2f46218f 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -122,7 +122,7 @@ def build_placeholder_meta_tokenization( return "".join(with_placeholders), hints_and_subseq @staticmethod - def prepare_info_for_model_step( + def build_scalars( *, per_meta_tokenizer_data: List[str], per_meta_encoding_including_placeholders: List[Encoding], diff --git a/fuse/data/tokenizers/modular_tokenizer/op.py b/fuse/data/tokenizers/modular_tokenizer/op.py index be98f52b..261dcceb 100644 --- a/fuse/data/tokenizers/modular_tokenizer/op.py +++ b/fuse/data/tokenizers/modular_tokenizer/op.py @@ -487,7 +487,7 @@ def __call__( + ".per_meta_part_encoding", # using the key_in as base for the name because key_out_* are optional ) - prepared_data = InjectorToModularTokenizerLib.prepare_info_for_model_step( + prepared_data = InjectorToModularTokenizerLib.build_scalars( per_meta_tokenizer_data=per_meta_orig, per_meta_encoding_including_placeholders=sample_dict[ key_in + ".per_meta_part_encoding" From dff2a359b314e934f070454635f898a2ff6da533 Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Wed, 9 Oct 2024 10:07:45 -0400 Subject: [PATCH 5/7] handling shorter seq in scalars due to crop --- fuse/data/tokenizers/modular_tokenizer/inject_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 2f46218f..9dd211f3 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -236,6 +236,11 @@ def build_scalars( ), ] ) + elif full_query_len > all_scalars_values.shape[0]: + print('warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}') + all_scalars_values = all_scalars_values[:full_query_len] + all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len] + return { "scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found From 0ac3c6dfb89f33c948502697c37a4efceb94ac1f Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Wed, 9 Oct 2024 11:33:19 -0400 Subject: [PATCH 6/7] ... --- fuse/data/tokenizers/modular_tokenizer/inject_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index 9dd211f3..bf62119a 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -236,11 +236,12 @@ def build_scalars( ), ] ) - elif full_query_len > all_scalars_values.shape[0]: - print('warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}') + elif full_query_len < all_scalars_values.shape[0]: + print( + "warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}" + ) all_scalars_values = all_scalars_values[:full_query_len] all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len] - return { "scalars_values": all_scalars_values, # 1d - its length is the number of actual scalars (provided) found From 2ae251056755e720eb707e15b8a486ea611dea0d Mon Sep 17 00:00:00 2001 From: yoel shoshan Date: Wed, 9 Oct 2024 11:53:52 -0400 Subject: [PATCH 7/7] ... --- .../modular_tokenizer/inject_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py index bf62119a..c7a9a2dd 100644 --- a/fuse/data/tokenizers/modular_tokenizer/inject_utils.py +++ b/fuse/data/tokenizers/modular_tokenizer/inject_utils.py @@ -7,6 +7,7 @@ TypedInput, list_to_tokenizer_string, ) +from warnings import warn class InjectorToModularTokenizerLib: @@ -128,6 +129,7 @@ def build_scalars( per_meta_encoding_including_placeholders: List[Encoding], token_ids: List[int], sample_dict: Optional[NDict] = None, + crop_report: str = "warn", ) -> Dict: """ since we: @@ -141,9 +143,11 @@ def build_scalars( per_meta_encoding_including_placeholders: a list of Encoding elements. This is used to extract per tokenizer final tokens num (after all of the padding and cropping logic was already done) sample_dict: a fuse sample_dict - optional. needed only if the meta tokenizer instruction uses a syntax of lookup from the dictionary - + crop_report: one of None (no action), 'warn' - print a warning, 'raise' - raise an exception + will be triggered if cropping happened """ + assert crop_report in ["warn", "raise", None] ## both `all_scalars_values` and `all_scalars_valid_mask` will contain torch tensors, which will be concatanated in the end of this function # one scalar for every element, `scalar_default_unfound_value` is used for elements that aren't scalars @@ -237,9 +241,14 @@ def build_scalars( ] ) elif full_query_len < all_scalars_values.shape[0]: - print( - "warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}" - ) + if crop_report in ["warn", "raise"]: + _msg = f"warning: scalars sequence had to be cropped. The full (including all subtokenizers) length was {all_scalars_values.shape[0]} after cropping it is {full_query_len}" + if crop_report == "warn": + warn(_msg) + elif crop_report == "raise": + raise Exception(_msg) + else: + assert False, "should not get here" all_scalars_values = all_scalars_values[:full_query_len] all_scalars_valid_mask = all_scalars_valid_mask[:full_query_len]