Skip to content

Commit

Permalink
change model
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Dec 17, 2024
1 parent 8109c71 commit c749dcd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion data_juicer/ops/mapper/query_intent_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class QueryIntentDetectionMapper(Mapper):

def __init__(
self,
hf_model: str = 'Falconsai/intent_classification',
hf_model:
str = 'bespin-global/klue-roberta-small-3i4k-intent-classification', # noqa: E501 E131
zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en',
model_params: Dict = {},
zh_to_en_model_params: Dict = {},
Expand Down
12 changes: 6 additions & 6 deletions tests/ops/mapper/test_query_intent_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase):

hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Falconsai/intent_classification'
hf_model = '/mnt/workspace/shared/checkpoints/huggingface/bespin-global/klue-roberta-small-3i4k-intent-classification'
zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en'

def _run_op(self, op, samples, intensity_key, targets):
Expand All @@ -27,11 +27,11 @@ def _run_op(self, op, samples, intensity_key, targets):
def test_default(self):

samples = [{
'query': '我要一个汉堡。'
'query': '这样好吗?'
},{
'query': '你最近过得怎么样?'
'query': '把那只笔递给我。'
},{
'query': '它是正方形的。'
'query': '难道不是这样的吗?'
}
]
targets = [1, 0, -1]
Expand All @@ -45,9 +45,9 @@ def test_default(self):
def test_no_zh_to_en(self):

samples = [{
'query': '它是正方形的。'
'query': '这样好吗?'
},{
'query': 'It is square.'
'query': 'Is this okay?'
}
]
targets = [0, 1]
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/mapper/test_query_sentiment_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase):

hf_model = '/mnt/workspace/shared/checkpoints/huggingface/mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis'
zh_to_en_hf_model = '/mnt/workspace/shared/checkpoints/huggingface/Helsinki-NLP/opus-mt-zh-en'
hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis'
zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en'

def _run_op(self, op, samples, label_key, targets):
dataset = Dataset.from_list(samples)
Expand Down

0 comments on commit c749dcd

Please sign in to comment.