-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MetricCountSeqAndTokens counst tokens in label (if exist) + improving epoch metrics print #352
Conversation
michalozeryflato
commented
May 7, 2024
- add optional decoder_input to MetricCountSeqAndTokens
- print traceback before throw - helps in debugging
- modify epoch stats printing - each column has its own (fixed) width
2. print traceback before throw - helps in debugging 3. modify epoch stats printing - each column has its own (fixed) width
fuse/data/ops/op_base.py
Outdated
@@ -134,6 +135,7 @@ def op_call( | |||
+ f"error in __call__ method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below" | |||
+ "*************************************************************************************************************************************\n" | |||
) | |||
print(traceback.print_exc()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure you need to wrap it with print()
? And I think we still get the traceback so I'm not sure why it's needed..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the print().
adding traceback.print_exc() helps me debug -makes it easier to locate the source of the problem closer and debug it.
If you don't find this useful, I will revert the change, and make it only in my local copy of the file
ignore_index: Optional[int] = None, | ||
) -> Dict[str, torch.Tensor]: | ||
"""Count number of sequences and tokens | ||
Args: | ||
encoder_input_key: | ||
key to encoder_input | ||
decoder_input_key: | ||
key to encoder_input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key to *decoder_input :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
☝️
if decoder_input_key is not None and decoder_input_key in batch_dict: | ||
decoder_input = batch_dict[decoder_input_key].detach() | ||
mask2 = get_mask(decoder_input, ignore_index) | ||
assert mask2.shape[0] == mask.shape[0] | ||
token_num += mask2.sum().to(dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just making sure:
Here's the support for counting the labels tokens as well?
If so, best to double check with @mosheraboh :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SagiPolaczek I updated @mosheraboh, but better confirm the change with him.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
col_width = max(max_col_width, max_val_width, col_width) | ||
|
||
dashes = (col_width + 2) * len(df.columns.values) | ||
max_val_width_per_col = np.vectorize(len)(df.values.astype(str)).max( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here you've modified the code to set the width per column instead of a single width for all columns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, to make the overall width smaller - so it better fits the screen
ignore_index: Optional[int] = None, | ||
) -> Dict[str, torch.Tensor]: | ||
"""Count number of sequences and tokens | ||
Args: | ||
encoder_input_key: | ||
key to encoder_input | ||
decoder_input_key: | ||
key to encoder_input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
☝️
Note that the previous command preds[:, target] created a matrix of size [n,n]. Another simple alternative is preds[torch.arange(preds.shape[0]), target]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great
@@ -177,18 +194,18 @@ def _perplexity_update( | |||
preds = preds.detach() | |||
target = target.detach() | |||
|
|||
preds = preds.reshape(-1, preds.shape[-1]) | |||
target = target.reshape(-1) | |||
preds = preds.view(-1, preds.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why? to save memory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mosheraboh From the documentation of reshape: When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
Question is:
(1) do we assume view will always work, and otherwise alert the user. Or
(2) do we want to be permissive and copy the data (despite additional GPU memory) when view is not possible (i.e. memory is not contiguous)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really that expensive in memory?
If it is, than I would warn the user conditionally to Tensor.is_contiguous()
result :)
@@ -137,6 +153,7 @@ def __init__( | |||
|
|||
# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68 | |||
# copied and not imported to not be affected by internal interface modifications. | |||
# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THANK YOU!
Two comments inline
target = target.where( | ||
target != ignore_index, torch.tensor(0, device=target.device) | ||
) | ||
target = target[mask] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we change the shape of target
here:
With Tensor.where()
we keep the same shape, but with target[mask]
we take only the values where the mask is True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quickly ran an example:
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> t
tensor([[1, 2],
[3, 4]])
>>> t.ne(1)
tensor([[False, True],
[ True, True]])
>>> m = t.ne(1)
>>> t[m]
tensor([2, 3, 4])
>>> t.where( t != 1, 0)
tensor([[0, 2],
[3, 4]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it's critical (maybe it's even better, saving memory)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I have reduced the side of pred and target, since at the end the entries corresponding to the mask are dropped. So it seems to produce the same results, but with less memory/time (not having to process the masked entries)
@@ -177,18 +194,18 @@ def _perplexity_update( | |||
preds = preds.detach() | |||
target = target.detach() | |||
|
|||
preds = preds.reshape(-1, preds.shape[-1]) | |||
target = target.reshape(-1) | |||
preds = preds.view(-1, preds.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really that expensive in memory?
If it is, than I would warn the user conditionally to Tensor.is_contiguous()
result :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
if ignore_index is not None: | ||
mask = target.ne(ignore_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can use get_mask()
here