From 468f432af3510dc2916c650ca1b8cc808d5d2ec9 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 29 May 2024 18:19:49 +0800 Subject: [PATCH 1/2] add AttentionContext --- xtuner/utils/__init__.py | 3 ++- xtuner/utils/attention_context.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 xtuner/utils/attention_context.py diff --git a/xtuner/utils/__init__.py b/xtuner/utils/__init__.py index 6bc9a1173..fda80ac1c 100644 --- a/xtuner/utils/__init__.py +++ b/xtuner/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .attention_context import AttentionContext from .constants import (DEFAULT_IMAGE_TOKEN, DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX, IMAGE_TOKEN_INDEX) from .stop_criteria import StopWordStoppingCriteria @@ -7,5 +8,5 @@ __all__ = [ 'IGNORE_INDEX', 'DEFAULT_PAD_TOKEN_INDEX', 'PROMPT_TEMPLATE', 'DEFAULT_IMAGE_TOKEN', 'SYSTEM_TEMPLATE', 'StopWordStoppingCriteria', - 'IMAGE_TOKEN_INDEX' + 'IMAGE_TOKEN_INDEX', 'AttentionContext' ] diff --git a/xtuner/utils/attention_context.py b/xtuner/utils/attention_context.py new file mode 100644 index 000000000..090b3de11 --- /dev/null +++ b/xtuner/utils/attention_context.py @@ -0,0 +1,29 @@ +from mmengine.utils import ManagerMixin + + +class AttentionContext(ManagerMixin): + + def __init__(self) -> None: + self._cumulative_len = None + self._max_seqlen = None + + def update(self, seqlen_list): + cumulative_len = [0] + max_seqlen = 0 + for seqlen in seqlen_list: + cumulative_len.append(cumulative_len[-1] + seqlen) + max_seqlen = max(max_seqlen, seqlen) + self._cumulative_len = cumulative_len + self._max_seqlen = max_seqlen + + def clear(self): + self._cumulative_len = None + self._max_seqlen = None + + @property + def cumulative_len(self): + return self._cumulative_len + + @property + def max_seqlen(self): + return self._max_seqlen From 2ec31ed9edc70b9f14a6344ecd2943b6f859636a Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Wed, 29 May 2024 18:45:22 +0800 Subject: [PATCH 2/2] refactor AttentionContext --- xtuner/utils/attention_context.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/xtuner/utils/attention_context.py b/xtuner/utils/attention_context.py index 090b3de11..d6057753c 100644 --- a/xtuner/utils/attention_context.py +++ b/xtuner/utils/attention_context.py @@ -1,9 +1,10 @@ from mmengine.utils import ManagerMixin -class AttentionContext(ManagerMixin): +class MessageHub(ManagerMixin): - def __init__(self) -> None: + def __init__(self, name: str = '', **kwargs): + super().__init__(name, **kwargs) self._cumulative_len = None self._max_seqlen = None @@ -27,3 +28,24 @@ def cumulative_len(self): @property def max_seqlen(self): return self._max_seqlen + + +class AttentionContext: + + message_hub = MessageHub.get_instance('attention_context') + + @classmethod + def update(cls, seqlen_list): + cls.message_hub.update(seqlen_list) + + @classmethod + def clear(cls): + cls.message_hub.clear() + + @classmethod + def get_max_seqlen(cls): + return cls.message_hub.max_seqlen + + @classmethod + def get_cumulative_len(cls): + return cls.message_hub.cumulative_len