Skip to content

Commit

Permalink
optax: Use specific return-type annotations in loss.py
Browse files Browse the repository at this point in the history
Instead of always returning the union type chex.Array, return the more
specific type jnp.ndarray when appropriate. This makes it easier for
functions using optax to write type-correct function which return jax
Arrays.

PiperOrigin-RevId: 548880116
  • Loading branch information
josephlr authored and OptaxDev committed Jul 18, 2023
1 parent 815da81 commit fb9ffc4
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions optax/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def l2_loss(
def huber_loss(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
delta: float = 1.) -> chex.Array:
delta: float = 1.) -> jnp.ndarray:
"""Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the `huber loss`, it is equivalent to
Expand Down Expand Up @@ -118,7 +118,7 @@ def huber_loss(
def smooth_labels(
labels: chex.Array,
alpha: float,
) -> jnp.ndarray:
) -> chex.Array:
"""Apply label smoothing.
Label smoothing is often used in combination with a cross-entropy loss.
Expand All @@ -140,7 +140,10 @@ def smooth_labels(
return (1.0 - alpha) * labels + alpha / num_categories


def sigmoid_binary_cross_entropy(logits, labels):
def sigmoid_binary_cross_entropy(
logits: chex.Array,
labels: chex.Array,
) -> jnp.ndarray:
"""Computes element-wise sigmoid cross entropy given logits and labels.
This function can be used for binary or multiclass classification (where each
Expand Down Expand Up @@ -178,7 +181,7 @@ class is an independent binary prediction and different classes are not
def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
) -> chex.Array:
) -> jnp.ndarray:
"""Computes the softmax cross entropy between sets of logits and labels.
Measures the probability error in discrete classification tasks in which
Expand Down Expand Up @@ -206,7 +209,7 @@ def softmax_cross_entropy(
def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
) -> chex.Array:
) -> jnp.ndarray:
"""Computes softmax cross entropy between sets of logits and integer labels.
Measures the probability error in discrete classification tasks in which
Expand Down Expand Up @@ -242,7 +245,7 @@ def cosine_similarity(
predictions: chex.Array,
targets: chex.Array,
epsilon: float = 0.,
) -> chex.Array:
) -> jnp.ndarray:
r"""Computes the cosine similarity between targets and predictions.
The cosine **similarity** is a measure of similarity between vectors defined
Expand Down Expand Up @@ -277,7 +280,7 @@ def cosine_distance(
predictions: chex.Array,
targets: chex.Array,
epsilon: float = 0.,
) -> chex.Array:
) -> jnp.ndarray:
r"""Computes the cosine distance between targets and predictions.
The cosine **distance**, implemented here, measures the **dissimilarity**
Expand All @@ -302,7 +305,7 @@ def cosine_distance(
def log_cosh(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
) -> chex.Array:
) -> jnp.ndarray:
"""Calculates the log-cosh loss for a set of predictions.
log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)`
Expand Down Expand Up @@ -331,7 +334,7 @@ def ctc_loss_with_forward_probs(
labels: chex.Array,
label_paddings: chex.Array,
blank_id: int = 0,
log_epsilon: float = -1e5) -> Tuple[chex.Array, chex.Array, chex.Array]:
log_epsilon: float = -1e5) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
r"""Computes CTC loss and CTC forward-probabilities.
The CTC loss is a loss function based on log-likelihoods of the model that
Expand Down Expand Up @@ -459,7 +462,7 @@ def ctc_loss(logits: chex.Array,
labels: chex.Array,
label_paddings: chex.Array,
blank_id: int = 0,
log_epsilon: float = -1e5) -> chex.Array:
log_epsilon: float = -1e5) -> jnp.ndarray:
"""Computes CTC loss.
See docstring for ``ctc_loss_with_forward_probs`` for details.
Expand Down Expand Up @@ -494,7 +497,7 @@ def ctc_loss(logits: chex.Array,

def convex_kl_divergence(
log_predictions: chex.Array, targets: chex.Array
) -> chex.Array:
) -> jnp.ndarray:
"""Computes a convex version of the Kullback-Leibler divergence loss.
Measures the information gain achieved if target probability distribution
Expand All @@ -521,7 +524,7 @@ def convex_kl_divergence(

def kl_divergence(
log_predictions: chex.Array, targets: chex.Array
) -> chex.Array:
) -> jnp.ndarray:
"""Computes the Kullback-Leibler divergence (relative entropy) loss.
Measures the information gain achieved if target probability distribution
Expand All @@ -548,7 +551,7 @@ def kl_divergence(


def kl_divergence_with_log_targets(log_predictions: chex.Array,
log_targets: chex.Array) -> chex.Array:
log_targets: chex.Array) -> jnp.ndarray:
"""Computes the Kullback-Leibler divergence (relative entropy) loss.
Version of kl_div_loss where targets are given in log-space.
Expand All @@ -569,7 +572,7 @@ def kl_divergence_with_log_targets(log_predictions: chex.Array,


def hinge_loss(predictor_outputs: chex.Array,
targets: chex.Array) -> chex.Array:
targets: chex.Array) -> jnp.ndarray:
"""Computes the hinge loss for binary classification.
Args:
Expand All @@ -584,7 +587,7 @@ def hinge_loss(predictor_outputs: chex.Array,

def poly_loss_cross_entropy(
logits: chex.Array, labels: chex.Array, epsilon: float = 2.0
) -> chex.Array:
) -> jnp.ndarray:
r"""Computes PolyLoss between logits and labels.
The PolyLoss is a loss function that decomposes commonly
Expand Down

0 comments on commit fb9ffc4

Please sign in to comment.