Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568283334
  • Loading branch information
Jake VanderPlas authored and t5-copybara committed Sep 27, 2023
1 parent ace831e commit 00dc974
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions t5x/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __hash__(self) -> int:
def _decay_rate_pow(i: int, exponent: float = 0.8) -> float:
"""Default Adafactor second-moment decay schedule."""
t = jnp.array(i, jnp.float32) + 1.0
return 1.0 - t**(-exponent)
return 1.0 - t**(-exponent) # pytype: disable=bad-return-type # jnp-type

@staticmethod
def _parse_rule(
Expand Down Expand Up @@ -412,7 +412,7 @@ def init_param_state(self, param, path):
state['v'] = jnp.zeros(param.shape, dtype=jnp.float32)
if self.hyper_params.beta1 is not None:
state['m'] = jnp.zeros(param.shape, dtype=self.dtype_momentum)
return _AdafactorParamState(**state)
return _AdafactorParamState(**state) # pytype: disable=wrong-arg-types # jnp-type

def init_state(self, params):
params_flat = utils.flatten_dict_string_keys(params)
Expand Down
6 changes: 3 additions & 3 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ class Checkpointer(object):
oldest ones will be automatically deleted to save space.
"""

def __init__(
def __init__( # pytype: disable=annotation-type-mismatch # jnp-type
self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
Expand Down Expand Up @@ -1302,7 +1302,7 @@ class CheckpointerConstructor(typing_extensions.Protocol):
of Checkpointer subclasses without triggering type errors.
"""

def __call__(self,
def __call__(self, # pytype: disable=annotation-type-mismatch # jnp-type
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
Expand Down Expand Up @@ -1456,7 +1456,7 @@ class SaveBestCheckpointer(Checkpointer):
oldest ones will be automatically deleted to save space.
"""

def __init__(self,
def __init__(self, # pytype: disable=annotation-type-mismatch # jnp-type
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
checkpoints_dir: str,
Expand Down
2 changes: 1 addition & 1 deletion t5x/contrib/calm/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def sampling_loop_cond_fn(state: SamplingLoopState) -> bool:
# Different elements in the batch can be at different loop indices, if any
# of our examples are not at the end, keep going.
all_sequences_ended = jnp.all(state.ended)
return ~all_sequences_ended
return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type

def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState:
"""Sampling loop state update."""
Expand Down
2 changes: 1 addition & 1 deletion t5x/contrib/calm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def loss_fn_meta_cls(
for meta_logits, meta_labels in zip(all_meta_logits[:-1],
all_meta_labels[:-1]):
# Balance across the positive/ negative labels.
balanced_weights = weights.copy().astype(float)
balanced_weights = weights.copy().astype(float) # pytype: disable=attribute-error # jnp-type

pos_num = (meta_labels * weights == 1).sum()
neg_num = ((1 - meta_labels) * weights == 1).sum()
Expand Down
2 changes: 1 addition & 1 deletion t5x/contrib/moe/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class UpcycleCheckpointer(checkpoints.Checkpointer):
for more details.
"""

def __init__(
def __init__( # pytype: disable=annotation-type-mismatch # jnp-type
self,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
Expand Down
2 changes: 1 addition & 1 deletion t5x/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def sampling_loop_cond_fn(state: SamplingLoopState) -> bool:
# Different elements in the batch can be at different loop indices, if any
# of our examples are not at the end, keep going.
all_sequences_ended = jnp.all(state.ended)
return ~all_sequences_ended
return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type

def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState:
"""Sampling loop state update."""
Expand Down
12 changes: 6 additions & 6 deletions t5x/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def compute_weighted_accuracy(
if weights is not None:
accuracy = accuracy * weights

return jnp.sum(accuracy)
return jnp.sum(accuracy) # pytype: disable=bad-return-type # jnp-type


# TODO(cpgaffney) remove when users rely on compute_base_metrics
Expand Down Expand Up @@ -1369,7 +1369,7 @@ def count_packed_examples(segment_ids: jnp.ndarray) -> int:
# Get the first discrete different along axis=1.
first_diff = jnp.diff(segment_ids, n=1, axis=1)
# count = #(non-0 diff) + #(row) - #(padded ex).
return jnp.sum(first_diff != 0) + segment_ids.shape[0] - num_padded_examples
return jnp.sum(first_diff != 0) + segment_ids.shape[0] - num_padded_examples # pytype: disable=bad-return-type # jnp-type


def compute_base_metrics(
Expand Down Expand Up @@ -1414,20 +1414,20 @@ def compute_base_metrics(
'loss_per_all_target_tokens': clu_metrics.Average(
total=loss, count=num_tokens
),
'timing/seqs_per_second': metrics_lib.TimeRate.from_model_output(
'timing/seqs_per_second': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type
numerator=num_examples
),
'timing/steps_per_second': metrics_lib.StepsPerTime.from_model_output(),
'timing/seconds': metrics_lib.Time(),
'timing/seqs': metrics_lib.Sum(num_examples),
'timing/seqs_per_second_per_core': metrics_lib.TimeRate.from_model_output(
'timing/seqs_per_second_per_core': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type
numerator=num_examples / num_devices
),
'timing/target_tokens_per_second': metrics_lib.TimeRate.from_model_output(
'timing/target_tokens_per_second': metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type
numerator=num_tokens
),
'timing/target_tokens_per_second_per_core': (
metrics_lib.TimeRate.from_model_output(
metrics_lib.TimeRate.from_model_output( # pytype: disable=wrong-arg-types # jnp-type
numerator=num_tokens / num_devices
)
),
Expand Down
4 changes: 2 additions & 2 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def __init__(
if save_cfg is not None:
if save_cfg.save_dataset:
assert ds_iter is not None
save_checkpointer = save_cfg.checkpointer_cls(
save_checkpointer = save_cfg.checkpointer_cls( # pytype: disable=wrong-arg-types # jnp-type
train_state=train_state_shape,
partitioner=partitioner,
checkpoints_dir=model_dir,
Expand Down Expand Up @@ -749,7 +749,7 @@ def multihost_assert_equal(input_tree, fail_message: str = ''):
# ------------------------------------------------------------------------------
# Fast *nondeterministic* hardware RNG for faster Dropout
# ------------------------------------------------------------------------------
def _hardware_uniform(
def _hardware_uniform( # pytype: disable=annotation-type-mismatch # jnp-type
rng_key: Array,
shape: Shape,
dtype: jnp.dtype = np.float32,
Expand Down

0 comments on commit 00dc974

Please sign in to comment.