Skip to content

Commit

Permalink
Support CaliFree and Unweighted NE in TorchRec (pytorch#2540)
Browse files Browse the repository at this point in the history
Summary:

### Overview
Git pull request TBD after approvals.

This diff implements CaliFree and Unweighted NE metrics. The new metrics will not be attached to existing NE metric to avoid cluttering the additional options.

### Implementation
CaliFree: 
  raw_ne / (
        -pos_labels * torch.log2(weighted_sum_predictions / weighted_num_samples)
        - (weighted_num_samples - pos_labels)
        * torch.log2(1 - (weighted_sum_predictions / weighted_num_samples))
    )
Unweighted:
  weights = 1

Differential Revision: D65311797
  • Loading branch information
monofb authored and facebook-github-bot committed Nov 5, 2024
1 parent 0512183 commit 97c2f66
Show file tree
Hide file tree
Showing 7 changed files with 775 additions and 0 deletions.
226 changes: 226 additions & 0 deletions torchrec/metrics/cali_free_ne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, cast, Dict, List, Optional, Type

import torch
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)
from torchrec.pt2.utils import pt2_compile_callable


def compute_cross_entropy(
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
eta: float,
) -> torch.Tensor:
predictions = predictions.double()
predictions.clamp_(min=eta, max=1 - eta)
cross_entropy = -weights * labels * torch.log2(predictions) - weights * (
1.0 - labels
) * torch.log2(1.0 - predictions)
return cross_entropy


def _compute_cross_entropy_norm(
mean_label: torch.Tensor,
pos_labels: torch.Tensor,
neg_labels: torch.Tensor,
eta: float,
) -> torch.Tensor:
mean_label = mean_label.double()
mean_label.clamp_(min=eta, max=1 - eta)
return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2(
1.0 - mean_label
)


@torch.fx.wrap
def _compute_ne(
ce_sum: torch.Tensor,
weighted_num_samples: torch.Tensor,
pos_labels: torch.Tensor,
neg_labels: torch.Tensor,
eta: float,
) -> torch.Tensor:
# Goes into this block if all elements in weighted_num_samples > 0
weighted_num_samples = weighted_num_samples.double().clamp(min=eta)
mean_label = pos_labels / weighted_num_samples
ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta)
return ce_sum / ce_norm


def compute_cali_free_ne(
ce_sum: torch.Tensor,
weighted_num_samples: torch.Tensor,
pos_labels: torch.Tensor,
neg_labels: torch.Tensor,
weighted_sum_predictions: torch.Tensor,
eta: float,
allow_missing_label_with_zero_weight: bool = False,
) -> torch.Tensor:
if allow_missing_label_with_zero_weight and not weighted_num_samples.all():
# If nan were to occur, return a dummy value instead of nan if
# allow_missing_label_with_zero_weight is True
return torch.tensor([eta])
raw_ne = _compute_ne(
ce_sum=ce_sum,
weighted_num_samples=weighted_num_samples,
pos_labels=pos_labels,
neg_labels=neg_labels,
eta=eta,
)
return raw_ne / (
-pos_labels * torch.log2(weighted_sum_predictions / weighted_num_samples)
- (weighted_num_samples - pos_labels)
* torch.log2(1 - (weighted_sum_predictions / weighted_num_samples))
)


def get_cali_free_ne_states(
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
eta: float,
) -> Dict[str, torch.Tensor]:
cross_entropy = compute_cross_entropy(
labels,
predictions,
weights,
eta,
)
return {
"cross_entropy_sum": torch.sum(cross_entropy, dim=-1),
"weighted_num_samples": torch.sum(weights, dim=-1),
"pos_labels": torch.sum(weights * labels, dim=-1),
"neg_labels": torch.sum(weights * (1.0 - labels), dim=-1),
"weighted_sum_predictions": torch.sum(weights * predictions, dim=-1),
}


class CaliFreeNEMetricComputation(RecMetricComputation):
r"""
This class implements the RecMetricComputation for CaliFree NE, i.e. Normalized Entropy.
The constructor arguments are defined in RecMetricComputation.
See the docstring of RecMetricComputation for more detail.
Args:
allow_missing_label_with_zero_weight (bool): allow missing label to have weight 0, instead of throwing exception.
"""

def __init__(
self,
*args: Any,
allow_missing_label_with_zero_weight: bool = False,
**kwargs: Any,
) -> None:
self._allow_missing_label_with_zero_weight: bool = (
allow_missing_label_with_zero_weight
)
super().__init__(*args, **kwargs)
self._add_state(
"cross_entropy_sum",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"weighted_num_samples",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"pos_labels",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"neg_labels",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"weighted_sum_predictions",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self.eta = 1e-12

@pt2_compile_callable
def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' should not be None for CaliFreeNEMetricComputation update"
)
states = get_cali_free_ne_states(labels, predictions, weights, self.eta)
num_samples = predictions.shape[-1]

for state_name, state_value in states.items():
state = getattr(self, state_name)
state += state_value
self._aggregate_window_state(state_name, state_value, num_samples)

def _compute(self) -> List[MetricComputationReport]:
reports = [
MetricComputationReport(
name=MetricName.CALI_FREE_NE,
metric_prefix=MetricPrefix.LIFETIME,
value=compute_cali_free_ne(
cast(torch.Tensor, self.cross_entropy_sum),
cast(torch.Tensor, self.weighted_num_samples),
cast(torch.Tensor, self.pos_labels),
cast(torch.Tensor, self.neg_labels),
cast(torch.Tensor, self.weighted_sum_predictions),
self.eta,
self._allow_missing_label_with_zero_weight,
),
),
MetricComputationReport(
name=MetricName.CALI_FREE_NE,
metric_prefix=MetricPrefix.WINDOW,
value=compute_cali_free_ne(
self.get_window_state("cross_entropy_sum"),
self.get_window_state("weighted_num_samples"),
self.get_window_state("pos_labels"),
self.get_window_state("neg_labels"),
self.get_window_state("weighted_sum_predictions"),
self.eta,
self._allow_missing_label_with_zero_weight,
),
),
]
return reports


class CaliFreeNEMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.CALI_FREE_NE
_computation_class: Type[RecMetricComputation] = CaliFreeNEMetricComputation
4 changes: 4 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchrec.metrics.accuracy import AccuracyMetric
from torchrec.metrics.auc import AUCMetric
from torchrec.metrics.auprc import AUPRCMetric
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
from torchrec.metrics.calibration import CalibrationMetric
from torchrec.metrics.ctr import CTRMetric
from torchrec.metrics.mae import MAEMetric
Expand Down Expand Up @@ -57,6 +58,7 @@
from torchrec.metrics.tensor_weighted_avg import TensorWeightedAvgMetric
from torchrec.metrics.throughput import ThroughputMetric
from torchrec.metrics.tower_qps import TowerQPSMetric
from torchrec.metrics.unweighted_ne import UnweightedNEMetric
from torchrec.metrics.weighted_avg import WeightedAvgMetric
from torchrec.metrics.xauc import XAUCMetric

Expand Down Expand Up @@ -88,6 +90,8 @@
RecMetricEnum.SERVING_CALIBRATION: ServingCalibrationMetric,
RecMetricEnum.OUTPUT: OutputMetric,
RecMetricEnum.TENSOR_WEIGHTED_AVG: TensorWeightedAvgMetric,
RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric,
RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric,
}


Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class RecMetricEnum(RecMetricEnumBase):
SERVING_CALIBRATION = "serving_calibration"
OUTPUT = "output"
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"
CALI_FREE_NE = "cali_free_ne"
UNWEIGHTED_NE = "unweighted_ne"


@dataclass(unsafe_hash=True, eq=True)
Expand Down
6 changes: 6 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class MetricName(MetricNameBase):
SERVING_CALIBRATION = "serving_calibration"
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"

CALI_FREE_NE = "cali_free_ne"
UNWEIGHTED_NE = "unweighted_ne"


class MetricNamespaceBase(StrValueMixin, Enum):
pass
Expand Down Expand Up @@ -120,6 +123,9 @@ class MetricNamespace(MetricNamespaceBase):
OUTPUT = "output"
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"

CALI_FREE_NE = "cali_free_ne"
UNWEIGHTED_NE = "unweighted_ne"


class MetricPrefix(StrValueMixin, Enum):
DEFAULT = ""
Expand Down
Loading

0 comments on commit 97c2f66

Please sign in to comment.