Skip to content

Commit

Permalink
Allow Missing Labels for MTML models
Browse files Browse the repository at this point in the history
Differential Revision: D50273422
  • Loading branch information
Paul Zhang authored and facebook-github-bot committed Nov 6, 2023
1 parent 5fce39c commit 628a8c0
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 2 deletions.
19 changes: 18 additions & 1 deletion torchrec/metrics/ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,21 @@ def _compute_cross_entropy_norm(
)


@torch.fx.wrap
def compute_ne(
ce_sum: torch.Tensor,
weighted_num_samples: torch.Tensor,
pos_labels: torch.Tensor,
neg_labels: torch.Tensor,
eta: float,
allow_missing_label_with_zero_weight: bool,
) -> torch.Tensor:
if allow_missing_label_with_zero_weight and not weighted_num_samples.all():
# If nan were to occur, return a dummy value instead of nan if
# allow_missing_label_with_zero_weight is True
return torch.tensor([eta])

# Goes into this block if all elements in weighted_num_samples > 0
mean_label = pos_labels / weighted_num_samples
ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta)
return ce_sum / ce_norm
Expand Down Expand Up @@ -96,9 +104,16 @@ class NEMetricComputation(RecMetricComputation):
"""

def __init__(
self, *args: Any, include_logloss: bool = False, **kwargs: Any
self,
*args: Any,
include_logloss: bool = False,
allow_missing_label_with_zero_weight: bool = False,
**kwargs: Any,
) -> None:
self._include_logloss: bool = include_logloss
self._allow_missing_label_with_zero_weight: bool = (
allow_missing_label_with_zero_weight
)
super().__init__(*args, **kwargs)
self._add_state(
"cross_entropy_sum",
Expand Down Expand Up @@ -161,6 +176,7 @@ def _compute(self) -> List[MetricComputationReport]:
cast(torch.Tensor, self.pos_labels),
cast(torch.Tensor, self.neg_labels),
self.eta,
self._allow_missing_label_with_zero_weight,
),
),
MetricComputationReport(
Expand All @@ -172,6 +188,7 @@ def _compute(self) -> List[MetricComputationReport]:
self.get_window_state("pos_labels"),
self.get_window_state("neg_labels"),
self.eta,
self._allow_missing_label_with_zero_weight,
),
),
]
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
should_validate_update: bool = False,
process_group: Optional[dist.ProcessGroup] = None,
fused_update_limit: int = 0,
allow_missing_label_with_zero_weight: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down
17 changes: 16 additions & 1 deletion torchrec/metrics/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,24 @@ def rec_metric_value_test_helper(
is_time_dependent: bool = False,
time_dependent_metric: Optional[Dict[Type[RecMetric], str]] = None,
n_classes: Optional[int] = None,
zero_weights: bool = False,
**kwargs: Any,
) -> Tuple[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], ...]]:
tasks = gen_test_tasks(task_names)
model_outs = []

for _ in range(nsteps):
weight_value: Optional[torch.Tensor] = None
if zero_weights:
weight_value = torch.zeros(batch_size)

_model_outs = [
gen_test_batch(
label_name=task.label_name,
prediction_name=task.prediction_name,
weight_name=task.weight_name,
batch_size=batch_size,
n_classes=n_classes,
weight_value=weight_value,
)
for task in tasks
]
Expand All @@ -254,6 +259,9 @@ def get_target_rec_metric_value(
window_size = world_size * batch_size * batch_window_size
if n_classes:
kwargs["number_of_classes"] = n_classes
if zero_weights:
kwargs["allow_missing_label_with_zero_weight"] = True

target_metric_obj = target_clazz(
world_size=world_size,
my_rank=my_rank,
Expand Down Expand Up @@ -345,6 +353,7 @@ def rec_metric_value_test_launcher(
batch_window_size: int = BATCH_WINDOW_SIZE,
test_nsteps: int = 1,
n_classes: Optional[int] = None,
zero_weights: bool = False,
**kwargs: Any,
) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -368,8 +377,10 @@ def rec_metric_value_test_launcher(
nsteps=test_nsteps,
batch_window_size=1,
n_classes=n_classes,
zero_weights=zero_weights,
**kwargs,
)

pet.elastic_launch(lc, entrypoint=entry_point)(
target_clazz,
target_compute_mode,
Expand All @@ -382,6 +393,7 @@ def rec_metric_value_test_launcher(
batch_window_size,
n_classes,
test_nsteps,
zero_weights,
)


Expand All @@ -407,6 +419,7 @@ def metric_test_helper(
batch_window_size: int = BATCH_WINDOW_SIZE,
n_classes: Optional[int] = None,
nsteps: int = 1,
zero_weights: bool = False,
is_time_dependent: bool = False,
time_dependent_metric: Optional[Dict[Type[RecMetric], str]] = None,
**kwargs: Any,
Expand All @@ -418,6 +431,7 @@ def metric_test_helper(
world_size=world_size,
rank=rank,
)

target_metrics, test_metrics = rec_metric_value_test_helper(
target_clazz=target_clazz,
target_compute_mode=target_compute_mode,
Expand All @@ -433,6 +447,7 @@ def metric_test_helper(
nsteps=nsteps,
is_time_dependent=is_time_dependent,
time_dependent_metric=time_dependent_metric,
zero_weights=zero_weights,
**kwargs,
)

Expand Down
20 changes: 20 additions & 0 deletions torchrec/metrics/tests/test_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,17 @@ def _get_states(

@staticmethod
def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor:
allow_missing_label_with_zero_weight = False
if not states["weighted_num_samples"].all():
allow_missing_label_with_zero_weight = True

return compute_ne(
states["cross_entropy_sum"],
states["weighted_num_samples"],
pos_labels=states["pos_labels"],
neg_labels=states["neg_labels"],
eta=TestNEMetric.eta,
allow_missing_label_with_zero_weight=allow_missing_label_with_zero_weight,
)


Expand Down Expand Up @@ -176,6 +181,21 @@ def test_ne_update_fused(self) -> None:
# entry_point=self._test_ne_large_window_size,
# )

def test_ne_zero_weights(self) -> None:
rec_metric_value_test_launcher(
target_clazz=NEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestNEMetric,
metric_name=NEMetricTest.task_name,
task_names=["t1", "t2", "t3"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
zero_weights=True,
)

_logloss_metric_test_helper: Callable[..., None] = partial(
metric_test_helper, include_logloss=True
)
Expand Down

0 comments on commit 628a8c0

Please sign in to comment.