From 937b3f1cf5a9b9b3294920245bbaff4549f3d55a Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 14:17:54 +0800 Subject: [PATCH] query intent detection --- data_juicer/ops/mapper/__init__.py | 29 +++--- .../mapper/dialog_intent_detection_mapper.py | 1 - .../mapper/query_intent_detection_mapper.py | 98 +++++++++++++++++++ .../test_query_intent_detection_mapper.py | 61 ++++++++++++ 4 files changed, 174 insertions(+), 15 deletions(-) create mode 100644 data_juicer/ops/mapper/query_intent_detection_mapper.py create mode 100644 tests/ops/mapper/test_query_intent_detection_mapper.py diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 931cd7f2..af710bd8 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -36,6 +36,7 @@ from .punctuation_normalization_mapper import PunctuationNormalizationMapper from .python_file_mapper import PythonFileMapper from .python_lambda_mapper import PythonLambdaMapper +from .query_intent_detection_mapper import QueryIntentDetectionMapper from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper from .relation_identity_mapper import RelationIdentityMapper from .remove_bibliography_mapper import RemoveBibliographyMapper @@ -87,18 +88,18 @@ 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PairPreferenceMapper', 'PunctuationNormalizationMapper', 'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper', - 'RelationIdentityMapper', 'RemoveBibliographyMapper', - 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', - 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', - 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', - 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', - 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper', - 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', - 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', - 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', - 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', - 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', - 'WhitespaceNormalizationMapper' + 'QueryIntentDetectionMapper', 'RelationIdentityMapper', + 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', + 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', + 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', + 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', + 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', + 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', + 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', + 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper', + 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', + 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', + 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', + 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', + 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index b32fd929..759e291b 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -28,7 +28,6 @@ class DialogIntentDetectionMapper(Mapper): '要求:\n' '- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出' '。\n' - # '备选意图类别:[信息查找, 请求建议, 其他]\n' '用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n' '意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n' '意图类别:信息查找\n' diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py new file mode 100644 index 00000000..66290532 --- /dev/null +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -0,0 +1,98 @@ +from typing import Dict, Optional + +from data_juicer.utils.common_utils import nested_set +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'query_intent_detection_mapper' + + +@OPERATORS.register_module(OP_NAME) +class QueryIntentDetectionMapper(Mapper): + """ + Mapper to predict user's Intent label in query. Input from query_key. + Output intensity label and corresponding score for the query, which is + store in 'intent.query_label' and 'intent.query_label_score' in + Data-Juicer meta field. + """ + + _accelerator = 'cuda' + _batched_op = True + + DEFAULT_LABEL_TO_INTENSITY = {} + + def __init__( + self, + hf_model: str = 'Falconsai/intent_classification', + zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', + model_params: Dict = {}, + zh_to_en_model_params: Dict = {}, + *, + label_to_intensity: Dict = None, + **kwargs): + """ + Initialization method. + + :param hf_model: Hugginface model ID to predict sentiment intensity. + :param zh_to_en_hf_model: Translation model from Chinese to English. + If not None, translate the query from Chinese to English. + :param model_params: model param for hf_model. + :param zh_to_en_model_params: model param for zh_to_hf_model. + :param label_to_intensity: Map the output labels to the intensities + instead of the default mapper if not None. + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_model, + return_pipe=True, + pipe_task='text-classification', + **model_params) + + if zh_to_en_hf_model is not None: + self.zh_to_en_model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=zh_to_en_hf_model, + return_pipe=True, + pipe_task='translation', + **zh_to_en_model_params) + else: + self.zh_to_en_model_key = None + + if label_to_intensity is not None: + self.label_to_intensity = label_to_intensity + else: + self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY + + def process_batched(self, samples, rank=None): + queries = samples[self.query_key] + + if self.zh_to_en_model_key is not None: + translater, _ = get_model(self.zh_to_en_model_key, rank, + self.use_cuda()) + results = translater(queries) + queries = [item['translation_text'] for item in results] + + classifier, _ = get_model(self.model_key, rank, self.use_cuda()) + results = classifier(queries) + intensities = [ + self.label_to_intensity[r['label']] + if r['label'] in self.label_to_intensity else r['label'] + for r in results + ] + scores = [r['score'] for r in results] + + if Fields.meta not in samples: + samples[Fields.meta] = [{} for val in intensities] + for i in range(len(samples[Fields.meta])): + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_intent_label, + intensities[i]) + samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], + MetaKeys.query_intent_score, + scores[i]) + + return samples diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py new file mode 100644 index 00000000..6f8494dd --- /dev/null +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -0,0 +1,61 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.query_intent_detection_mapper import QueryIntentDetectionMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.common_utils import nested_access + +class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): + + hf_model = 'Falconsai/intent_classification' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' + + def _run_op(self, op, samples, intensity_key, targets): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for sample, target in zip(dataset, targets): + intensity = nested_access(sample[Fields.meta], intensity_key) + self.assertEqual(intensity, target) + + def test_default(self): + + samples = [{ + 'query': '我要一个汉堡。' + },{ + 'query': '你最近过得怎么样?' + },{ + 'query': '它是正方形的。' + } + ] + targets = [1, 0, -1] + + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + + def test_no_zh_to_en(self): + + samples = [{ + 'query': '它是正方形的。' + },{ + 'query': 'It is square.' + } + ] + targets = [0, 1] + + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = None, + ) + self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + +if __name__ == '__main__': + unittest.main()