From 4ab9e92b7b983c580c4ea70de2b828da8adf5339 Mon Sep 17 00:00:00 2001 From: Joe Richey Date: Mon, 17 Jul 2023 21:18:07 -0700 Subject: [PATCH] optax: Use specific return-type annotations in loss.py Instead of always returning the union type chex.Array, return the more specific type jnp.ndarray. This makes it easier for users of optax to write type-correct functions returning jax Arrays. PiperOrigin-RevId: 548880116 --- optax/_src/loss.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/optax/_src/loss.py b/optax/_src/loss.py index e1062bde..bebd3e1b 100644 --- a/optax/_src/loss.py +++ b/optax/_src/loss.py @@ -86,7 +86,7 @@ def l2_loss( def huber_loss( predictions: chex.Array, targets: Optional[chex.Array] = None, - delta: float = 1.) -> chex.Array: + delta: float = 1.) -> jnp.ndarray: """Huber loss, similar to L2 loss close to zero, L1 loss away from zero. If gradient descent is applied to the `huber loss`, it is equivalent to @@ -118,7 +118,7 @@ def huber_loss( def smooth_labels( labels: chex.Array, alpha: float, -) -> jnp.ndarray: +) -> chex.Array: """Apply label smoothing. Label smoothing is often used in combination with a cross-entropy loss. @@ -140,7 +140,10 @@ def smooth_labels( return (1.0 - alpha) * labels + alpha / num_categories -def sigmoid_binary_cross_entropy(logits, labels): +def sigmoid_binary_cross_entropy( + logits: chex.Array, + labels: chex.Array, +) -> jnp.ndarray: """Computes element-wise sigmoid cross entropy given logits and labels. This function can be used for binary or multiclass classification (where each @@ -178,7 +181,7 @@ class is an independent binary prediction and different classes are not def softmax_cross_entropy( logits: chex.Array, labels: chex.Array, -) -> chex.Array: +) -> jnp.ndarray: """Computes the softmax cross entropy between sets of logits and labels. Measures the probability error in discrete classification tasks in which @@ -206,7 +209,7 @@ def softmax_cross_entropy( def softmax_cross_entropy_with_integer_labels( logits: chex.Array, labels: chex.Array, -) -> chex.Array: +) -> jnp.ndarray: """Computes softmax cross entropy between sets of logits and integer labels. Measures the probability error in discrete classification tasks in which @@ -242,7 +245,7 @@ def cosine_similarity( predictions: chex.Array, targets: chex.Array, epsilon: float = 0., -) -> chex.Array: +) -> jnp.ndarray: r"""Computes the cosine similarity between targets and predictions. The cosine **similarity** is a measure of similarity between vectors defined @@ -277,7 +280,7 @@ def cosine_distance( predictions: chex.Array, targets: chex.Array, epsilon: float = 0., -) -> chex.Array: +) -> jnp.ndarray: r"""Computes the cosine distance between targets and predictions. The cosine **distance**, implemented here, measures the **dissimilarity** @@ -302,7 +305,7 @@ def cosine_distance( def log_cosh( predictions: chex.Array, targets: Optional[chex.Array] = None, -) -> chex.Array: +) -> jnp.ndarray: """Calculates the log-cosh loss for a set of predictions. log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` @@ -331,7 +334,7 @@ def ctc_loss_with_forward_probs( labels: chex.Array, label_paddings: chex.Array, blank_id: int = 0, - log_epsilon: float = -1e5) -> Tuple[chex.Array, chex.Array, chex.Array]: + log_epsilon: float = -1e5) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: r"""Computes CTC loss and CTC forward-probabilities. The CTC loss is a loss function based on log-likelihoods of the model that @@ -459,7 +462,7 @@ def ctc_loss(logits: chex.Array, labels: chex.Array, label_paddings: chex.Array, blank_id: int = 0, - log_epsilon: float = -1e5) -> chex.Array: + log_epsilon: float = -1e5) -> jnp.ndarray: """Computes CTC loss. See docstring for ``ctc_loss_with_forward_probs`` for details. @@ -494,7 +497,7 @@ def ctc_loss(logits: chex.Array, def convex_kl_divergence( log_predictions: chex.Array, targets: chex.Array -) -> chex.Array: +) -> jnp.ndarray: """Computes a convex version of the Kullback-Leibler divergence loss. Measures the information gain achieved if target probability distribution @@ -521,7 +524,7 @@ def convex_kl_divergence( def kl_divergence( log_predictions: chex.Array, targets: chex.Array -) -> chex.Array: +) -> jnp.ndarray: """Computes the Kullback-Leibler divergence (relative entropy) loss. Measures the information gain achieved if target probability distribution @@ -548,7 +551,7 @@ def kl_divergence( def kl_divergence_with_log_targets(log_predictions: chex.Array, - log_targets: chex.Array) -> chex.Array: + log_targets: chex.Array) -> jnp.ndarray: """Computes the Kullback-Leibler divergence (relative entropy) loss. Version of kl_div_loss where targets are given in log-space. @@ -569,7 +572,7 @@ def kl_divergence_with_log_targets(log_predictions: chex.Array, def hinge_loss(predictor_outputs: chex.Array, - targets: chex.Array) -> chex.Array: + targets: chex.Array) -> jnp.ndarray: """Computes the hinge loss for binary classification. Args: @@ -584,7 +587,7 @@ def hinge_loss(predictor_outputs: chex.Array, def poly_loss_cross_entropy( logits: chex.Array, labels: chex.Array, epsilon: float = 2.0 -) -> chex.Array: +) -> jnp.ndarray: r"""Computes PolyLoss between logits and labels. The PolyLoss is a loss function that decomposes commonly