Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568468680
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Sep 26, 2023
1 parent 8697f51 commit 7e7b66e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions clrs/_src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def output_loss_chunked(truth: _DataPoint, pred: _Array,
else:
mask = _expand_and_broadcast_to(is_last, loss)
total_mask = jnp.maximum(jnp.sum(mask), EPS)
return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask
return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask # pytype: disable=bad-return-type # jnp-type


def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float:
Expand Down Expand Up @@ -112,7 +112,7 @@ def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float:
# Compute the cross entropy between doubly stochastic pred and truth_data
total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1))

return total_loss
return total_loss # pytype: disable=bad-return-type # jnp-type


def hint_loss_chunked(
Expand Down

0 comments on commit 7e7b66e

Please sign in to comment.