From c749dcd3885f7cc6b52e8c16d69672617f8c1cc8 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Tue, 17 Dec 2024 15:06:00 +0800 Subject: [PATCH] change model --- .../ops/mapper/query_intent_detection_mapper.py | 3 ++- .../ops/mapper/test_query_intent_detection_mapper.py | 12 ++++++------ .../mapper/test_query_sentiment_detection_mapper.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index 66290532..badc7f49 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -25,7 +25,8 @@ class QueryIntentDetectionMapper(Mapper): def __init__( self, - hf_model: str = 'Falconsai/intent_classification', + hf_model: + str = 'bespin-global/klue-roberta-small-3i4k-intent-classification', # noqa: E501 E131 zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 5fac5ffc..592d3d0c 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -12,7 +12,7 @@ class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): - hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Falconsai/intent_classification' + hf_model = '/mnt/workspace/shared/checkpoints/huggingface/bespin-global/klue-roberta-small-3i4k-intent-classification' zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, intensity_key, targets): @@ -27,11 +27,11 @@ def _run_op(self, op, samples, intensity_key, targets): def test_default(self): samples = [{ - 'query': '我要一个汉堡。' + 'query': '这样好吗?' },{ - 'query': '你最近过得怎么样?' + 'query': '把那只笔递给我。' },{ - 'query': '它是正方形的。' + 'query': '难道不是这样的吗?' } ] targets = [1, 0, -1] @@ -45,9 +45,9 @@ def test_default(self): def test_no_zh_to_en(self): samples = [{ - 'query': '它是正方形的。' + 'query': '这样好吗?' },{ - 'query': 'It is square.' + 'query': 'Is this okay?' } ] targets = [0, 1] diff --git a/tests/ops/mapper/test_query_sentiment_detection_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py index dcc29e25..62ed0f38 100644 --- a/tests/ops/mapper/test_query_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py @@ -12,8 +12,8 @@ class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase): - hf_model = '/mnt/workspace/shared/checkpoints/huggingface/mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' - zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en' + hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' + zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' def _run_op(self, op, samples, label_key, targets): dataset = Dataset.from_list(samples)