Skip to content

Commit

Permalink
dialog sent intensity
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Dec 12, 2024
1 parent 788a212 commit 6a43eec
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 10 deletions.
7 changes: 4 additions & 3 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .clean_html_mapper import CleanHtmlMapper
from .clean_ip_mapper import CleanIpMapper
from .clean_links_mapper import CleanLinksMapper
from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper
from .expand_macro_mapper import ExpandMacroMapper
from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper
from .extract_entity_relation_mapper import ExtractEntityRelationMapper
Expand Down Expand Up @@ -70,9 +71,9 @@
'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper',
'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper',
'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper',
'ExpandMacroMapper', 'ExtractEntityAttributeMapper',
'ExtractEntityRelationMapper', 'ExtractEventMapper',
'ExtractKeywordMapper', 'ExtractNicknameMapper',
'DialogSentimentIntensityMapper', 'ExpandMacroMapper',
'ExtractEntityAttributeMapper', 'ExtractEntityRelationMapper',
'ExtractEventMapper', 'ExtractKeywordMapper', 'ExtractNicknameMapper',
'ExtractSupportTextMapper', 'FixUnicodeMapper',
'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper',
'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper',
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/mapper/calibrate_qa_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self,
:param reference_template: Template for formatting the reference text.
:param qa_pair_template: Template for formatting question-answer pairs.
:param output_pattern: Regular expression for parsing model output.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
Expand Down
199 changes: 199 additions & 0 deletions data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import re
from typing import Dict, Optional

from loguru import logger
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.common_utils import nested_set
from data_juicer.utils.constant import MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'dialog_sentiment_intensity_mapper'


# TODO: LLM-based inference.
@OPERATORS.register_module(OP_NAME)
class DialogSentimentIntensityMapper(Mapper):
"""
Mapper to predict user's sentiment intensity in dialog which is stored
in the history_key.
"""

DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n'
'要求:\n'
'- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,'
'-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n'
'- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n'
'用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n'
'情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n'
'情绪值:0\n'
'LLM:当然可以!可持续发展是指在满足当代人的需求的同时,不损害子孙后代满足其自'
'身需求的能力的发展模式。它包括经济发展、社会发展和环境保护三个主要方面。通过合'
'理利用资源和保护环境,我们可以确保未来的世代也能享有健全的生态系统和经济制度。\n'
'用户:谢谢你的解释!那你能告诉我一些普通人可以采取的可持续生活方式吗?\n'
'情绪分析:对回答感到满意,情绪正面。\n'
'情绪值:1\n'
'LLM:当然可以,普通人可以通过减少一次性产品的使用、选择公共交通或拼车、节约用'
'水、以及支持本地和可持续发展的企业等方式来践行可持续生活。此外,关注垃圾分类和'
'多用电子账单也是不错的选择。\n'
'用户:你提到支持本地企业,这一点我很感兴趣。能详细说说为什么这对可持续发展有促'
'进作用吗?\n'
'情绪分析:觉得回答实用且具体,情绪进一步转好。\n'
'情绪值:2\n'
'LLM:呃,我最近发现了一部新电影,讲述了一个关于外星人和地球土著合作保护环境的'
'故事。虽然它是科幻片,但很有启发性,推荐你去看看。\n'
'用户:什么吗,根本是答非所问。\n'
'情绪分析:LLM没有回应问题而是提到无关内容,导致用户情绪直线下降。\n'
'情绪值:-2\n'
'LLM:抱歉刚才的偏题!支持本地企业有助于减少长途运输产生的碳足迹,使供应链更加'
'环保。此外,本地企业也更有可能采用可持续的生产方式,同时促进社区经济的繁荣。\n'
'用户:还行吧,算你能够掰回来。\n'
'情绪分析:问题得到解答,问题偏题得到纠正,情绪稍有好转。\n'
'情绪值:-1\n')
DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
DEFAULT_ANALYSIS_TEMPLATE = '情绪分析:{analysis}\n'
DEFAULT_INTENSITY_TEMPLATE = '情绪值:{intensity}\n'
DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n'
DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)'

def __init__(self,
api_model: str = 'gpt-4o',
max_round: PositiveInt = 10,
intensity_key: str = MetaKeys.sentiment_intensity,
analysis_key: str = MetaKeys.sentiment_analysis,
*,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
query_template: Optional[str] = None,
response_template: Optional[str] = None,
analysis_template: Optional[str] = None,
intensity_template: Optional[str] = None,
analysis_pattern: Optional[str] = None,
intensity_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param max_round: The max num of round in the dialog to build the
prompt.
:param intensity_key: The output (nested) key of the sentiment
intensity. Defaults to '__dj__meta.sentiment.intensity'.
:param analysis_key: The output (nested) key of the sentiment
analysis. Defaults to '__dj__meta.sentiment.analysis'.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param query_template: Template for query part to build the input
prompt.
:param response_template: Template for response part to build the
input prompt.
:param analysis_template: Template for analysis part to build the
input prompt.
:param intensity_template: Template for intensity part to build the
input prompt.
:param analysis_pattern: Pattern to parse the return sentiment
analysis.
:param intensity_pattern: Pattern to parse the return sentiment
intensity.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)

self.max_round = max_round
self.intensity_key = intensity_key
self.analysis_key = analysis_key

self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
self.response_template = response_template or \
self.DEFAULT_RESPONSE_TEMPLATE
self.analysis_template = analysis_template or \
self.DEFAULT_ANALYSIS_TEMPLATE
self.intensity_template = intensity_template or \
self.DEFAULT_INTENSITY_TEMPLATE
self.analysis_pattern = analysis_pattern or \
self.DEFAULT_ANALYSIS_PATTERN
self.intensity_pattern = intensity_pattern or \
self.DEFAULT_INTENSITY_PATTERN

self.sampling_params = sampling_params

self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)

self.try_num = try_num

def build_input(self, history, query):
input_prompt = ''.join(history[-self.max_round * 4:])
input_prompt += self.query_template.format(query=query[0])

return input_prompt

def parse_output(self, response):
analysis = ''
intensity = 0

match = re.search(self.analysis_pattern, response)
if match:
analysis = match.group(1)

match = re.search(self.intensity_pattern, response)
if match:
intensity = int(match.group(1))

return analysis, intensity

def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)

analysis_list = []
intensities = []
history = []

for qa in sample[self.history_key]:
input_prompt = self.build_input(history, qa[0])
messages = [{
'role': 'system',
'content': self.system_prompt,
}, {
'role': 'user',
'content': input_prompt,
}]

for _ in range(self.try_num):
try:
response = client(messages, **self.sampling_params)
analysis, intensity = self.parse_output(response)
if len(analysis) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')

analysis_list.append(analysis)
intensities.append(intensity)

history.append(self.query_template.format(query=qa[0]))
history.append(self.analysis_template.format(analysis=analysis))
history.append(self.intensity_template.format(intensity=intensity))
history.append(self.response_template.format(response=qa[1]))

sample = nested_set(sample, self.analysis_key, analysis_list)
sample = nested_set(sample, self.intensity_key, intensities)

return sample
6 changes: 6 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
DEFAULT_PREFIX = '__dj__'


class MetaKeys(object):

sentiment_intensity = DEFAULT_PREFIX + 'meta.sentiment.intensity'
sentiment_analysis = DEFAULT_PREFIX + 'meta.sentiment.analysis'


class Fields(object):
stats = DEFAULT_PREFIX + 'stats__'
meta = DEFAULT_PREFIX + 'meta__'
Expand Down
64 changes: 64 additions & 0 deletions tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest
import json

from loguru import logger

from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)
from data_juicer.utils.constant import MetaKeys
from data_juicer.utils.common_utils import nested_access

# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase):


def _run_op(self, op):

samples = [{
'history': [
(
'李莲花有口皆碑',
'「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。'
),
(
'是的,你确实是一个普通大夫,没什么值得夸耀的。',
'「委屈」你这话说的,我也是尽心尽力治病救人了。'
),
(
'你自己说的呀,我现在说了,你又不高兴了。',
'or of of of of or or and or of of of of of of of,,, '
),
(
'你在说什么我听不懂。',
'「委屈」我也没说什么呀,就是觉得你有点冤枉我了'
)
]
}]

dataset = Dataset.from_list(samples)
dataset = dataset.map(op.process, batch_size=2)
analysis_list = nested_access(dataset, MetaKeys.sentiment_analysis)
intensity_list = nested_access(dataset, MetaKeys.sentiment_intensity)

for analysis, intensity in zip(analysis_list, intensity_list):
logger.info(f'分析:{analysis}')
logger.info(f'情绪:{intensity}')

def default_test(self):
# before runing this test, set below environment variables:
# export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
# export OPENAI_API_KEY=your_key
op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct')
self._run_op(op)

def max_round_test(self):
op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct')
self._run_op(op)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_extract_entity_attribute_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_extract_entity_relation_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_extract_event_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractEventMapperTest(DataJuicerTestCaseBase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_extract_keyword_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractKeywordMapperTest(DataJuicerTestCaseBase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_extract_nickname_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractNicknameMapperTest(DataJuicerTestCaseBase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_extract_support_text_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from data_juicer.utils.constant import Fields
from data_juicer.utils.common_utils import nested_access

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class ExtractSupportTextMapperTest(DataJuicerTestCaseBase):
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/mapper/test_relation_identity_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields

# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError.
# Skip tests for this OP.
# These tests have been tested locally.
@SKIPPED_TESTS.register_module()
class RelationIdentityMapperTest(DataJuicerTestCaseBase):
Expand Down

0 comments on commit 6a43eec

Please sign in to comment.