Skip to content

Commit

Permalink
setup new tensor weighted avg metric (pytorch#2413)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2413

# context

Original diff D63150195,

Broke some ads tests. Have fixed the automatic dependencies and retested such tests.

Context copied below

Reviewed By: iamzainhuda

Differential Revision: D63269804
  • Loading branch information
Bill Yang authored and facebook-github-bot committed Sep 23, 2024
1 parent a9f4b72 commit 2f238d9
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from torchrec.metrics.segmented_ne import SegmentedNEMetric
from torchrec.metrics.serving_calibration import ServingCalibrationMetric
from torchrec.metrics.serving_ne import ServingNEMetric
from torchrec.metrics.tensor_weighted_avg import TensorWeightedAvgMetric
from torchrec.metrics.throughput import ThroughputMetric
from torchrec.metrics.tower_qps import TowerQPSMetric
from torchrec.metrics.weighted_avg import WeightedAvgMetric
Expand Down Expand Up @@ -84,6 +85,7 @@
RecMetricEnum.SERVING_NE: ServingNEMetric,
RecMetricEnum.SERVING_CALIBRATION: ServingCalibrationMetric,
RecMetricEnum.OUTPUT: OutputMetric,
RecMetricEnum.TENSOR_WEIGHTED_AVG: TensorWeightedAvgMetric,
}


Expand Down
3 changes: 3 additions & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class RecMetricEnum(RecMetricEnumBase):
SERVING_NE = "serving_ne"
SERVING_CALIBRATION = "serving_calibration"
OUTPUT = "output"
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"


@dataclass(unsafe_hash=True, eq=True)
Expand All @@ -65,6 +66,8 @@ class RecTaskInfo:
None # used for session level metrics
)
is_negative_task: bool = False
tensor_name: Optional[str] = None
weighted: bool = True


class RecComputeMode(Enum):
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class MetricName(MetricNameBase):

SERVING_NE = "serving_ne"
SERVING_CALIBRATION = "serving_calibration"
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"


class MetricNamespaceBase(StrValueMixin, Enum):
Expand Down Expand Up @@ -114,6 +115,7 @@ class MetricNamespace(MetricNamespaceBase):
SERVING_CALIBRATION = "serving_calibration"

OUTPUT = "output"
TENSOR_WEIGHTED_AVG = "tensor_weighted_avg"


class MetricPrefix(StrValueMixin, Enum):
Expand Down
146 changes: 146 additions & 0 deletions torchrec/metrics/tensor_weighted_avg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, cast, Dict, List, Optional, Set, Type, Union

import torch
from torchrec.metrics.metrics_config import RecTaskInfo
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)


def get_mean(value_sum: torch.Tensor, num_samples: torch.Tensor) -> torch.Tensor:
return value_sum / num_samples


class TensorWeightedAvgMetricComputation(RecMetricComputation):
def __init__(
self,
*args: Any,
tensor_name: Optional[str] = None,
weighted: bool = True,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if tensor_name is None:
raise RecMetricException(
f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got {tensor_name}"
)
self.tensor_name: str = tensor_name
self.weighted: bool = weighted
self._add_state(
"weighted_sum",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"weighted_num_samples",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)

def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if (
"required_inputs" not in kwargs
or self.tensor_name not in kwargs["required_inputs"]
):
raise RecMetricException(
f"TensorWeightedAvgMetricComputation expects {self.tensor_name} in the required_inputs"
)
num_samples = labels.shape[0]
target_tensor = cast(torch.Tensor, kwargs["required_inputs"][self.tensor_name])
weights = cast(torch.Tensor, weights)
states = {
"weighted_sum": (
target_tensor * weights if self.weighted else target_tensor
).sum(dim=-1),
"weighted_num_samples": (
weights.sum(dim=-1)
if self.weighted
else torch.ones(weights.shape).sum(dim=-1).to(device=weights.device)
),
}
for state_name, state_value in states.items():
state = getattr(self, state_name)
state += state_value
self._aggregate_window_state(state_name, state_value, num_samples)

def _compute(self) -> List[MetricComputationReport]:
return [
MetricComputationReport(
name=MetricName.WEIGHTED_AVG,
metric_prefix=MetricPrefix.LIFETIME,
value=get_mean(
cast(torch.Tensor, self.weighted_sum),
cast(torch.Tensor, self.weighted_num_samples),
),
),
MetricComputationReport(
name=MetricName.WEIGHTED_AVG,
metric_prefix=MetricPrefix.WINDOW,
value=get_mean(
self.get_window_state("weighted_sum"),
self.get_window_state("weighted_num_samples"),
),
),
]


class TensorWeightedAvgMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.WEIGHTED_AVG
_computation_class: Type[RecMetricComputation] = TensorWeightedAvgMetricComputation

def __init__(
self,
# pyre-ignore Missing parameter annotation [2]
*args,
**kwargs: Dict[str, Any],
) -> None:

super().__init__(*args, **kwargs)

def _get_task_kwargs(
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
) -> Dict[str, Any]:
if not isinstance(task_config, RecTaskInfo):
raise RecMetricException(
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
)
return {
"tensor_name": task_config.tensor_name,
"weighted": task_config.weighted,
}

def _get_task_required_inputs(
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
) -> Set[str]:
if not isinstance(task_config, RecTaskInfo):
raise RecMetricException(
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
)
required_inputs = set()
if task_config.tensor_name is not None:
required_inputs.add(task_config.tensor_name)
return required_inputs

0 comments on commit 2f238d9

Please sign in to comment.