From 1810465a2c921f989f667c1cdd0ffa9ec09bbbcc Mon Sep 17 00:00:00 2001 From: Michal Ozery-Flato <104420142+michalozeryflato@users.noreply.github.com> Date: Thu, 23 May 2024 20:26:18 +0300 Subject: [PATCH] MetricCountSeqAndTokens counst tokens in label (if exist) + improving epoch metrics print (#352) * 1. add optional decoder_input to MetricCountSeqAndTokens 2. print traceback before throw - helps in debugging 3. modify epoch stats printing - each column has its own (fixed) width * addressing Sagi's comments * improve _perplexity_update: to consume less GPU memory. Note that the previous command preds[:, target] created a matrix of size [n,n]. Another simple alternative is preds[torch.arange(preds.shape[0]), target] * fix a typo in the documentation * fix exceptions in _perplexity_update - caught by the test code!! * revert back to reshape --- fuse/data/ops/op_base.py | 2 + .../sequence_gen/metrics_seq_gen_common.py | 38 ++++++++++++++----- fuse/utils/misc/misc.py | 26 ++++++------- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/fuse/data/ops/op_base.py b/fuse/data/ops/op_base.py index 04dc9c98..ccbdbf08 100644 --- a/fuse/data/ops/op_base.py +++ b/fuse/data/ops/op_base.py @@ -22,6 +22,7 @@ from fuse.utils.ndict import NDict from fuse.data.ops.hashable_class import HashableClass import inspect +import traceback class OpBase(HashableClass): @@ -134,6 +135,7 @@ def op_call( + f"error in __call__ method of op={op}, op_id={op_id}, sample_id={get_sample_id(sample_dict)} - more details below" + "*************************************************************************************************************************************\n" ) + traceback.print_exc() raise diff --git a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py index c5e2947c..6da0ff1b 100644 --- a/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py +++ b/fuse/eval/metrics/sequence_gen/metrics_seq_gen_common.py @@ -33,12 +33,14 @@ class MetricCountSeqAndTokens(MetricPerBatchDefault): def __init__( self, encoder_input: str, + decoder_input: Optional[str] = None, ignore_index: Optional[int] = None, state: Optional[dict] = None, **kwargs: dict, ) -> None: """ :param encoder_input: key to the encoder_input + :param decoder_input: key to the decoder_input :param ignore_index: token_id to ignore (not to count), typically pad token id :param state: the sequence count and token count to continue for. Should be restored when we continue training. use get_state() to get the state and save it upon checkpointing, @@ -52,6 +54,7 @@ def __init__( _count_seq_and_tokens_update, ignore_index=ignore_index, encoder_input_key=encoder_input, + decoder_input_key=decoder_input, ), result_aggregate_func=self._count_seq_and_tokens_compute, **kwargs, @@ -82,12 +85,15 @@ def get_state(self) -> dict: def _count_seq_and_tokens_update( batch_dict: dict, encoder_input_key: str, + decoder_input_key: Optional[str] = None, ignore_index: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """Count number of sequences and tokens Args: encoder_input_key: key to encoder_input + decoder_input_key: + key to decoder_input ignore_index: Token not to count, typically padding Returns: @@ -98,19 +104,29 @@ def _count_seq_and_tokens_update( # to save GPU memory encoder_input = encoder_input.detach() - if ignore_index is not None: - mask = encoder_input.ne(ignore_index) - else: - mask = torch.ones_like(encoder_input, dtype=torch.bool) + mask = get_mask(encoder_input, ignore_index) seq_num = torch.tensor( mask.shape[0], dtype=torch.int64, device=encoder_input.device ) token_num = mask.sum().to(dtype=torch.int64) + if decoder_input_key is not None and decoder_input_key in batch_dict: + decoder_input = batch_dict[decoder_input_key].detach() + mask2 = get_mask(decoder_input, ignore_index) + assert mask2.shape[0] == mask.shape[0] + token_num += mask2.sum().to(dtype=torch.int64) return {"seq_num": seq_num.unsqueeze(0), "token_num": token_num.unsqueeze(0)} +def get_mask(input: torch.Tensor, ignore_index: int) -> torch.Tensor: + if ignore_index is not None: + mask = input.ne(ignore_index) + else: + mask = torch.ones_like(input, dtype=torch.bool) + return mask + + class MetricPerplexity(MetricPerBatchDefault): def __init__( self, @@ -137,6 +153,7 @@ def __init__( # Copied internal function https://github.com/Lightning-AI/metrics/blob/825d17f32ee0b9a2a8024c89d4a09863d7eb45c3/src/torchmetrics/functional/text/perplexity.py#L68 # copied and not imported to not be affected by internal interface modifications. +# modifications: (1) reshape => view (2) apply mask at the beginning of computation (3) use torch.gather def _perplexity_update( batch_dict: dict, preds_key: str, @@ -177,24 +194,25 @@ def _perplexity_update( preds = preds.detach() target = target.detach() + # reshape attempts to create a view, and COPIES the data if fails. + # an issue to consider: use view and alert the user if it fails? preds = preds.reshape(-1, preds.shape[-1]) target = target.reshape(-1) if ignore_index is not None: mask = target.ne(ignore_index) - target = target.where( - target != ignore_index, torch.tensor(0, device=target.device) - ) + target = target[mask] + preds = preds[mask] + count = mask.sum() else: - mask = torch.ones_like(target, dtype=torch.bool) + count = torch.tensor(target.shape[0]) - preds = preds[:, target].diagonal()[mask] + preds = torch.gather(preds, 1, target.view(-1, 1)).squeeze(1) # avoid from overflow if preds.dtype == torch.float16: preds = preds.to(torch.float32) preds = torch.clamp(preds, min=1e-10) total_log_probs = -preds.log().sum() - count = mask.sum() return {"log_probs": total_log_probs.unsqueeze(0), "token_num": count.unsqueeze(0)} diff --git a/fuse/utils/misc/misc.py b/fuse/utils/misc/misc.py index fa89c3b9..f0fbf614 100644 --- a/fuse/utils/misc/misc.py +++ b/fuse/utils/misc/misc.py @@ -158,22 +158,22 @@ def time_display(seconds: int, granularity: int = 3) -> str: def get_pretty_dataframe(df: pd.DataFrame, col_width: int = 25) -> str: # check is col_width needs to be widen (so that dashes are in one line) - max_val_width = np.vectorize(len)( - df.values.astype(str) - ).max() # get maximum length of all values - max_col_width = max( - [len(x) for x in df.columns] - ) # get the maximum lengths of all column names - col_width = max(max_col_width, max_val_width, col_width) - - dashes = (col_width + 2) * len(df.columns.values) + max_val_width_per_col = np.vectorize(len)(df.values.astype(str)).max( + axis=0 + ) # get maximum length of all values + col_name_widths = [len(x) for x in df.columns] # get column names length + col_widths = [ + max(max(x), col_width) for x in zip(max_val_width_per_col, col_name_widths) + ] + + dashes = sum(col_widths) + 2 * len(df.columns.values) df_as_string = f"\n{'-' * dashes}\n" - for col in df.columns.values: - df_as_string += f"| {col:{col_width}}" + for col, col_w in zip(df.columns.values, col_widths): + df_as_string += f"| {col:{col_w}}" df_as_string += f"|\n{'-' * dashes}\n" for idx, row in df.iterrows(): - for col in df.columns.values: - df_as_string += f"| {row[col]:{col_width}}" + for i_col, col in enumerate(df.columns.values): + df_as_string += f"| {row[col]:{col_widths[i_col]}}" df_as_string += f"|\n{'-' * dashes}\n" return df_as_string