Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricCountSeqAndTokens counst tokens in label (if exist) + improving epoch metrics print #352

Merged
merged 8 commits into from
May 23, 2024

Conversation

michalozeryflato
Copy link
Collaborator

  1. add optional decoder_input to MetricCountSeqAndTokens
  2. print traceback before throw - helps in debugging
  3. modify epoch stats printing - each column has its own (fixed) width

2. print traceback before throw - helps in debugging
3. modify epoch stats printing - each column has its own (fixed) width
@@ -134,6 +135,7 @@ def op_call(
+ f"error in __call__ method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below"
+ "*************************************************************************************************************************************\n"
)
print(traceback.print_exc())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you need to wrap it with print() ? And I think we still get the traceback so I'm not sure why it's needed..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the print().
adding traceback.print_exc() helps me debug -makes it easier to locate the source of the problem closer and debug it.
If you don't find this useful, I will revert the change, and make it only in my local copy of the file

ignore_index: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Count number of sequences and tokens
Args:
encoder_input_key:
key to encoder_input
decoder_input_key:
key to encoder_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key to *decoder_input :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

☝️

Comment on lines +114 to +118
if decoder_input_key is not None and decoder_input_key in batch_dict:
decoder_input = batch_dict[decoder_input_key].detach()
mask2 = get_mask(decoder_input, ignore_index)
assert mask2.shape[0] == mask.shape[0]
token_num += mask2.sum().to(dtype=torch.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making sure:
Here's the support for counting the labels tokens as well?

If so, best to double check with @mosheraboh :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SagiPolaczek I updated @mosheraboh, but better confirm the change with him.

mosheraboh
mosheraboh previously approved these changes May 16, 2024
Copy link
Collaborator

@mosheraboh mosheraboh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

col_width = max(max_col_width, max_val_width, col_width)

dashes = (col_width + 2) * len(df.columns.values)
max_val_width_per_col = np.vectorize(len)(df.values.astype(str)).max(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you've modified the code to set the width per column instead of a single width for all columns?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to make the overall width smaller - so it better fits the screen

ignore_index: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Count number of sequences and tokens
Args:
encoder_input_key:
key to encoder_input
decoder_input_key:
key to encoder_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

☝️

Note that the previous command preds[:, target] created a matrix of size [n,n].
Another simple alternative is
preds[torch.arange(preds.shape[0]), target]
mosheraboh
mosheraboh previously approved these changes May 23, 2024
Copy link
Collaborator

@mosheraboh mosheraboh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great

@@ -177,18 +194,18 @@ def _perplexity_update(
preds = preds.detach()
target = target.detach()

preds = preds.reshape(-1, preds.shape[-1])
target = target.reshape(-1)
preds = preds.view(-1, preds.shape[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? to save memory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mosheraboh From the documentation of reshape: When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.

Question is:
(1) do we assume view will always work, and otherwise alert the user. Or
(2) do we want to be permissive and copy the data (despite additional GPU memory) when view is not possible (i.e. memory is not contiguous)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really that expensive in memory?

If it is, than I would warn the user conditionally to Tensor.is_contiguous() result :)

@@ -137,6 +153,7 @@ def __init__(

# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68
# copied and not imported to not be affected by internal interface modifications.
# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice solution.

SagiPolaczek
SagiPolaczek previously approved these changes May 23, 2024
Copy link
Collaborator

@SagiPolaczek SagiPolaczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THANK YOU!

Two comments inline

Comment on lines -185 to +202
target = target.where(
target != ignore_index, torch.tensor(0, device=target.device)
)
target = target[mask]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we change the shape of target here:

With Tensor.where() we keep the same shape, but with target[mask] we take only the values where the mask is True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly ran an example:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> t
tensor([[1, 2],
        [3, 4]])
>>> t.ne(1)
tensor([[False,  True],
        [ True,  True]])
>>> m = t.ne(1)
>>> t[m]
tensor([2, 3, 4])
>>> t.where( t != 1, 0)
tensor([[0, 2],
        [3, 4]])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's critical (maybe it's even better, saving memory)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I have reduced the side of pred and target, since at the end the entries corresponding to the mask are dropped. So it seems to produce the same results, but with less memory/time (not having to process the masked entries)

@@ -177,18 +194,18 @@ def _perplexity_update(
preds = preds.detach()
target = target.detach()

preds = preds.reshape(-1, preds.shape[-1])
target = target.reshape(-1)
preds = preds.view(-1, preds.shape[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really that expensive in memory?

If it is, than I would warn the user conditionally to Tensor.is_contiguous() result :)

Copy link
Collaborator

@SagiPolaczek SagiPolaczek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

Comment on lines 202 to 203
if ignore_index is not None:
mask = target.ne(ignore_index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use get_mask() here

@michalozeryflato michalozeryflato merged commit 1810465 into master May 23, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants