diff --git a/xtuner/utils/__init__.py b/xtuner/utils/__init__.py index 6bc9a1173..fda80ac1c 100644 --- a/xtuner/utils/__init__.py +++ b/xtuner/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .attention_context import AttentionContext from .constants import (DEFAULT_IMAGE_TOKEN, DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX, IMAGE_TOKEN_INDEX) from .stop_criteria import StopWordStoppingCriteria @@ -7,5 +8,5 @@ __all__ = [ 'IGNORE_INDEX', 'DEFAULT_PAD_TOKEN_INDEX', 'PROMPT_TEMPLATE', 'DEFAULT_IMAGE_TOKEN', 'SYSTEM_TEMPLATE', 'StopWordStoppingCriteria', - 'IMAGE_TOKEN_INDEX' + 'IMAGE_TOKEN_INDEX', 'AttentionContext' ] diff --git a/xtuner/utils/attention_context.py b/xtuner/utils/attention_context.py new file mode 100644 index 000000000..d6057753c --- /dev/null +++ b/xtuner/utils/attention_context.py @@ -0,0 +1,51 @@ +from mmengine.utils import ManagerMixin + + +class MessageHub(ManagerMixin): + + def __init__(self, name: str = '', **kwargs): + super().__init__(name, **kwargs) + self._cumulative_len = None + self._max_seqlen = None + + def update(self, seqlen_list): + cumulative_len = [0] + max_seqlen = 0 + for seqlen in seqlen_list: + cumulative_len.append(cumulative_len[-1] + seqlen) + max_seqlen = max(max_seqlen, seqlen) + self._cumulative_len = cumulative_len + self._max_seqlen = max_seqlen + + def clear(self): + self._cumulative_len = None + self._max_seqlen = None + + @property + def cumulative_len(self): + return self._cumulative_len + + @property + def max_seqlen(self): + return self._max_seqlen + + +class AttentionContext: + + message_hub = MessageHub.get_instance('attention_context') + + @classmethod + def update(cls, seqlen_list): + cls.message_hub.update(seqlen_list) + + @classmethod + def clear(cls): + cls.message_hub.clear() + + @classmethod + def get_max_seqlen(cls): + return cls.message_hub.max_seqlen + + @classmethod + def get_cumulative_len(cls): + return cls.message_hub.cumulative_len