Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optax: Use specific return-type annotations in loss.py #549

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions optax/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def l2_loss(
def huber_loss(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
delta: float = 1.) -> chex.Array:
delta: float = 1.) -> jnp.ndarray:
"""Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the `huber loss`, it is equivalent to
Expand Down Expand Up @@ -118,7 +118,7 @@ def huber_loss(
def smooth_labels(
labels: chex.Array,
alpha: float,
) -> jnp.ndarray:
) -> chex.Array:
"""Apply label smoothing.
Label smoothing is often used in combination with a cross-entropy loss.
Expand All @@ -140,7 +140,10 @@ def smooth_labels(
return (1.0 - alpha) * labels + alpha / num_categories


def sigmoid_binary_cross_entropy(logits, labels):
def sigmoid_binary_cross_entropy(
logits: chex.Array,
labels: chex.Array,
) -> jnp.ndarray:
"""Computes element-wise sigmoid cross entropy given logits and labels.
This function can be used for binary or multiclass classification (where each
Expand Down Expand Up @@ -178,7 +181,7 @@ class is an independent binary prediction and different classes are not
def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
) -> chex.Array:
) -> jnp.ndarray:
"""Computes the softmax cross entropy between sets of logits and labels.
Measures the probability error in discrete classification tasks in which
Expand Down Expand Up @@ -206,7 +209,7 @@ def softmax_cross_entropy(
def softmax_cross_entropy_with_integer_labels(
logits: chex.Array,
labels: chex.Array,
) -> chex.Array:
) -> jnp.ndarray:
"""Computes softmax cross entropy between sets of logits and integer labels.
Measures the probability error in discrete classification tasks in which
Expand Down Expand Up @@ -242,7 +245,7 @@ def cosine_similarity(
predictions: chex.Array,
targets: chex.Array,
epsilon: float = 0.,
) -> chex.Array:
) -> jnp.ndarray:
r"""Computes the cosine similarity between targets and predictions.
The cosine **similarity** is a measure of similarity between vectors defined
Expand Down Expand Up @@ -277,7 +280,7 @@ def cosine_distance(
predictions: chex.Array,
targets: chex.Array,
epsilon: float = 0.,
) -> chex.Array:
) -> jnp.ndarray:
r"""Computes the cosine distance between targets and predictions.
The cosine **distance**, implemented here, measures the **dissimilarity**
Expand All @@ -302,7 +305,7 @@ def cosine_distance(
def log_cosh(
predictions: chex.Array,
targets: Optional[chex.Array] = None,
) -> chex.Array:
) -> jnp.ndarray:
"""Calculates the log-cosh loss for a set of predictions.
log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)`
Expand Down Expand Up @@ -331,7 +334,7 @@ def ctc_loss_with_forward_probs(
labels: chex.Array,
label_paddings: chex.Array,
blank_id: int = 0,
log_epsilon: float = -1e5) -> Tuple[chex.Array, chex.Array, chex.Array]:
log_epsilon: float = -1e5) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
r"""Computes CTC loss and CTC forward-probabilities.
The CTC loss is a loss function based on log-likelihoods of the model that
Expand Down Expand Up @@ -459,7 +462,7 @@ def ctc_loss(logits: chex.Array,
labels: chex.Array,
label_paddings: chex.Array,
blank_id: int = 0,
log_epsilon: float = -1e5) -> chex.Array:
log_epsilon: float = -1e5) -> jnp.ndarray:
"""Computes CTC loss.
See docstring for ``ctc_loss_with_forward_probs`` for details.
Expand Down Expand Up @@ -494,7 +497,7 @@ def ctc_loss(logits: chex.Array,

def convex_kl_divergence(
log_predictions: chex.Array, targets: chex.Array
) -> chex.Array:
) -> jnp.ndarray:
"""Computes a convex version of the Kullback-Leibler divergence loss.
Measures the information gain achieved if target probability distribution
Expand All @@ -521,7 +524,7 @@ def convex_kl_divergence(

def kl_divergence(
log_predictions: chex.Array, targets: chex.Array
) -> chex.Array:
) -> jnp.ndarray:
"""Computes the Kullback-Leibler divergence (relative entropy) loss.
Measures the information gain achieved if target probability distribution
Expand All @@ -548,7 +551,7 @@ def kl_divergence(


def kl_divergence_with_log_targets(log_predictions: chex.Array,
log_targets: chex.Array) -> chex.Array:
log_targets: chex.Array) -> jnp.ndarray:
"""Computes the Kullback-Leibler divergence (relative entropy) loss.
Version of kl_div_loss where targets are given in log-space.
Expand All @@ -569,7 +572,7 @@ def kl_divergence_with_log_targets(log_predictions: chex.Array,


def hinge_loss(predictor_outputs: chex.Array,
targets: chex.Array) -> chex.Array:
targets: chex.Array) -> jnp.ndarray:
"""Computes the hinge loss for binary classification.
Args:
Expand All @@ -584,7 +587,7 @@ def hinge_loss(predictor_outputs: chex.Array,

def poly_loss_cross_entropy(
logits: chex.Array, labels: chex.Array, epsilon: float = 2.0
) -> chex.Array:
) -> jnp.ndarray:
r"""Computes PolyLoss between logits and labels.
The PolyLoss is a loss function that decomposes commonly
Expand Down
Loading