-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
174 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Dict, Optional | ||
|
||
from data_juicer.utils.common_utils import nested_set | ||
from data_juicer.utils.constant import Fields, MetaKeys | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
from ..base_op import OPERATORS, Mapper | ||
|
||
OP_NAME = 'query_intent_detection_mapper' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
class QueryIntentDetectionMapper(Mapper): | ||
""" | ||
Mapper to predict user's Intent label in query. Input from query_key. | ||
Output intensity label and corresponding score for the query, which is | ||
store in 'intent.query_label' and 'intent.query_label_score' in | ||
Data-Juicer meta field. | ||
""" | ||
|
||
_accelerator = 'cuda' | ||
_batched_op = True | ||
|
||
DEFAULT_LABEL_TO_INTENSITY = {} | ||
|
||
def __init__( | ||
self, | ||
hf_model: str = 'Falconsai/intent_classification', | ||
zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', | ||
model_params: Dict = {}, | ||
zh_to_en_model_params: Dict = {}, | ||
*, | ||
label_to_intensity: Dict = None, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param hf_model: Hugginface model ID to predict sentiment intensity. | ||
:param zh_to_en_hf_model: Translation model from Chinese to English. | ||
If not None, translate the query from Chinese to English. | ||
:param model_params: model param for hf_model. | ||
:param zh_to_en_model_params: model param for zh_to_hf_model. | ||
:param label_to_intensity: Map the output labels to the intensities | ||
instead of the default mapper if not None. | ||
:param kwargs: Extra keyword arguments. | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
self.model_key = prepare_model(model_type='huggingface', | ||
pretrained_model_name_or_path=hf_model, | ||
return_pipe=True, | ||
pipe_task='text-classification', | ||
**model_params) | ||
|
||
if zh_to_en_hf_model is not None: | ||
self.zh_to_en_model_key = prepare_model( | ||
model_type='huggingface', | ||
pretrained_model_name_or_path=zh_to_en_hf_model, | ||
return_pipe=True, | ||
pipe_task='translation', | ||
**zh_to_en_model_params) | ||
else: | ||
self.zh_to_en_model_key = None | ||
|
||
if label_to_intensity is not None: | ||
self.label_to_intensity = label_to_intensity | ||
else: | ||
self.label_to_intensity = self.DEFAULT_LABEL_TO_INTENSITY | ||
|
||
def process_batched(self, samples, rank=None): | ||
queries = samples[self.query_key] | ||
|
||
if self.zh_to_en_model_key is not None: | ||
translater, _ = get_model(self.zh_to_en_model_key, rank, | ||
self.use_cuda()) | ||
results = translater(queries) | ||
queries = [item['translation_text'] for item in results] | ||
|
||
classifier, _ = get_model(self.model_key, rank, self.use_cuda()) | ||
results = classifier(queries) | ||
intensities = [ | ||
self.label_to_intensity[r['label']] | ||
if r['label'] in self.label_to_intensity else r['label'] | ||
for r in results | ||
] | ||
scores = [r['score'] for r in results] | ||
|
||
if Fields.meta not in samples: | ||
samples[Fields.meta] = [{} for val in intensities] | ||
for i in range(len(samples[Fields.meta])): | ||
samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], | ||
MetaKeys.query_intent_label, | ||
intensities[i]) | ||
samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], | ||
MetaKeys.query_intent_score, | ||
scores[i]) | ||
|
||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import unittest | ||
import json | ||
|
||
from loguru import logger | ||
|
||
from data_juicer.core.data import NestedDataset as Dataset | ||
from data_juicer.ops.mapper.query_intent_detection_mapper import QueryIntentDetectionMapper | ||
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, | ||
DataJuicerTestCaseBase) | ||
from data_juicer.utils.constant import Fields, MetaKeys | ||
from data_juicer.utils.common_utils import nested_access | ||
|
||
class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): | ||
|
||
hf_model = 'Falconsai/intent_classification' | ||
zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' | ||
|
||
def _run_op(self, op, samples, intensity_key, targets): | ||
dataset = Dataset.from_list(samples) | ||
dataset = dataset.map(op.process, batch_size=2) | ||
|
||
for sample, target in zip(dataset, targets): | ||
intensity = nested_access(sample[Fields.meta], intensity_key) | ||
self.assertEqual(intensity, target) | ||
|
||
def test_default(self): | ||
|
||
samples = [{ | ||
'query': '我要一个汉堡。' | ||
},{ | ||
'query': '你最近过得怎么样?' | ||
},{ | ||
'query': '它是正方形的。' | ||
} | ||
] | ||
targets = [1, 0, -1] | ||
|
||
op = QueryIntentDetectionMapper( | ||
hf_model = self.hf_model, | ||
zh_to_en_hf_model = self.zh_to_en_hf_model, | ||
) | ||
self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) | ||
|
||
def test_no_zh_to_en(self): | ||
|
||
samples = [{ | ||
'query': '它是正方形的。' | ||
},{ | ||
'query': 'It is square.' | ||
} | ||
] | ||
targets = [0, 1] | ||
|
||
op = QueryIntentDetectionMapper( | ||
hf_model = self.hf_model, | ||
zh_to_en_hf_model = None, | ||
) | ||
self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |