Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricCountSeqAndTokens counst tokens in label (if exist) + improving epoch metrics print #352

Merged
merged 8 commits into from
May 23, 2024
2 changes: 2 additions & 0 deletions fuse/data/ops/op_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fuse.utils.ndict import NDict
from fuse.data.ops.hashable_class import HashableClass
import inspect
import traceback


class OpBase(HashableClass):
Expand Down Expand Up @@ -134,6 +135,7 @@ def op_call(
+ f"error in __call__ method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below"
+ "*************************************************************************************************************************************\n"
)
traceback.print_exc()
raise


Expand Down
40 changes: 28 additions & 12 deletions fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class MetricCountSeqAndTokens(MetricPerBatchDefault):
def __init__(
self,
encoder_input: str,
decoder_input: Optional[str] = None,
ignore_index: Optional[int] = None,
state: Optional[dict] = None,
**kwargs: dict,
) -> None:
"""
:param encoder_input: key to the encoder_input
:param decoder_input: key to the decoder_input
:param ignore_index: token_id to ignore (not to count), typically pad token id
:param state: the sequence count and token count to continue for. Should be restored when we continue training.
use get_state() to get the state and save it upon checkpointing,
Expand All @@ -52,6 +54,7 @@ def __init__(
_count_seq_and_tokens_update,
ignore_index=ignore_index,
encoder_input_key=encoder_input,
decoder_input_key=decoder_input,
),
result_aggregate_func=self._count_seq_and_tokens_compute,
**kwargs,
Expand Down Expand Up @@ -82,12 +85,15 @@ def get_state(self) -> dict:
def _count_seq_and_tokens_update(
batch_dict: dict,
encoder_input_key: str,
decoder_input_key: Optional[str] = None,
ignore_index: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Count number of sequences and tokens
Args:
encoder_input_key:
key to encoder_input
decoder_input_key:
key to decoder_input
ignore_index:
Token not to count, typically padding
Returns:
Expand All @@ -98,19 +104,29 @@ def _count_seq_and_tokens_update(
# to save GPU memory
encoder_input = encoder_input.detach()

if ignore_index is not None:
mask = encoder_input.ne(ignore_index)
else:
mask = torch.ones_like(encoder_input, dtype=torch.bool)
mask = get_mask(encoder_input, ignore_index)

seq_num = torch.tensor(
mask.shape[0], dtype=torch.int64, device=encoder_input.device
)
token_num = mask.sum().to(dtype=torch.int64)

if decoder_input_key is not None and decoder_input_key in batch_dict:
decoder_input = batch_dict[decoder_input_key].detach()
mask2 = get_mask(decoder_input, ignore_index)
assert mask2.shape[0] == mask.shape[0]
token_num += mask2.sum().to(dtype=torch.int64)
Comment on lines +114 to +118
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making sure:
Here's the support for counting the labels tokens as well?

If so, best to double check with @mosheraboh :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SagiPolaczek I updated @mosheraboh, but better confirm the change with him.

return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)}


def get_mask(input: torch.Tensor, ignore_index: int) -> torch.Tensor:
if ignore_index is not None:
mask = input.ne(ignore_index)
else:
mask = torch.ones_like(input, dtype=torch.bool)
return mask


class MetricPerplexity(MetricPerBatchDefault):
def __init__(
self,
Expand All @@ -137,6 +153,7 @@ def __init__(

# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68
# copied and not imported to not be affected by internal interface modifications.
# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice solution.

def _perplexity_update(
batch_dict: dict,
preds_key: str,
Expand Down Expand Up @@ -177,24 +194,23 @@ def _perplexity_update(
preds = preds.detach()
target = target.detach()

preds = preds.reshape(-1, preds.shape[-1])
target = target.reshape(-1)
preds = preds.view(-1, preds.shape[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? to save memory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mosheraboh From the documentation of reshape: When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.

Question is:
(1) do we assume view will always work, and otherwise alert the user. Or
(2) do we want to be permissive and copy the data (despite additional GPU memory) when view is not possible (i.e. memory is not contiguous)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really that expensive in memory?

If it is, than I would warn the user conditionally to Tensor.is_contiguous() result :)

target = target.view(-1)

if ignore_index is not None:
mask = target.ne(ignore_index)
Comment on lines 202 to 203
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use get_mask() here

target = target.where(
target != ignore_index, torch.tensor(0, device=target.device)
)
target = target[mask]
Comment on lines -185 to +204
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we change the shape of target here:

With Tensor.where() we keep the same shape, but with target[mask] we take only the values where the mask is True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly ran an example:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> t
tensor([[1, 2],
        [3, 4]])
>>> t.ne(1)
tensor([[False,  True],
        [ True,  True]])
>>> m = t.ne(1)
>>> t[m]
tensor([2, 3, 4])
>>> t.where( t != 1, 0)
tensor([[0, 2],
        [3, 4]])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's critical (maybe it's even better, saving memory)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I have reduced the side of pred and target, since at the end the entries corresponding to the mask are dropped. So it seems to produce the same results, but with less memory/time (not having to process the masked entries)

preds = preds[mask]
count = mask.sum()
else:
mask = torch.ones_like(target, dtype=torch.bool)
count = torch.tensor(target.shape[0])

preds = preds[:, target].diagonal()[mask]
preds = torch.gather(preds, 1, target.view(-1, 1)).squeeze(1)
# avoid from overflow
if preds.dtype == torch.float16:
preds = preds.to(torch.float32)
preds = torch.clamp(preds, min=1e-10)
total_log_probs = -preds.log().sum()
count = mask.sum()

return {"log_probs": total_log_probs.unsqueeze(0), "token_num": count.unsqueeze(0)}

Expand Down
26 changes: 13 additions & 13 deletions fuse/utils/misc/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,22 @@ def time_display(seconds: int, granularity: int = 3) -> str:

def get_pretty_dataframe(df: pd.DataFrame, col_width: int = 25) -> str:
# check is col_width needs to be widen (so that dashes are in one line)
max_val_width = np.vectorize(len)(
df.values.astype(str)
).max() # get maximum length of all values
max_col_width = max(
[len(x) for x in df.columns]
) # get the maximum lengths of all column names
col_width = max(max_col_width, max_val_width, col_width)

dashes = (col_width + 2) * len(df.columns.values)
max_val_width_per_col = np.vectorize(len)(df.values.astype(str)).max(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you've modified the code to set the width per column instead of a single width for all columns?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to make the overall width smaller - so it better fits the screen

axis=0
) # get maximum length of all values
col_name_widths = [len(x) for x in df.columns] # get column names length
col_widths = [
max(max(x), col_width) for x in zip(max_val_width_per_col, col_name_widths)
]

dashes = sum(col_widths) + 2 * len(df.columns.values)
df_as_string = f"\n{'-' * dashes}\n"
for col in df.columns.values:
df_as_string += f"| {col:{col_width}}"
for col, col_w in zip(df.columns.values, col_widths):
df_as_string += f"| {col:{col_w}}"
df_as_string += f"|\n{'-' * dashes}\n"
for idx, row in df.iterrows():
for col in df.columns.values:
df_as_string += f"| {row[col]:{col_width}}"
for i_col, col in enumerate(df.columns.values):
df_as_string += f"| {row[col]:{col_widths[i_col]}}"

df_as_string += f"|\n{'-' * dashes}\n"
return df_as_string
Expand Down
Loading