Skip to content

Commit

Permalink
query intent detection
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Dec 17, 2024
1 parent 4a3ad39 commit 937b3f1
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 15 deletions.
29 changes: 15 additions & 14 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .punctuation_normalization_mapper import PunctuationNormalizationMapper
from .python_file_mapper import PythonFileMapper
from .python_lambda_mapper import PythonLambdaMapper
from .query_intent_detection_mapper import QueryIntentDetectionMapper
from .query_sentiment_detection_mapper import QuerySentimentDetectionMapper
from .relation_identity_mapper import RelationIdentityMapper
from .remove_bibliography_mapper import RemoveBibliographyMapper
Expand Down Expand Up @@ -87,18 +88,18 @@
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PairPreferenceMapper', 'PunctuationNormalizationMapper',
'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper',
'RelationIdentityMapper', 'RemoveBibliographyMapper',
'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper',
'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
'WhitespaceNormalizationMapper'
'QueryIntentDetectionMapper', 'RelationIdentityMapper',
'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper',
'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper',
'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper',
'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper',
'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
]
1 change: 0 additions & 1 deletion data_juicer/ops/mapper/dialog_intent_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class DialogIntentDetectionMapper(Mapper):
'要求:\n'
'- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出'
'。\n'
# '备选意图类别:[信息查找, 请求建议, 其他]\n'
'用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n'
'意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n'
'意图类别:信息查找\n'
Expand Down
98 changes: 98 additions & 0 deletions data_juicer/ops/mapper/query_intent_detection_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Dict, Optional

from data_juicer.utils.common_utils import nested_set
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Mapper

OP_NAME = 'query_intent_detection_mapper'


@OPERATORS.register_module(OP_NAME)
class QueryIntentDetectionMapper(Mapper):
"""
Mapper to predict user's Intent label in query. Input from query_key.
Output intensity label and corresponding score for the query, which is
store in 'intent.query_label' and 'intent.query_label_score' in
Data-Juicer meta field.
"""

_accelerator = 'cuda'
_batched_op = True

DEFAULT_LABEL_TO_INTENSITY = {}

def __init__(
self,
hf_model: str = 'Falconsai/intent_classification',
zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en',
model_params: Dict = {},
zh_to_en_model_params: Dict = {},
*,
label_to_intensity: Dict = None,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model ID to predict sentiment intensity.
:param zh_to_en_hf_model: Translation model from Chinese to English.
If not None, translate the query from Chinese to English.
:param model_params: model param for hf_model.
:param zh_to_en_model_params: model param for zh_to_hf_model.
:param label_to_intensity: Map the output labels to the intensities
instead of the default mapper if not None.
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)

self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model,
return_pipe=True,
pipe_task='text-classification',
**model_params)

if zh_to_en_hf_model is not None:
self.zh_to_en_model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=zh_to_en_hf_model,
return_pipe=True,
pipe_task='translation',
**zh_to_en_model_params)
else:
self.zh_to_en_model_key = None

if label_to_intensity is not None:
self.label_to_intensity = label_to_intensity
else:
self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY

def process_batched(self, samples, rank=None):
queries = samples[self.query_key]

if self.zh_to_en_model_key is not None:
translater, _ = get_model(self.zh_to_en_model_key, rank,
self.use_cuda())
results = translater(queries)
queries = [item['translation_text'] for item in results]

classifier, _ = get_model(self.model_key, rank, self.use_cuda())
results = classifier(queries)
intensities = [
self.label_to_intensity[r['label']]
if r['label'] in self.label_to_intensity else r['label']
for r in results
]
scores = [r['score'] for r in results]

if Fields.meta not in samples:
samples[Fields.meta] = [{} for val in intensities]
for i in range(len(samples[Fields.meta])):
samples[Fields.meta][i] = nested_set(samples[Fields.meta][i],
MetaKeys.query_intent_label,
intensities[i])
samples[Fields.meta][i] = nested_set(samples[Fields.meta][i],
MetaKeys.query_intent_score,
scores[i])

return samples
61 changes: 61 additions & 0 deletions tests/ops/mapper/test_query_intent_detection_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest
import json

from loguru import logger

from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.query_intent_detection_mapper import QueryIntentDetectionMapper
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
DataJuicerTestCaseBase)
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.common_utils import nested_access

class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase):

hf_model = 'Falconsai/intent_classification'
zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en'

def _run_op(self, op, samples, intensity_key, targets):
dataset = Dataset.from_list(samples)
dataset = dataset.map(op.process, batch_size=2)

for sample, target in zip(dataset, targets):
intensity = nested_access(sample[Fields.meta], intensity_key)
self.assertEqual(intensity, target)

def test_default(self):

samples = [{
'query': '我要一个汉堡。'
},{
'query': '你最近过得怎么样?'
},{
'query': '它是正方形的。'
}
]
targets = [1, 0, -1]

op = QueryIntentDetectionMapper(
hf_model = self.hf_model,
zh_to_en_hf_model = self.zh_to_en_hf_model,
)
self._run_op(op, samples, MetaKeys.query_sentiment_label, targets)

def test_no_zh_to_en(self):

samples = [{
'query': '它是正方形的。'
},{
'query': 'It is square.'
}
]
targets = [0, 1]

op = QueryIntentDetectionMapper(
hf_model = self.hf_model,
zh_to_en_hf_model = None,
)
self._run_op(op, samples, MetaKeys.query_sentiment_label, targets)

if __name__ == '__main__':
unittest.main()

0 comments on commit 937b3f1

Please sign in to comment.