Skip to content

Commit

Permalink
Fix GAUC in train_pipeline (pytorch#2672)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2672

fix gauc in train_pipeline

Thanks Chenyu for the collaboration.

Reviewed By: Yonezcy

Differential Revision:
D67935693

Privacy Context Container: L1194039

fbshipit-source-id: 2e641259ffdbe9b3f8f674aa952a586b18fa47f4
  • Loading branch information
yunjiangster authored and facebook-github-bot committed Jan 9, 2025
1 parent 905612b commit 27e8101
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchrec/metrics/gauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def compute_gauc_3d(


def to_3d(
tensor_2d: torch.Tensor, seq_lengths: torch.Tensor, max_length: torch.Tensor
tensor_2d: torch.Tensor, seq_lengths: torch.Tensor, max_length: int
) -> torch.Tensor:
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(seq_lengths)
return torch.ops.fbgemm.jagged_2d_to_dense(tensor_2d, offsets, max_length)
Expand All @@ -108,7 +108,7 @@ def get_auc_states(
) -> Dict[str, torch.Tensor]:

# predictions, labels: [n_task, n_sample]
max_length = num_candidates.max()
max_length = int(num_candidates.max().item())
predictions_perm = predictions.permute(1, 0)
labels_perm = labels.permute(1, 0)
predictions_3d = to_3d(predictions_perm, num_candidates, max_length).permute(
Expand Down

0 comments on commit 27e8101

Please sign in to comment.