From 6a43eecd6599dee6e2e9c8f27b94c4588f3cf47d Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Thu, 12 Dec 2024 11:02:38 +0800 Subject: [PATCH] dialog sent intensity --- data_juicer/ops/mapper/__init__.py | 7 +- data_juicer/ops/mapper/calibrate_qa_mapper.py | 2 + .../dialog_sentiment_intensity_mapper.py | 199 ++++++++++++++++++ data_juicer/utils/constant.py | 6 + .../test_dialog_sentiment_intensity_mapper.py | 64 ++++++ .../test_extract_entity_attribute_mapper.py | 2 +- .../test_extract_entity_relation_mapper.py | 2 +- tests/ops/mapper/test_extract_event_mapper.py | 2 +- .../ops/mapper/test_extract_keyword_mapper.py | 2 +- .../mapper/test_extract_nickname_mapper.py | 2 +- .../test_extract_support_text_mapper.py | 2 +- .../mapper/test_relation_identity_mapper.py | 2 +- 12 files changed, 282 insertions(+), 10 deletions(-) create mode 100644 data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py create mode 100644 tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5a740d192..a994a4ba4 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -8,6 +8,7 @@ from .clean_html_mapper import CleanHtmlMapper from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper +from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper from .expand_macro_mapper import ExpandMacroMapper from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper from .extract_entity_relation_mapper import ExtractEntityRelationMapper @@ -70,9 +71,9 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', - 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', - 'ExtractEntityRelationMapper', 'ExtractEventMapper', - 'ExtractKeywordMapper', 'ExtractNicknameMapper', + 'DialogSentimentIntensityMapper', 'ExpandMacroMapper', + 'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper', + 'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'ExtractSupportTextMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 8480ee899..bf9686409 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -55,6 +55,8 @@ def __init__(self, :param reference_template: Template for formatting the reference text. :param qa_pair_template: Template for formatting question-answer pairs. :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py new file mode 100644 index 000000000..899bc2614 --- /dev/null +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -0,0 +1,199 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'dialog_sentiment_intensity_mapper' + + +# TODO: LLM-based inference. +@OPERATORS.register_module(OP_NAME) +class DialogSentimentIntensityMapper(Mapper): + """ + Mapper to predict user's sentiment intensity in dialog which is stored + in the history_key. + """ + + DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' + '要求:\n' + '- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,' + '-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n' + '- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n' + '用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n' + '情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n' + '情绪值:0\n' + 'LLM:当然可以!可持续发展是指在满足当代人的需求的同时,不损害子孙后代满足其自' + '身需求的能力的发展模式。它包括经济发展、社会发展和环境保护三个主要方面。通过合' + '理利用资源和保护环境,我们可以确保未来的世代也能享有健全的生态系统和经济制度。\n' + '用户:谢谢你的解释!那你能告诉我一些普通人可以采取的可持续生活方式吗?\n' + '情绪分析:对回答感到满意,情绪正面。\n' + '情绪值:1\n' + 'LLM:当然可以,普通人可以通过减少一次性产品的使用、选择公共交通或拼车、节约用' + '水、以及支持本地和可持续发展的企业等方式来践行可持续生活。此外,关注垃圾分类和' + '多用电子账单也是不错的选择。\n' + '用户:你提到支持本地企业,这一点我很感兴趣。能详细说说为什么这对可持续发展有促' + '进作用吗?\n' + '情绪分析:觉得回答实用且具体,情绪进一步转好。\n' + '情绪值:2\n' + 'LLM:呃,我最近发现了一部新电影,讲述了一个关于外星人和地球土著合作保护环境的' + '故事。虽然它是科幻片,但很有启发性,推荐你去看看。\n' + '用户:什么吗,根本是答非所问。\n' + '情绪分析:LLM没有回应问题而是提到无关内容,导致用户情绪直线下降。\n' + '情绪值:-2\n' + 'LLM:抱歉刚才的偏题!支持本地企业有助于减少长途运输产生的碳足迹,使供应链更加' + '环保。此外,本地企业也更有可能采用可持续的生产方式,同时促进社区经济的繁荣。\n' + '用户:还行吧,算你能够掰回来。\n' + '情绪分析:问题得到解答,问题偏题得到纠正,情绪稍有好转。\n' + '情绪值:-1\n') + DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' + DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n' + DEFAULT_INTENSITY_TEMPLATE = '情绪值:{intensity}\n' + DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n' + DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)' + + def __init__(self, + api_model: str = 'gpt-4o', + max_round: PositiveInt = 10, + intensity_key: str = MetaKeys.sentiment_intensity, + analysis_key: str = MetaKeys.sentiment_analysis, + *, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + intensity_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + intensity_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + + :param api_model: API model name. + :param max_round: The max num of round in the dialog to build the + prompt. + :param intensity_key: The output (nested) key of the sentiment + intensity. Defaults to '__dj__meta.sentiment.intensity'. + :param analysis_key: The output (nested) key of the sentiment + analysis. Defaults to '__dj__meta.sentiment.analysis'. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param query_template: Template for query part to build the input + prompt. + :param response_template: Template for response part to build the + input prompt. + :param analysis_template: Template for analysis part to build the + input prompt. + :param intensity_template: Template for intensity part to build the + input prompt. + :param analysis_pattern: Pattern to parse the return sentiment + analysis. + :param intensity_pattern: Pattern to parse the return sentiment + intensity. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.max_round = max_round + self.intensity_key = intensity_key + self.analysis_key = analysis_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE + self.response_template = response_template or \ + self.DEFAULT_RESPONSE_TEMPLATE + self.analysis_template = analysis_template or \ + self.DEFAULT_ANALYSIS_TEMPLATE + self.intensity_template = intensity_template or \ + self.DEFAULT_INTENSITY_TEMPLATE + self.analysis_pattern = analysis_pattern or \ + self.DEFAULT_ANALYSIS_PATTERN + self.intensity_pattern = intensity_pattern or \ + self.DEFAULT_INTENSITY_PATTERN + + self.sampling_params = sampling_params + + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + + def build_input(self, history, query): + input_prompt = ''.join(history[-self.max_round * 4:]) + input_prompt += self.query_template.format(query=query[0]) + + return input_prompt + + def parse_output(self, response): + analysis = '' + intensity = 0 + + match = re.search(self.analysis_pattern, response) + if match: + analysis = match.group(1) + + match = re.search(self.intensity_pattern, response) + if match: + intensity = int(match.group(1)) + + return analysis, intensity + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + analysis_list = [] + intensities = [] + history = [] + + for qa in sample[self.history_key]: + input_prompt = self.build_input(history, qa[0]) + messages = [{ + 'role': 'system', + 'content': self.system_prompt, + }, { + 'role': 'user', + 'content': input_prompt, + }] + + for _ in range(self.try_num): + try: + response = client(messages, **self.sampling_params) + analysis, intensity = self.parse_output(response) + if len(analysis) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + analysis_list.append(analysis) + intensities.append(intensity) + + history.append(self.query_template.format(query=qa[0])) + history.append(self.analysis_template.format(analysis=analysis)) + history.append(self.intensity_template.format(intensity=intensity)) + history.append(self.response_template.format(response=qa[1])) + + sample = nested_set(sample, self.analysis_key, analysis_list) + sample = nested_set(sample, self.intensity_key, intensities) + + return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 350181f41..219ba68c3 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -10,6 +10,12 @@ DEFAULT_PREFIX = '__dj__' +class MetaKeys(object): + + sentiment_intensity = DEFAULT_PREFIX + 'meta.sentiment.intensity' + sentiment_analysis = DEFAULT_PREFIX + 'meta.sentiment.analysis' + + class Fields(object): stats = DEFAULT_PREFIX + 'stats__' meta = DEFAULT_PREFIX + 'meta__' diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py new file mode 100644 index 000000000..8d37c974b --- /dev/null +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -0,0 +1,64 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.common_utils import nested_access + +# Skip tests for this OP. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): + + + def _run_op(self, op): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + analysis_list = nested_access(dataset, MetaKeys.sentiment_analysis) + intensity_list = nested_access(dataset, MetaKeys.sentiment_intensity) + + for analysis, intensity in zip(analysis_list, intensity_list): + logger.info(f'分析:{analysis}') + logger.info(f'情绪:{intensity}') + + def default_test(self): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op) + + def max_round_test(self): + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index f15b4ca3f..a2c156d48 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py index 40e3ca32d..0aed4fcee 100644 --- a/tests/ops/mapper/test_extract_entity_relation_mapper.py +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index aba40d73e..e936cb06c 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractEventMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py index 5836f902a..2501a46ca 100644 --- a/tests/ops/mapper/test_extract_keyword_mapper.py +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractKeywordMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 2911a1002..457a7d53b 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractNicknameMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py index 0445d2526..080dfd672 100644 --- a/tests/ops/mapper/test_extract_support_text_mapper.py +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -10,7 +10,7 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.common_utils import nested_access -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class ExtractSupportTextMapperTest(DataJuicerTestCaseBase): diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py index d730cb79f..231b20ba1 100644 --- a/tests/ops/mapper/test_relation_identity_mapper.py +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -9,7 +9,7 @@ DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields -# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class RelationIdentityMapperTest(DataJuicerTestCaseBase):