Skip to content

Commit

Permalink
MetricCountSeqAndTokens counst tokens in label (if exist) + improving…
Browse files Browse the repository at this point in the history
… epoch metrics print (#352)

* 1. add optional decoder_input to MetricCountSeqAndTokens
2. print traceback before throw - helps in debugging
3. modify epoch stats printing - each column has its own (fixed) width

* addressing Sagi's comments

* improve _perplexity_update: to consume less GPU memory.
Note that the previous command preds[:, target] created a matrix of size [n,n].
Another simple alternative is
preds[torch.arange(preds.shape[0]), target]

* fix a typo in the documentation

* fix exceptions in _perplexity_update  - caught by the test code!!

* revert back to reshape
  • Loading branch information
michalozeryflato authored May 23, 2024
1 parent d0a7250 commit 1810465
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
2 changes: 2 additions & 0 deletions fuse/data/ops/op_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fuse.utils.ndict import NDict
from fuse.data.ops.hashable_class import HashableClass
import inspect
import traceback


class OpBase(HashableClass):
Expand Down Expand Up @@ -134,6 +135,7 @@ def op_call(
+ f"error in __call__ method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below"
+ "*************************************************************************************************************************************\n"
)
traceback.print_exc()
raise


Expand Down
38 changes: 28 additions & 10 deletions fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class MetricCountSeqAndTokens(MetricPerBatchDefault):
def __init__(
self,
encoder_input: str,
decoder_input: Optional[str] = None,
ignore_index: Optional[int] = None,
state: Optional[dict] = None,
**kwargs: dict,
) -> None:
"""
:param encoder_input: key to the encoder_input
:param decoder_input: key to the decoder_input
:param ignore_index: token_id to ignore (not to count), typically pad token id
:param state: the sequence count and token count to continue for. Should be restored when we continue training.
use get_state() to get the state and save it upon checkpointing,
Expand All @@ -52,6 +54,7 @@ def __init__(
_count_seq_and_tokens_update,
ignore_index=ignore_index,
encoder_input_key=encoder_input,
decoder_input_key=decoder_input,
),
result_aggregate_func=self._count_seq_and_tokens_compute,
**kwargs,
Expand Down Expand Up @@ -82,12 +85,15 @@ def get_state(self) -> dict:
def _count_seq_and_tokens_update(
batch_dict: dict,
encoder_input_key: str,
decoder_input_key: Optional[str] = None,
ignore_index: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""Count number of sequences and tokens
Args:
encoder_input_key:
key to encoder_input
decoder_input_key:
key to decoder_input
ignore_index:
Token not to count, typically padding
Returns:
Expand All @@ -98,19 +104,29 @@ def _count_seq_and_tokens_update(
# to save GPU memory
encoder_input = encoder_input.detach()

if ignore_index is not None:
mask = encoder_input.ne(ignore_index)
else:
mask = torch.ones_like(encoder_input, dtype=torch.bool)
mask = get_mask(encoder_input, ignore_index)

seq_num = torch.tensor(
mask.shape[0], dtype=torch.int64, device=encoder_input.device
)
token_num = mask.sum().to(dtype=torch.int64)

if decoder_input_key is not None and decoder_input_key in batch_dict:
decoder_input = batch_dict[decoder_input_key].detach()
mask2 = get_mask(decoder_input, ignore_index)
assert mask2.shape[0] == mask.shape[0]
token_num += mask2.sum().to(dtype=torch.int64)
return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)}


def get_mask(input: torch.Tensor, ignore_index: int) -> torch.Tensor:
if ignore_index is not None:
mask = input.ne(ignore_index)
else:
mask = torch.ones_like(input, dtype=torch.bool)
return mask


class MetricPerplexity(MetricPerBatchDefault):
def __init__(
self,
Expand All @@ -137,6 +153,7 @@ def __init__(

# Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68
# copied and not imported to not be affected by internal interface modifications.
# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather
def _perplexity_update(
batch_dict: dict,
preds_key: str,
Expand Down Expand Up @@ -177,24 +194,25 @@ def _perplexity_update(
preds = preds.detach()
target = target.detach()

# reshape attempts to create a view, and COPIES the data if fails.
# an issue to consider: use view and alert the user if it fails?
preds = preds.reshape(-1, preds.shape[-1])
target = target.reshape(-1)

if ignore_index is not None:
mask = target.ne(ignore_index)
target = target.where(
target != ignore_index, torch.tensor(0, device=target.device)
)
target = target[mask]
preds = preds[mask]
count = mask.sum()
else:
mask = torch.ones_like(target, dtype=torch.bool)
count = torch.tensor(target.shape[0])

preds = preds[:, target].diagonal()[mask]
preds = torch.gather(preds, 1, target.view(-1, 1)).squeeze(1)
# avoid from overflow
if preds.dtype == torch.float16:
preds = preds.to(torch.float32)
preds = torch.clamp(preds, min=1e-10)
total_log_probs = -preds.log().sum()
count = mask.sum()

return {"log_probs": total_log_probs.unsqueeze(0), "token_num": count.unsqueeze(0)}

Expand Down
26 changes: 13 additions & 13 deletions fuse/utils/misc/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,22 @@ def time_display(seconds: int, granularity: int = 3) -> str:

def get_pretty_dataframe(df: pd.DataFrame, col_width: int = 25) -> str:
# check is col_width needs to be widen (so that dashes are in one line)
max_val_width = np.vectorize(len)(
df.values.astype(str)
).max() # get maximum length of all values
max_col_width = max(
[len(x) for x in df.columns]
) # get the maximum lengths of all column names
col_width = max(max_col_width, max_val_width, col_width)

dashes = (col_width + 2) * len(df.columns.values)
max_val_width_per_col = np.vectorize(len)(df.values.astype(str)).max(
axis=0
) # get maximum length of all values
col_name_widths = [len(x) for x in df.columns] # get column names length
col_widths = [
max(max(x), col_width) for x in zip(max_val_width_per_col, col_name_widths)
]

dashes = sum(col_widths) + 2 * len(df.columns.values)
df_as_string = f"\n{'-' * dashes}\n"
for col in df.columns.values:
df_as_string += f"| {col:{col_width}}"
for col, col_w in zip(df.columns.values, col_widths):
df_as_string += f"| {col:{col_w}}"
df_as_string += f"|\n{'-' * dashes}\n"
for idx, row in df.iterrows():
for col in df.columns.values:
df_as_string += f"| {row[col]:{col_width}}"
for i_col, col in enumerate(df.columns.values):
df_as_string += f"| {row[col]:{col_widths[i_col]}}"

df_as_string += f"|\n{'-' * dashes}\n"
return df_as_string
Expand Down

0 comments on commit 1810465

Please sign in to comment.