From 00dc9747240ac13ae00ce671af59278ed8cab38a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 25 Sep 2023 11:50:38 -0700 Subject: [PATCH] [LSC] Ignore incorrect type annotations related to jax.numpy APIs PiperOrigin-RevId: 568283334 --- t5x/adafactor.py | 4 ++-- t5x/checkpoints.py | 6 +++--- t5x/contrib/calm/decoding.py | 2 +- t5x/contrib/calm/models.py | 2 +- t5x/contrib/moe/checkpoints.py | 2 +- t5x/decoding.py | 2 +- t5x/models.py | 12 ++++++------ t5x/utils.py | 4 ++-- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/t5x/adafactor.py b/t5x/adafactor.py index c61c7ef31..9c046f6a0 100644 --- a/t5x/adafactor.py +++ b/t5x/adafactor.py @@ -291,7 +291,7 @@ def __hash__(self) -> int: def _decay_rate_pow(i: int, exponent: float = 0.8) -> float: """Default Adafactor second-moment decay schedule.""" t = jnp.array(i, jnp.float32) + 1.0 - return 1.0 - t**(-exponent) + return 1.0 - t**(-exponent) # pytype: disable=bad-return-type # jnp-type @staticmethod def _parse_rule( @@ -412,7 +412,7 @@ def init_param_state(self, param, path): state['v'] = jnp.zeros(param.shape, dtype=jnp.float32) if self.hyper_params.beta1 is not None: state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum) - return _AdafactorParamState(**state) + return _AdafactorParamState(**state) # pytype: disable=wrong-arg-types # jnp-type def init_state(self, params): params_flat = utils.flatten_dict_string_keys(params) diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 1e0405f10..58fcb4af3 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -557,7 +557,7 @@ class Checkpointer(object): oldest ones will be automatically deleted to save space. """ - def __init__( + def __init__( # pytype: disable=annotation-type-mismatch # jnp-type self, train_state: train_state_lib.TrainState, partitioner: partitioning.BasePartitioner, @@ -1302,7 +1302,7 @@ class CheckpointerConstructor(typing_extensions.Protocol): of Checkpointer subclasses without triggering type errors. """ - def __call__(self, + def __call__(self, # pytype: disable=annotation-type-mismatch # jnp-type train_state: train_state_lib.TrainState, partitioner: partitioning.BasePartitioner, checkpoints_dir: str, @@ -1456,7 +1456,7 @@ class SaveBestCheckpointer(Checkpointer): oldest ones will be automatically deleted to save space. """ - def __init__(self, + def __init__(self, # pytype: disable=annotation-type-mismatch # jnp-type train_state: train_state_lib.TrainState, partitioner: partitioning.BasePartitioner, checkpoints_dir: str, diff --git a/t5x/contrib/calm/decoding.py b/t5x/contrib/calm/decoding.py index dac8f45c8..bf2dd879b 100644 --- a/t5x/contrib/calm/decoding.py +++ b/t5x/contrib/calm/decoding.py @@ -473,7 +473,7 @@ def sampling_loop_cond_fn(state: SamplingLoopState) -> bool: # Different elements in the batch can be at different loop indices, if any # of our examples are not at the end, keep going. all_sequences_ended = jnp.all(state.ended) - return ~all_sequences_ended + return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: """Sampling loop state update.""" diff --git a/t5x/contrib/calm/models.py b/t5x/contrib/calm/models.py index d87922ece..6c2801336 100644 --- a/t5x/contrib/calm/models.py +++ b/t5x/contrib/calm/models.py @@ -269,7 +269,7 @@ def loss_fn_meta_cls( for meta_logits, meta_labels in zip(all_meta_logits[:-1], all_meta_labels[:-1]): # Balance across the positive/ negative labels. - balanced_weights = weights.copy().astype(float) + balanced_weights = weights.copy().astype(float) # pytype: disable=attribute-error # jnp-type pos_num = (meta_labels * weights == 1).sum() neg_num = ((1 - meta_labels) * weights == 1).sum() diff --git a/t5x/contrib/moe/checkpoints.py b/t5x/contrib/moe/checkpoints.py index 04a5886f7..bb91e9d31 100644 --- a/t5x/contrib/moe/checkpoints.py +++ b/t5x/contrib/moe/checkpoints.py @@ -45,7 +45,7 @@ class UpcycleCheckpointer(checkpoints.Checkpointer): for more details. """ - def __init__( + def __init__( # pytype: disable=annotation-type-mismatch # jnp-type self, train_state: train_state_lib.TrainState, partitioner: partitioning.BasePartitioner, diff --git a/t5x/decoding.py b/t5x/decoding.py index c445c8837..63e2a075d 100644 --- a/t5x/decoding.py +++ b/t5x/decoding.py @@ -500,7 +500,7 @@ def sampling_loop_cond_fn(state: SamplingLoopState) -> bool: # Different elements in the batch can be at different loop indices, if any # of our examples are not at the end, keep going. all_sequences_ended = jnp.all(state.ended) - return ~all_sequences_ended + return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: """Sampling loop state update.""" diff --git a/t5x/models.py b/t5x/models.py index 09e86a858..499dbeea5 100644 --- a/t5x/models.py +++ b/t5x/models.py @@ -1312,7 +1312,7 @@ def compute_weighted_accuracy( if weights is not None: accuracy = accuracy * weights - return jnp.sum(accuracy) + return jnp.sum(accuracy) # pytype: disable=bad-return-type # jnp-type # TODO(cpgaffney) remove when users rely on compute_base_metrics @@ -1369,7 +1369,7 @@ def count_packed_examples(segment_ids: jnp.ndarray) -> int: # Get the first discrete different along axis=1. first_diff = jnp.diff(segment_ids, n=1, axis=1) # count = #(non-0 diff) + #(row) - #(padded ex). - return jnp.sum(first_diff != 0) + segment_ids.shape[0] - num_padded_examples + return jnp.sum(first_diff != 0) + segment_ids.shape[0] - num_padded_examples # pytype: disable=bad-return-type # jnp-type def compute_base_metrics( @@ -1414,20 +1414,20 @@ def compute_base_metrics( 'loss_per_all_target_tokens': clu_metrics.Average( total=loss, count=num_tokens ), - 'timing/seqs_per_second': metrics_lib.TimeRate.from_model_output( + 'timing/seqs_per_second': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type numerator=num_examples ), 'timing/steps_per_second': metrics_lib.StepsPerTime.from_model_output(), 'timing/seconds': metrics_lib.Time(), 'timing/seqs': metrics_lib.Sum(num_examples), - 'timing/seqs_per_second_per_core': metrics_lib.TimeRate.from_model_output( + 'timing/seqs_per_second_per_core': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type numerator=num_examples / num_devices ), - 'timing/target_tokens_per_second': metrics_lib.TimeRate.from_model_output( + 'timing/target_tokens_per_second': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type numerator=num_tokens ), 'timing/target_tokens_per_second_per_core': ( - metrics_lib.TimeRate.from_model_output( + metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type numerator=num_tokens / num_devices ) ), diff --git a/t5x/utils.py b/t5x/utils.py index d868c1c52..d56a35274 100644 --- a/t5x/utils.py +++ b/t5x/utils.py @@ -355,7 +355,7 @@ def __init__( if save_cfg is not None: if save_cfg.save_dataset: assert ds_iter is not None - save_checkpointer = save_cfg.checkpointer_cls( + save_checkpointer = save_cfg.checkpointer_cls( # pytype: disable=wrong-arg-types # jnp-type train_state=train_state_shape, partitioner=partitioner, checkpoints_dir=model_dir, @@ -749,7 +749,7 @@ def multihost_assert_equal(input_tree, fail_message: str = ''): # ------------------------------------------------------------------------------ # Fast *nondeterministic* hardware RNG for faster Dropout # ------------------------------------------------------------------------------ -def _hardware_uniform( +def _hardware_uniform( # pytype: disable=annotation-type-mismatch # jnp-type rng_key: Array, shape: Shape, dtype: jnp.dtype = np.float32,