diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 963fe099f..a54d31f94 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -79,8 +79,10 @@ process: - clean_copyright_mapper: # remove copyright comments. - dialog_intent_detection_mapper: # Mapper to generate user's intent labels in dialog. api_model: 'gpt-4o' # API model name. - intent_candidates: null # The output intent candidates. Use the intent labels of the open domain if it is None. + intent_candidates: null # The output intent candidates. Use open-domai intent labels n if it is None. max_round: 10 # The max num of round in the dialog to build the prompt. + labels_key: 'dialog_intent_labels' # The key name in the meta field to store the output labels. It is 'dialog_intent_labels' in default. + analysis_key: 'dialog_intent_labels_analysis' # The key name in the meta field to store the corresponding analysis. It is 'dialog_intent_labels_analysis' in default. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. @@ -96,12 +98,16 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - dialog_sentiment_detection_mapper: # Mapper to generate user's sentiment labels in dialog. api_model: 'gpt-4o' # API model name. + sentiment_candidates: null # The output sentiment candidates. Use open-domai sentiment labels n if it is None. max_round: 10 # The max num of round in the dialog to build the prompt. + labels_key: 'dialog_sentiment_labels' # The key name in the meta field to store the output labels. It is 'dialog_sentiment_labels' in default. + analysis_key: 'dialog_sentiment_labels_analysis' # The key name in the meta field to store the corresponding analysis. It is 'dialog_sentiment_labels_analysis' in default. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. query_template: null # Template for query part to build the input prompt. response_template: null # Template for response part to build the input prompt. + candidate_template: null # Template for sentiment candidates to build the input prompt. analysis_template: null # Template for analysis part to build the input prompt. labels_template: null # Template for labels part to build the input prompt. analysis_pattern: null # Pattern to parse the return sentiment analysis. @@ -112,6 +118,8 @@ process: - dialog_sentiment_intensity_mapper: # Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. api_model: 'gpt-4o' # API model name. max_round: 10 # The max num of round in the dialog to build the prompt. + intensities_key: null # The key name in the meta field to store the output sentiment intensities. It is 'dialog_sentiment_intensity' in default. + analysis_key: null # The key name in the meta field to store the corresponding analysis. It is 'dialog_sentiment_intensity_analysis' in default. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. @@ -126,12 +134,16 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - dialog_topic_detection_mapper: # Mapper to generate user's topic labels in dialog. api_model: 'gpt-4o' # API model name. + topic_candidates: null # The output topic candidates. Use open-domai topic labels n if it is None. max_round: 10 # The max num of round in the dialog to build the prompt. + labels_key: 'dialog_topic_labels' # The key name in the meta field to store the output labels. It is 'dialog_topic_labels' in default. + analysis_key: 'dialog_topic_labels_analysis' # The key name in the meta field to store the corresponding analysis. It is 'dialog_topic_labels_analysis' in default. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. query_template: null # Template for query part to build the input prompt. response_template: null # Template for response part to build the input prompt. + candidate_template: null # Template for topic candidates to build the input prompt. analysis_template: null # Template for analysis part to build the input prompt. labels_template: null # Template for labels part to build the input prompt. analysis_pattern: null # Pattern to parse the return topic analysis. @@ -144,10 +156,10 @@ process: api_model: 'gpt-4o' # API model name. query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried. query_attributes: ["人物性格"] # Attribute list to be queried. - entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction. - entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted. - attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description. - support_text_key: '__dj__support_text__' # The field name to store the attribute support text extracted from the raw text. + entity_key: 'entity' # The key name in the meta field to store the given main entity for attribute extraction. + entity_attribute_key: 'attribute' # The key name in the meta field to store the given attribute to be extracted. + attribute_desc_key: 'attribute_description' # The key name in the meta field to store the extracted attribute description. + support_text_key: 'support_text' # The key name in the meta field to store the attribute support text extracted from the raw text. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute. @@ -161,8 +173,8 @@ process: - extract_entity_relation_mapper: # Extract entities and relations in the text for knowledge graph. api_model: 'gpt-4o' # API model name. entity_types: ['person', 'organization', 'location'] # Pre-defined entity types for knowledge graph. - entity_key: '__dj__entity__' # The field name to store the entities. - relation_key: '__dj__relation__' # The field name to store the relations between entities. + entity_key: 'entity' # The key name in the meta field to store the entities. + relation_key: 'relation' # The key name in the meta field to store the relations between entities. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. prompt_template: null # The template of input prompt. @@ -180,8 +192,8 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - extract_event_mapper: # Extract events and relevant characters in the text api_model: 'gpt-4o' # API model name. - event_desc_key: '__dj__event_description__' # The field name to store the event descriptions. - relevant_char_key: '__dj__relevant_characters__' # The field name to store the relevant characters to the events. + event_desc_key: 'event_description' # The key name in the meta field to store the event descriptions. + relevant_char_key: 'relevant_characters' # The key name in the meta field to store the relevant characters to the events. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. @@ -193,7 +205,7 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - extract_keyword_mapper: # Generate keywords for the text. api_model: 'gpt-4o' # API model name. - keyword_key: '__dj__keyword__' # The field name to store the keywords. + keyword_key: 'keyword' # The key name in the meta field to store the keywords. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. prompt_template: null # The template of input prompt. @@ -205,7 +217,7 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - extract_nickname_mapper: # Extract nickname relationship in the text. api_model: 'gpt-4o' # API model name. - nickname_key: '__dj__nickname__' # The field name to store the nickname relationship. + nickname_key: 'nickname' # The key name in the meta field to store the nickname relationship. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. @@ -217,8 +229,8 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - extract_support_text_mapper: # extract support sub text for a summary. api_model: 'gpt-4o' # API model name. - summary_key: '__dj__event_description__' # The field name to store the input summary. Support for nested keys such as "__dj__stats__.text_len". - support_text_key: '__dj__support_text__' # The field name to store the output support text for the summary. + summary_key: 'event_description' # The key name in the meta field to store the input summary. It's "event_description" in default. + support_text_key: 'support_text' # The key name in the meta field to store the output support text for the summary. It's "support_text" in default. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the task. @@ -279,13 +291,13 @@ process: keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated images in the final datasets and the original images will be removed. It's True in default. caption_key: null # the key name of fields in samples to store captions for each images, the caption guide the diffusion model to produce what the image is hf_img2seq: 'Salesforce/blip2-opt-2.7b' # model name on huggingface to generate caption if caption_key is null - mem_required: '8GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + mem_required: '8GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched - image_face_blur_mapper: # blur faces detected in images cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'. blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel - image_tagging_mapper: # Mapper to generate image tags. - tag_field_name: '__dj__image_tags__' # the field name to store the tags. It's "__dj__image_tags__" in default. + tag_field_name: 'image_tags' # the field name to store the tags. It's "image_tags" in default. mem_required: '9GB' - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. @@ -333,7 +345,7 @@ process: model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. - - python_file_mapper: # executing Python lambda function defined in a file. + - python_file_mapper: # executing Python lambda function defined in a file. file_path: '' # The path to the Python file containing the function to be executed. function_name: 'process_single' # The name of the function defined in the file to be executed. - python_lambda_mapper: # executing Python lambda function on data samples. @@ -344,22 +356,27 @@ process: zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. zh_to_en_model_params: {} # model param for zh_to_hf_model. + label_key: 'query_intent_label' # The key name in the meta field to store the output label. It is 'query_intent_label' in default. + score_key: 'query_intent_label_score' # The key name in the meta field to store the corresponding label score. It is 'query_intent_label_score' in default. - query_sentiment_detection_mapper: # Mapper to predict user's sentiment label ('negative', 'neutral' and 'positive') in query. hf_model: 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' # Hugginface model ID to predict sentiment label. zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. zh_to_en_model_params: {} # model param for zh_to_hf_model. + label_key: 'query_sentiment_label' # The key name in the meta field to store the output label. It is 'query_sentiment_label' in default. + score_key: 'query_sentiment_label_score' # The key name in the meta field to store the corresponding label score. It is 'query_sentiment_label_score' in default. - query_topic_detection_mapper: # Mapper to predict user's topic label in query. hf_model: 'dstefa/roberta-base_topic_classification_nyt_news' # Hugginface model ID to predict topic label. zh_to_en_hf_model: 'Helsinki-NLP/opus-mt-zh-en' # Translation model from Chinese to English. If not None, translate the query from Chinese to English. model_params: {} # model param for hf_model. zh_to_en_model_params: {} # model param for zh_to_hf_model. + label_key: 'query_topic_label' # The key name in the meta field to store the output label. It is 'query_topic_label' in default. + score_key: 'query_topic_label_score' # The key name in the meta field to store the corresponding label score. It is 'query_topic_label_score' in default. - relation_identity_mapper: # identify relation between two entity in the text. api_model: 'gpt-4o' # API model name. source_entity: '孙悟空' # The source entity of the relation to be dentified. target_entity: '猪八戒' # The target entity of the relation to be identified. - input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. - output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is input_key in default. + output_key: 'role_relation' # The output key in the meta field in the samples. It is 'role_relation' in default. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt_template: null # System prompt template for the task. Need to specify by entity1 and entity2. @@ -486,12 +503,12 @@ process: show_progress: false # whether to show progress from scenedetect - video_tagging_from_audio_mapper: # Mapper to generate video tags from audio streams extracted from the video. hf_ast: 'MIT/ast-finetuned-audioset-10-10-0.4593' # Huggingface model name for the audio classification model. - tag_field_name: '__dj__video_audio_tags__' # the field name to store the tags. It's "__dj__video_audio_tags__" in default. + tag_field_name: 'video_audio_tags' # the field name to store the tags. It's "video_audio_tags" in default. mem_required: '500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched - video_tagging_from_frames_mapper: # Mapper to generate video tags from frames extracted from the video. frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. - tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default. + tag_field_name: 'video_frame_tags' # the key name in the meta field to store the tags. It's "video_frame_tags" in default. mem_required: '9GB' - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. @@ -723,7 +740,7 @@ process: contain: any # require the videos containing 'any' or 'all' given tags. When tags equal to [], 'all' keeps all samples, 'any' keeps no sample. frame_sampling_method: all_keyframes # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. - tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default. + tag_field_name: 'video_frame_tags' # the key name in the meta field to store the tags. It's "video_frame_tags" in default. any_or_all: any # keep this sample when any/all videos meet the filter condition mem_required: '9GB' - words_num_filter: # filter text with number of words out of specific range @@ -822,6 +839,7 @@ process: # Grouper ops. - naive_grouper: # Group all samples to one batched sample. - naive_reverse_grouper: # Split one batched sample to samples. + batch_meta_export_path: null # the path to export the batch meta. Just drop the batch meta if it is None. - key_value_grouper: # Group samples to batched samples according values in given keys. group_by_keys: null # Group samples according values in the keys. Support for nested keys such as "__dj__stats__.text_len". It is [self.text_key] in default. @@ -830,8 +848,8 @@ process: api_model: 'gpt-4o' # API model name. entity: '孙悟空' # The given entity. attribute: '人物经历' # The given attribute. - input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. - output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + input_key: 'event_description' # The input key in the meta field of the samples. It is "event_description" in default. + output_key: 'entity_attribute' # The output key in the aggregation field of the samples. It is "entity_attribute" in default. word_limit: 100 # Prompt the output length. max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. api_endpoint: null # URL endpoint for the API. @@ -845,7 +863,7 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - meta_tags_aggregator: # Merge similar meta tags to one tag. api_model: 'gpt-4o' # API model name. - meta_tag_key: '__dj__meta__.query_sentiment_label' # The key of the meta tag to be mapped. + meta_tag_key: 'query_sentiment_label' # The key of the meta tag to be mapped. It is "query_sentiment_label" in default. target_tags: ['开心', '难过', '其他'] # The tags that is supposed to be mapped to. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -861,8 +879,8 @@ process: api_model: 'gpt-4o' # API model name. entity: '孙悟空' # The given entity. query_entity_type: '人物' # The type of queried relavant entities. - input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. - output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + input_key: 'event_description' # The input key in the meta field of the samples. It is "event_description" in default. + output_key: 'most_relavant_entities' # The output key in the aggregation field of the samples. It is "most_relavant_entities" in default. max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -874,8 +892,8 @@ process: sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - nested_aggregator: # Considering the limitation of input length, nested aggregate contents for each given number of samples. api_model: 'gpt-4o' # API model name. - input_key: null # The input field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is text_key in default. - output_key: null # The output field key in the samples. Support for nested keys such as "__dj__stats__.text_len". It is same as the input_key in default. + input_key: 'event_description' # The input key in the meta field of the samples. It is "event_description" in default. + output_key: null # The output key in the aggregation field in the samples. It is same as the input_key in default. max_token_num: null # The max token num of the total tokens of the sub documents. Without limitation if it is None. api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. diff --git a/data_juicer/ops/aggregator/entity_attribute_aggregator.py b/data_juicer/ops/aggregator/entity_attribute_aggregator.py index 16ec5fd07..de39cd322 100644 --- a/data_juicer/ops/aggregator/entity_attribute_aggregator.py +++ b/data_juicer/ops/aggregator/entity_attribute_aggregator.py @@ -6,8 +6,8 @@ from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, - is_string_list, nested_access, - nested_set) + is_string_list) +from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model from .nested_aggregator import NestedAggregator @@ -53,8 +53,8 @@ def __init__(self, api_model: str = 'gpt-4o', entity: str = None, attribute: str = None, - input_key: str = None, - output_key: str = None, + input_key: str = MetaKeys.event_description, + output_key: str = BatchMetaKeys.entity_attribute, word_limit: PositiveInt = 100, max_token_num: Optional[PositiveInt] = None, *, @@ -73,12 +73,10 @@ def __init__(self, :param api_model: API model name. :param entity: The given entity. :param attribute: The given attribute. - :param input_key: The input field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is text_key - in default. - :param output_key: The output field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is same as the - input_key in default. + :param input_key: The input key in the meta field of the samples. + It is "event_description" in default. + :param output_key: The output key in the aggregation field of the + samples. It is "entity_attribute" in default. :param word_limit: Prompt the output length. :param max_token_num: The max token num of the total tokens of the sub documents. Without limitation if it is None. @@ -103,8 +101,8 @@ def __init__(self, self.entity = entity self.attribute = attribute - self.input_key = input_key or self.text_key - self.output_key = output_key or self.input_key + self.input_key = input_key + self.output_key = output_key self.word_limit = word_limit self.max_token_num = max_token_num @@ -131,7 +129,7 @@ def __init__(self, **model_params) self.try_num = try_num - self.nested_sum = NestedAggregator(model=api_model, + self.nested_sum = NestedAggregator(api_model=api_model, max_token_num=max_token_num, api_endpoint=api_endpoint, response_path=response_path, @@ -185,12 +183,21 @@ def attribute_summary(self, sub_docs, rank=None): def process_single(self, sample=None, rank=None): + if self.output_key in sample[Fields.batch_meta]: + return sample + + if Fields.meta not in sample or self.input_key not in sample[ + Fields.meta][0]: + logger.warning('The input key does not exist in the sample!') + return sample + + sub_docs = [d[self.input_key] for d in sample[Fields.meta]] # if not batched sample - sub_docs = nested_access(sample, self.input_key) if not is_string_list(sub_docs): + logger.warning('Require string meta as input!') return sample - sample = nested_set(sample, self.output_key, - self.attribute_summary(sub_docs, rank=rank)) + sample[Fields.batch_meta][self.output_key] = self.attribute_summary( + sub_docs, rank=rank) return sample diff --git a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py index 7ca49f505..be585f44f 100644 --- a/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py +++ b/data_juicer/ops/aggregator/most_relavant_entities_aggregator.py @@ -5,8 +5,8 @@ from pydantic import PositiveInt from data_juicer.ops.base_op import OPERATORS, Aggregator -from data_juicer.utils.common_utils import (is_string_list, nested_access, - nested_set) +from data_juicer.utils.common_utils import is_string_list +from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..common import split_text_by_punctuation @@ -44,8 +44,8 @@ def __init__(self, api_model: str = 'gpt-4o', entity: str = None, query_entity_type: str = None, - input_key: str = None, - output_key: str = None, + input_key: str = MetaKeys.event_description, + output_key: str = BatchMetaKeys.most_relavant_entities, max_token_num: Optional[PositiveInt] = None, *, api_endpoint: Optional[str] = None, @@ -62,12 +62,10 @@ def __init__(self, :param api_model: API model name. :param entity: The given entity. :param query_entity_type: The type of queried relavant entities. - :param input_key: The input field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is text_key - in default. - :param output_key: The output field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is same as the - input_key in default. + :param input_key: The input key in the meta field of the samples. + It is "event_description" in default. + :param output_key: The output key in the aggregation field of the + samples. It is "most_relavant_entities" in default. :param max_token_num: The max token num of the total tokens of the sub documents. Without limitation if it is None. :param api_endpoint: URL endpoint for the API. @@ -91,8 +89,8 @@ def __init__(self, self.entity = entity self.query_entity_type = query_entity_type - self.input_key = input_key or self.text_key - self.output_key = output_key or self.input_key + self.input_key = input_key + self.output_key = output_key self.max_token_num = max_token_num system_prompt_template = system_prompt_template or \ @@ -167,13 +165,22 @@ def query_most_relavant_entities(self, sub_docs, rank=None): def process_single(self, sample=None, rank=None): + if self.output_key in sample[Fields.batch_meta]: + return sample + + if Fields.meta not in sample or self.input_key not in sample[ + Fields.meta][0]: + logger.warning('The input key does not exist in the sample!') + return sample + + sub_docs = [d[self.input_key] for d in sample[Fields.meta]] + # if not batched sample - sub_docs = nested_access(sample, self.input_key) if not is_string_list(sub_docs): return sample - sample = nested_set( - sample, self.output_key, - self.query_most_relavant_entities(sub_docs, rank=rank)) + sample[Fields.batch_meta][ + self.output_key] = self.query_most_relavant_entities(sub_docs, + rank=rank) return sample diff --git a/data_juicer/ops/aggregator/nested_aggregator.py b/data_juicer/ops/aggregator/nested_aggregator.py index ab25e057d..f228ffd9c 100644 --- a/data_juicer/ops/aggregator/nested_aggregator.py +++ b/data_juicer/ops/aggregator/nested_aggregator.py @@ -5,7 +5,8 @@ from data_juicer.ops.base_op import OPERATORS, Aggregator from data_juicer.utils.common_utils import (avg_split_string_list_under_limit, - is_string_list, nested_access) + is_string_list) +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'nested_aggregator' @@ -47,7 +48,7 @@ class NestedAggregator(Aggregator): def __init__(self, api_model: str = 'gpt-4o', - input_key: str = None, + input_key: str = MetaKeys.event_description, output_key: str = None, max_token_num: Optional[PositiveInt] = None, *, @@ -63,12 +64,10 @@ def __init__(self, """ Initialization method. :param api_model: API model name. - :param input_key: The input field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is text_key - in default. - :param output_key: The output field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is same as the - input_key in default. + :param input_key: The input key in the meta field of the samples. + It is "event_description" in default. + :param output_key: The output key in the aggregation field in the + samples. It is same as the input_key in default. :param max_token_num: The max token num of the total tokens of the sub documents. Without limitation if it is None. :param api_endpoint: URL endpoint for the API. @@ -165,11 +164,21 @@ def recursive_summary(self, sub_docs, rank=None): def process_single(self, sample=None, rank=None): + if self.output_key in sample[Fields.batch_meta]: + return sample + + if Fields.meta not in sample or self.input_key not in sample[ + Fields.meta][0]: + logger.warning('The input key does not exist in the sample!') + return sample + + sub_docs = [d[self.input_key] for d in sample[Fields.meta]] + # if not batched sample - sub_docs = nested_access(sample, self.input_key) if not is_string_list(sub_docs): return sample - sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank) + sample[Fields.batch_meta][self.output_key] = self.recursive_summary( + sub_docs, rank=rank) return sample diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 6c9dd17f5..f230e600b 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -633,6 +633,17 @@ def process_single(self, sample): def run(self, dataset, *, exporter=None, tracer=None): dataset = super(Aggregator, self).run(dataset) + # add batched meta field for OPs that produce aggregations + if Fields.batch_meta not in dataset.features: + from data_juicer.core.data import add_same_content_to_new_column + dataset = dataset.map(add_same_content_to_new_column, + fn_kwargs={ + 'new_column_name': Fields.batch_meta, + 'initial_value': {} + }, + num_proc=self.runtime_np(), + batch_size=self.batch_size, + desc='Adding new column for aggregation') new_dataset = dataset.map( self.process, num_proc=self.runtime_np(), diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index 2436d886c..4e37f1d7b 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -3,7 +3,7 @@ import numpy as np from pydantic import PositiveInt -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from ..base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE, Filter) @@ -30,7 +30,7 @@ def __init__(self, contain: str = 'any', frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, - tag_field_name: str = Fields.video_frame_tags, + tag_field_name: str = MetaKeys.video_frame_tags, any_or_all: str = 'any', *args, **kwargs): @@ -55,8 +55,8 @@ def __init__(self, the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. - :param tag_field_name: the field name to store the tags. It's - "__dj__video_frame_tags__" in default. + :param tag_field_name: the key name to store the tags in the meta + field. It's "video_frame_tags" in default. :param any_or_all: keep this sample with 'any' or 'all' strategy of all videos. 'any': keep this sample if any videos meet the condition. 'all': keep this sample only if all videos meet the diff --git a/data_juicer/ops/grouper/naive_reverse_grouper.py b/data_juicer/ops/grouper/naive_reverse_grouper.py index 2535205b9..ae7860cc6 100644 --- a/data_juicer/ops/grouper/naive_reverse_grouper.py +++ b/data_juicer/ops/grouper/naive_reverse_grouper.py @@ -1,3 +1,9 @@ +import json +import os + +from data_juicer.utils.constant import Fields +from data_juicer.utils.file_utils import create_directory_if_not_exists + from ..base_op import OPERATORS, Grouper, convert_dict_list_to_list_dict @@ -5,14 +11,17 @@ class NaiveReverseGrouper(Grouper): """Split batched samples to samples. """ - def __init__(self, *args, **kwargs): + def __init__(self, batch_meta_export_path=None, *args, **kwargs): """ Initialization method. + :param batch_meta_export_path: the path to export the batch meta. + Just drop the batch meta if it is None. :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) + self.batch_meta_export_path = batch_meta_export_path def process(self, dataset): @@ -20,7 +29,20 @@ def process(self, dataset): return dataset samples = [] + batch_metas = [] for sample in dataset: + if Fields.batch_meta in sample: + batch_metas.append(sample[Fields.batch_meta]) + sample = { + k: sample[k] + for k in sample if k != Fields.batch_meta + } samples.extend(convert_dict_list_to_list_dict(sample)) + if self.batch_meta_export_path is not None: + create_directory_if_not_exists( + os.path.dirname(self.batch_meta_export_path)) + with open(self.batch_meta_export_path, 'w') as f: + for batch_meta in batch_metas: + f.write(json.dumps(batch_meta, ensure_ascii=False) + '\n') return samples diff --git a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py index 7c8cba9ed..6926ef6c2 100644 --- a/data_juicer/ops/mapper/dialog_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_intent_detection_mapper.py @@ -4,8 +4,7 @@ from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.common_utils import nested_set +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,14 +12,13 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class DialogIntentDetectionMapper(Mapper): """ Mapper to generate user's intent labels in dialog. Input from history_key, query_key and response_key. Output lists of - labels and analysis for queries in the dialog, which is - store in 'dialog_intent_labels' and - 'dialog_intent_labels_analysis' in Data-Juicer meta field. + labels and analysis for queries in the dialog. """ DEFAULT_SYSTEM_PROMPT = ( @@ -60,6 +58,8 @@ def __init__(self, intent_candidates: Optional[List[str]] = None, max_round: NonNegativeInt = 10, *, + labels_key: str = MetaKeys.dialog_intent_labels, + analysis_key: str = MetaKeys.dialog_intent_labels_analysis, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, @@ -82,6 +82,11 @@ def __init__(self, intent labels of the open domain if it is None. :param max_round: The max num of round in the dialog to build the prompt. + :param labels_key: The key name in the meta field to store the + output labels. It is 'dialog_intent_labels' in default. + :param analysis_key: The key name in the meta field to store the + corresponding analysis. It is 'dialog_intent_labels_analysis' + in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -111,6 +116,8 @@ def __init__(self, self.intent_candidates = intent_candidates self.max_round = max_round + self.labels_key = labels_key + self.analysis_key = analysis_key self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE @@ -167,6 +174,11 @@ def parse_output(self, response): return analysis, labels def process_single(self, sample, rank=None): + + meta = sample[Fields.meta] + if self.labels_key in meta and self.analysis_key in meta: + return sample + client = get_model(self.model_key, rank=rank) analysis_list = [] @@ -208,9 +220,7 @@ def process_single(self, sample, rank=None): history.append(self.labels_template.format(labels=labels)) history.append(self.response_template.format(response=qa[1])) - analysis_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels_analysis}' # noqa: E501 - sample = nested_set(sample, analysis_key, analysis_list) - labels_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels}' - sample = nested_set(sample, labels_key, labels_list) + meta[self.labels_key] = labels_list + meta[self.analysis_key] = analysis_list return sample diff --git a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py index 33bccc5ce..11d7ddd2d 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_detection_mapper.py @@ -1,11 +1,10 @@ import re -from typing import Dict, Optional +from typing import Dict, List, Optional from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.common_utils import nested_set +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,14 +12,13 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class DialogSentimentDetectionMapper(Mapper): """ Mapper to generate user's sentiment labels in dialog. Input from history_key, query_key and response_key. Output lists of - labels and analysis for queries in the dialog, which is - store in 'dialog_sentiment_labels' and - 'dialog_sentiment_labels_analysis' in Data-Juicer meta field. + labels and analysis for queries in the dialog. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所具有的情绪。\n' @@ -29,35 +27,40 @@ class DialogSentimentDetectionMapper(Mapper): '。\n' '用户:最近工作压力好大,我觉得整个人都快被压垮了。\n' '情感分析:用户的言语中透露出明显的压力和疲惫感,可能还夹杂着一些无助和焦虑。\n' - '情感:压力、疲惫、无助、焦虑\n' + '情感类别:压力、疲惫、无助、焦虑\n' 'LLM:听起来你真的承受了很多,面临这种情况确实不容易。有没有考虑过找一些放松的' '方式,比如听音乐或者散步来减轻压力呢?\n' '用户:试过了,但是好像没什么效果,每天的事情都堆积如山。\n' '情感分析:用户感到无力解决现状,有挫败感,并且对尝试放松的方式失去信心。\n' - '情感:无力、挫败\n' + '情感类别:无力、挫败\n' 'LLM:我理解你的感受,有时候压力积累到一定程度确实让人难以承受。或许你可以尝试' '规划一下时间,把任务分成小块来完成,这样可能会减少一些压力感。\n' '用户:这个主意不错,我会试着让自己更有条理一些,谢谢你的建议。\n' '情感分析:用户对建议表现出认同和感激,同时展现出试图积极面对问题的态度。\n' - '情感:认同、感激、积极\n' + '情感类别:认同、感激、积极\n' 'LLM:不用谢,我很高兴能帮到你。记得给自己一些时间去适应新的计划,有任何需要' '随时可以跟我说哦!\n') DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_CANDIDATES_TEMPLATE = '备选情感类别:[{candidate_str}]' DEFAULT_ANALYSIS_TEMPLATE = '情感分析:{analysis}\n' - DEFAULT_LABELS_TEMPLATE = '情感:{labels}\n' + DEFAULT_LABELS_TEMPLATE = '情感类别:{labels}\n' DEFAULT_ANALYSIS_PATTERN = '情感分析:(.*?)\n' - DEFAULT_LABELS_PATTERN = '情感:(.*?)($|\n)' + DEFAULT_LABELS_PATTERN = '情感类别:(.*?)($|\n)' def __init__(self, api_model: str = 'gpt-4o', + sentiment_candidates: Optional[List[str]] = None, max_round: NonNegativeInt = 10, *, + labels_key: str = MetaKeys.dialog_sentiment_labels, + analysis_key: str = MetaKeys.dialog_sentiment_labels_analysis, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, query_template: Optional[str] = None, response_template: Optional[str] = None, + candidate_template: Optional[str] = None, analysis_template: Optional[str] = None, labels_template: Optional[str] = None, analysis_pattern: Optional[str] = None, @@ -70,8 +73,15 @@ def __init__(self, Initialization method. :param api_model: API model name. + :param sentiment_candidates: The output sentiment candidates. Use + open-domain sentiment labels if it is None. :param max_round: The max num of round in the dialog to build the prompt. + :param labels_key: The key name in the meta field to store the + output labels. It is 'dialog_sentiment_labels' in default. + :param analysis_key: The key name in the meta field to store the + corresponding analysis. It is + 'dialog_sentiment_labels_analysis' in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -80,6 +90,8 @@ def __init__(self, prompt. :param response_template: Template for response part to build the input prompt. + :param candidate_template: Template for sentiment candidates to + build the input prompt. :param analysis_template: Template for analysis part to build the input prompt. :param labels_template: Template for labels part to build the @@ -97,12 +109,17 @@ def __init__(self, """ super().__init__(**kwargs) + self.sentiment_candidates = sentiment_candidates self.max_round = max_round + self.labels_key = labels_key + self.analysis_key = analysis_key self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE self.response_template = response_template or \ self.DEFAULT_RESPONSE_TEMPLATE + self.candidate_template = candidate_template or \ + self.DEFAULT_CANDIDATES_TEMPLATE self.analysis_template = analysis_template or \ self.DEFAULT_ANALYSIS_TEMPLATE self.labels_template = labels_template or \ @@ -123,10 +140,16 @@ def __init__(self, self.try_num = try_num def build_input(self, history, query): - if self.max_round > 0: - input_prompt = ''.join(history[-self.max_round * 4:]) + + if self.sentiment_candidates: + input_prompt = self.candidate_template.format( + candidate_str=','.join(self.sentiment_candidates)) else: input_prompt = '' + + if self.max_round > 0: + input_prompt += ''.join(history[-self.max_round * 4:]) + input_prompt += self.query_template.format(query=query[0]) return input_prompt @@ -146,6 +169,11 @@ def parse_output(self, response): return analysis, labels def process_single(self, sample, rank=None): + + meta = sample[Fields.meta] + if self.labels_key in meta and self.analysis_key in meta: + return sample + client = get_model(self.model_key, rank=rank) analysis_list = [] @@ -187,9 +215,7 @@ def process_single(self, sample, rank=None): history.append(self.labels_template.format(labels=labels)) history.append(self.response_template.format(response=qa[1])) - analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels_analysis}' # noqa: E501 - sample = nested_set(sample, analysis_key, analysis_list) - labels_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_labels}' - sample = nested_set(sample, labels_key, labels_list) + meta[self.labels_key] = labels_list + meta[self.analysis_key] = analysis_list return sample diff --git a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py index 198314ee3..f9ffebc28 100644 --- a/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py +++ b/data_juicer/ops/mapper/dialog_sentiment_intensity_mapper.py @@ -4,8 +4,7 @@ from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.common_utils import nested_set +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,20 +12,21 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class DialogSentimentIntensityMapper(Mapper): """ Mapper to predict user's sentiment intensity (from -5 to 5 in default prompt) in dialog. Input from history_key, query_key and response_key. Output lists of intensities and analysis for queries in - the dialog, which is store in 'dialog_sentiment_intensity' and - 'dialog_sentiment_intensity_analysis' in Data-Juicer meta field. + the dialog. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户的情绪变化。\n' '要求:\n' '- 用户情绪值是-5到5之间到整数,-5表示极度负面,5表示极度正面,' '-5到5之间数值表示情绪从负面逐渐到正面的变化过程,0代表情呈绪中性。\n' + '- 只输出当轮对话的分析,不要继续构造对话。\n' '- 需要先进行分析,然后确定用户的情绪值,下面是一个样例,请模仿样例格式输出。\n' '用户:你好,我对可持续发展的定义有点模糊,帮我解释一下?\n' '情绪分析:刚开始,还没得到LLM回复,用户情绪呈中性。\n' @@ -61,29 +61,38 @@ class DialogSentimentIntensityMapper(Mapper): DEFAULT_ANALYSIS_PATTERN = '情绪分析:(.*?)\n' DEFAULT_INTENSITY_PATTERN = '情绪值:(.*?)($|\n)' - def __init__(self, - api_model: str = 'gpt-4o', - max_round: NonNegativeInt = 10, - *, - api_endpoint: Optional[str] = None, - response_path: Optional[str] = None, - system_prompt: Optional[str] = None, - query_template: Optional[str] = None, - response_template: Optional[str] = None, - analysis_template: Optional[str] = None, - intensity_template: Optional[str] = None, - analysis_pattern: Optional[str] = None, - intensity_pattern: Optional[str] = None, - try_num: PositiveInt = 3, - model_params: Dict = {}, - sampling_params: Dict = {}, - **kwargs): + def __init__( + self, + api_model: str = 'gpt-4o', + max_round: NonNegativeInt = 10, + *, + intensities_key: str = MetaKeys.dialog_sentiment_intensity, + analysis_key: str = MetaKeys.dialog_sentiment_intensity_analysis, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + query_template: Optional[str] = None, + response_template: Optional[str] = None, + analysis_template: Optional[str] = None, + intensity_template: Optional[str] = None, + analysis_pattern: Optional[str] = None, + intensity_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): """ Initialization method. :param api_model: API model name. :param max_round: The max num of round in the dialog to build the prompt. + :param intensities_key: The key name in the meta field to store + the output sentiment intensities. It is + 'dialog_sentiment_intensity' in default. + :param analysis_key: The key name in the meta field to store the + corresponding analysis. It is + 'dialog_sentiment_intensity_analysis' in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -110,6 +119,8 @@ def __init__(self, super().__init__(**kwargs) self.max_round = max_round + self.intensities_key = intensities_key + self.analysis_key = analysis_key self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE @@ -158,6 +169,11 @@ def parse_output(self, response): return analysis, intensity def process_single(self, sample, rank=None): + + meta = sample[Fields.meta] + if self.intensities_key in meta and self.analysis_key in meta: + return sample + client = get_model(self.model_key, rank=rank) analysis_list = [] @@ -199,9 +215,7 @@ def process_single(self, sample, rank=None): history.append(self.intensity_template.format(intensity=intensity)) history.append(self.response_template.format(response=qa[1])) - analysis_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity_analysis}' # noqa: E501 - sample = nested_set(sample, analysis_key, analysis_list) - intensity_key = f'{Fields.meta}.{MetaKeys.dialog_sentiment_intensity}' - sample = nested_set(sample, intensity_key, intensities) + meta[self.intensities_key] = intensities + meta[self.analysis_key] = analysis_list return sample diff --git a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py index 7e8ee0b54..7fd613df3 100644 --- a/data_juicer/ops/mapper/dialog_topic_detection_mapper.py +++ b/data_juicer/ops/mapper/dialog_topic_detection_mapper.py @@ -1,11 +1,10 @@ import re -from typing import Dict, Optional +from typing import Dict, List, Optional from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.common_utils import nested_set +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -13,14 +12,13 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class DialogTopicDetectionMapper(Mapper): """ Mapper to generate user's topic labels in dialog. Input from history_key, query_key and response_key. Output lists of - labels and analysis for queries in the dialog, which is - store in 'dialog_sentiment_labels' and - 'dialog_sentiment_labels_analysis' in Data-Juicer meta field. + labels and analysis for queries in the dialog. """ DEFAULT_SYSTEM_PROMPT = ('请判断用户和LLM多轮对话中用户所讨论的话题。\n' @@ -47,6 +45,7 @@ class DialogTopicDetectionMapper(Mapper): '。\n') DEFAULT_QUERY_TEMPLATE = '用户:{query}\n' DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n' + DEFAULT_CANDIDATES_TEMPLATE = '备选话题类别:[{candidate_str}]' DEFAULT_ANALYSIS_TEMPLATE = '话题分析:{analysis}\n' DEFAULT_LABELS_TEMPLATE = '话题类别:{labels}\n' DEFAULT_ANALYSIS_PATTERN = '话题分析:(.*?)\n' @@ -54,13 +53,17 @@ class DialogTopicDetectionMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', + topic_candidates: Optional[List[str]] = None, max_round: NonNegativeInt = 10, *, + labels_key: str = MetaKeys.dialog_topic_labels, + analysis_key: str = MetaKeys.dialog_topic_labels_analysis, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, query_template: Optional[str] = None, response_template: Optional[str] = None, + candidate_template: Optional[str] = None, analysis_template: Optional[str] = None, labels_template: Optional[str] = None, analysis_pattern: Optional[str] = None, @@ -73,8 +76,15 @@ def __init__(self, Initialization method. :param api_model: API model name. + :param topic_candidates: The output topic candidates. Use + open-domain topic labels if it is None. :param max_round: The max num of round in the dialog to build the prompt. + :param labels_key: The key name in the meta field to store the + output labels. It is 'dialog_topic_labels' in default. + :param analysis_key: The key name in the meta field to store the + corresponding analysis. It is 'dialog_topic_labels_analysis' + in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -83,13 +93,15 @@ def __init__(self, prompt. :param response_template: Template for response part to build the input prompt. + :param candidate_template: Template for topic candidates to + build the input prompt. :param analysis_template: Template for analysis part to build the input prompt. :param labels_template: Template for labels part to build the input prompt. - :param analysis_pattern: Pattern to parse the return sentiment + :param analysis_pattern: Pattern to parse the return topic analysis. - :param labels_pattern: Pattern to parse the return sentiment + :param labels_pattern: Pattern to parse the return topic labels. :param try_num: The number of retry attempts when there is an API call error or output parsing error. @@ -100,12 +112,17 @@ def __init__(self, """ super().__init__(**kwargs) + self.topic_candidates = topic_candidates self.max_round = max_round + self.labels_key = labels_key + self.analysis_key = analysis_key self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE self.response_template = response_template or \ self.DEFAULT_RESPONSE_TEMPLATE + self.candidate_template = candidate_template or \ + self.DEFAULT_CANDIDATES_TEMPLATE self.analysis_template = analysis_template or \ self.DEFAULT_ANALYSIS_TEMPLATE self.labels_template = labels_template or \ @@ -127,11 +144,15 @@ def __init__(self, def build_input(self, history, query): - if self.max_round > 0: - input_prompt = ''.join(history[-self.max_round * 4:]) + if self.topic_candidates: + input_prompt = self.candidate_template.format( + candidate_str=','.join(self.topic_candidates)) else: input_prompt = '' + if self.max_round > 0: + input_prompt += ''.join(history[-self.max_round * 4:]) + input_prompt += self.query_template.format(query=query[0]) return input_prompt @@ -151,6 +172,11 @@ def parse_output(self, response): return analysis, labels def process_single(self, sample, rank=None): + + meta = sample[Fields.meta] + if self.labels_key in meta and self.analysis_key in meta: + return sample + client = get_model(self.model_key, rank=rank) analysis_list = [] @@ -192,9 +218,7 @@ def process_single(self, sample, rank=None): history.append(self.labels_template.format(labels=labels)) history.append(self.response_template.format(response=qa[1])) - analysis_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels_analysis}' # noqa: E501 - sample = nested_set(sample, analysis_key, analysis_list) - labels_key = f'{Fields.meta}.{MetaKeys.dialog_topic_labels}' - sample = nested_set(sample, labels_key, labels_list) + meta[self.labels_key] = labels_list + meta[self.analysis_key] = analysis_list return sample diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py index 0fc76b11f..8bdeeaa0d 100644 --- a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -4,14 +4,15 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.constant import Fields +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'extract_entity_attribute_mapper' # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityAttributeMapper(Mapper): """ @@ -45,10 +46,10 @@ def __init__(self, query_entities: List[str] = [], query_attributes: List[str] = [], *, - entity_key: str = Fields.main_entities, - attribute_key: str = Fields.attributes, - attribute_desc_key: str = Fields.attribute_descriptions, - support_text_key: str = Fields.attribute_support_texts, + entity_key: str = MetaKeys.main_entities, + attribute_key: str = MetaKeys.attributes, + attribute_desc_key: str = MetaKeys.attribute_descriptions, + support_text_key: str = MetaKeys.attribute_support_texts, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt_template: Optional[str] = None, @@ -65,16 +66,18 @@ def __init__(self, :param api_model: API model name. :param query_entities: Entity list to be queried. :param query_attributes: Attribute list to be queried. - :param entity_key: The field name to store the given main entity for - attribute extraction. It's "__dj__entity__" in default. - :param entity_attribute_key: The field name to store the given - attribute to be extracted. It's "__dj__attribute__" in default. - :param attribute_desc_key: The field name to store the extracted - attribute description. It's "__dj__attribute_description__" in + :param entity_key: The key name in the meta field to store the + given main entity for attribute extraction. It's "entity" in default. - :param support_text_key: The field name to store the attribute - support text extracted from the raw text. It's - "__dj__support_text__" in default. + :param entity_attribute_key: The key name in the meta field to + store the given attribute to be extracted. It's "attribute" + in default. + :param attribute_desc_key: The key name in the meta field to store + the extracted attribute description. It's + "attribute_description" in default. + :param support_text_key: The key name in the meta field to store + the attribute support text extracted from the raw text. + It's "support_text" in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -172,15 +175,22 @@ def _process_single_text(self, text='', rank=None): def process_single(self, sample, rank=None): + # check if it's generated already + if set([ + self.entity_key, self.attribute_key, self.attribute_desc_key, + self.support_text_key + ]) <= set(sample[Fields.meta].keys()): + return sample + res = self._process_single_text(sample[self.text_key], rank=rank) entities, attributes, descs, demo_lists = res if self.drop_text: sample.pop(self.text_key) - sample[self.entity_key] = entities - sample[self.attribute_key] = attributes - sample[self.attribute_desc_key] = descs - sample[self.support_text_key] = demo_lists + sample[Fields.meta][self.entity_key] = entities + sample[Fields.meta][self.attribute_key] = attributes + sample[Fields.meta][self.attribute_desc_key] = descs + sample[Fields.meta][self.support_text_key] = demo_lists return sample diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py index 6350101ac..edf897381 100644 --- a/data_juicer/ops/mapper/extract_entity_relation_mapper.py +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -9,9 +9,9 @@ from loguru import logger from pydantic import NonNegativeInt, PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper from data_juicer.utils.common_utils import is_float -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..common import split_text_by_punctuation @@ -20,6 +20,7 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEntityRelationMapper(Mapper): """ @@ -149,8 +150,8 @@ def __init__(self, api_model: str = 'gpt-4o', entity_types: List[str] = None, *, - entity_key: str = Fields.entity, - relation_key: str = Fields.relation, + entity_key: str = MetaKeys.entity, + relation_key: str = MetaKeys.relation, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, prompt_template: Optional[str] = None, @@ -171,10 +172,10 @@ def __init__(self, Initialization method. :param api_model: API model name. :param entity_types: Pre-defined entity types for knowledge graph. - :param entity_key: The field name to store the entities. It's - "__dj__entity__" in default. + :param entity_key: The key name to store the entities in the meta + field. It's "entity" in default. :param relation_key: The field name to store the relations between - entities. It's "__dj__relation__" in default. + entities. It's "relation" in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -256,9 +257,9 @@ def split_by_tuple_delimiter(record): entities.append(items) entities = list(set(entities)) entities = [{ - Fields.entity_name: e[0], - Fields.entity_type: e[1], - Fields.entity_description: e[2] + MetaKeys.entity_name: e[0], + MetaKeys.entity_type: e[1], + MetaKeys.entity_description: e[2] } for e in entities] relation_pattern = re.compile(self.relation_pattern, @@ -271,11 +272,16 @@ def split_by_tuple_delimiter(record): relations.append(items) relations = list(set(relations)) relations = [{ - Fields.source_entity: r[0], - Fields.target_entity: r[1], - Fields.relation_description: r[2], - Fields.relation_keywords: split_text_by_punctuation(r[3]), - Fields.relation_strength: float(r[4]) + MetaKeys.source_entity: + r[0], + MetaKeys.target_entity: + r[1], + MetaKeys.relation_description: + r[2], + MetaKeys.relation_keywords: + split_text_by_punctuation(r[3]), + MetaKeys.relation_strength: + float(r[4]) } for r in relations] return entities, relations @@ -309,6 +315,11 @@ def light_rag_extraction(self, messages, rank=None): def process_single(self, sample, rank=None): + # check if it's generated already + if self.entity_key in sample[ + Fields.meta] and self.relation_key in sample[Fields.meta]: + return sample + input_prompt = self.prompt_template.format( tuple_delimiter=self.tuple_delimiter, record_delimiter=self.record_delimiter, @@ -327,6 +338,6 @@ def process_single(self, sample, rank=None): except Exception as e: logger.warning(f'Exception: {e}') - sample[self.entity_key] = entities - sample[self.relation_key] = relations + sample[Fields.meta][self.entity_key] = entities + sample[Fields.meta][self.relation_key] = relations return sample diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py index fddf4fed1..948a3cf04 100644 --- a/data_juicer/ops/mapper/extract_event_mapper.py +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -1,12 +1,12 @@ +import copy import re -from itertools import chain from typing import Dict, Optional from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.constant import Fields +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..common import split_text_by_punctuation @@ -15,6 +15,7 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractEventMapper(Mapper): """ @@ -52,8 +53,8 @@ class ExtractEventMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', *, - event_desc_key: str = Fields.event_description, - relevant_char_key: str = Fields.relevant_characters, + event_desc_key: str = MetaKeys.event_description, + relevant_char_key: str = MetaKeys.relevant_characters, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, @@ -67,11 +68,11 @@ def __init__(self, """ Initialization method. :param api_model: API model name. - :param event_desc_key: The field name to store the event descriptions. - It's "__dj__event_description__" in default. + :param event_desc_key: The key name to store the event descriptions + in the meta field. It's "event_description" in default. :param relevant_char_key: The field name to store the relevant - characters to the events. It's "__dj__relevant_characters__" in - default. + characters to the events in the meta field. It's + "relevant_characters" in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -146,7 +147,10 @@ def _process_single_sample(self, text='', rank=None): def process_batched(self, samples, rank=None): - sample_num = len(samples[self.text_key]) + # check if it's generated already + if self.event_desc_key in samples[Fields.meta][ + 0] and self.relevant_char_key in samples[Fields.meta][0]: + return samples events, characters = [], [] for text in samples[self.text_key]: @@ -158,13 +162,24 @@ def process_batched(self, samples, rank=None): if self.drop_text: samples.pop(self.text_key) - for key in samples: - samples[key] = [[samples[key][i]] * len(events[i]) - for i in range(sample_num)] - samples[self.event_desc_key] = events - samples[self.relevant_char_key] = characters - - for key in samples: - samples[key] = list(chain(*samples[key])) - - return samples + new_samples = [] + for i in range(len(events)): + for event, character in zip(events[i], characters[i]): + cur_sample = { + key: copy.deepcopy(samples[key][i]) + for key in samples + } + cur_sample[Fields.meta][self.event_desc_key] = event + cur_sample[Fields.meta][self.relevant_char_key] = character + new_samples.append(cur_sample) + + if len(new_samples) == 0: + logger.warning('Extract Not event in the batch of samples!') + return samples + + res_samples = {} + keys = new_samples[0].keys() + for key in keys: + res_samples[key] = [s[key] for s in new_samples] + + return res_samples diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py index 24e3e127e..2b727f0ac 100644 --- a/data_juicer/ops/mapper/extract_keyword_mapper.py +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -6,8 +6,8 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.constant import Fields +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model from ..common import split_text_by_punctuation @@ -16,6 +16,7 @@ # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractKeywordMapper(Mapper): """ @@ -102,7 +103,7 @@ class ExtractKeywordMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', *, - keyword_key: str = Fields.keyword, + keyword_key: str = MetaKeys.keyword, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, prompt_template: Optional[str] = None, @@ -116,8 +117,8 @@ def __init__(self, """ Initialization method. :param api_model: API model name. - :param keyword_key: The field name to store the keywords. It's - "__dj__keyword__" in default. + :param keyword_key: The key name to store the keywords in the meta + field. It's "keyword" in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -164,6 +165,11 @@ def parse_output(self, raw_output): return keywords def process_single(self, sample, rank=None): + + # check if it's generated already + if self.keyword_key in sample[Fields.meta]: + return sample + client = get_model(self.model_key, rank=rank) input_prompt = self.prompt_template.format( @@ -181,7 +187,7 @@ def process_single(self, sample, rank=None): except Exception as e: logger.warning(f'Exception: {e}') - sample[self.keyword_key] = keywords + sample[Fields.meta][self.keyword_key] = keywords if self.drop_text: sample.pop(self.text_key) diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py index 20aeb94db..140f61011 100644 --- a/data_juicer/ops/mapper/extract_nickname_mapper.py +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -4,14 +4,15 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.constant import Fields +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'extract_nickname_mapper' # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractNicknameMapper(Mapper): """ @@ -50,7 +51,7 @@ class ExtractNicknameMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', *, - nickname_key: str = Fields.nickname, + nickname_key: str = MetaKeys.nickname, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, @@ -64,8 +65,8 @@ def __init__(self, """ Initialization method. :param api_model: API model name. - :param nickname_key: The field name to store the nickname - relationship. It's "__dj__nickname__" in default. + :param nickname_key: The key name to store the nickname + relationship in the meta field. It's "nickname" in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -121,16 +122,21 @@ def parse_output(self, raw_output): nickname_relations = list(set(nickname_relations)) nickname_relations = [{ - Fields.source_entity: nr[0], - Fields.target_entity: nr[1], - Fields.relation_description: nr[2], - Fields.relation_keywords: ['nickname'], - Fields.relation_strength: None + MetaKeys.source_entity: nr[0], + MetaKeys.target_entity: nr[1], + MetaKeys.relation_description: nr[2], + MetaKeys.relation_keywords: ['nickname'], + MetaKeys.relation_strength: None } for nr in nickname_relations] return nickname_relations def process_single(self, sample, rank=None): + + # check if it's generated already + if self.nickname_key in sample[Fields.meta]: + return sample + client = get_model(self.model_key, rank=rank) input_prompt = self.input_template.format(text=sample[self.text_key]) @@ -151,7 +157,7 @@ def process_single(self, sample, rank=None): except Exception as e: logger.warning(f'Exception: {e}') - sample[self.nickname_key] = nickname_relations + sample[Fields.meta][self.nickname_key] = nickname_relations if self.drop_text: sample.pop(self.text_key) diff --git a/data_juicer/ops/mapper/extract_support_text_mapper.py b/data_juicer/ops/mapper/extract_support_text_mapper.py index 34bdbe653..1b31123a1 100644 --- a/data_juicer/ops/mapper/extract_support_text_mapper.py +++ b/data_juicer/ops/mapper/extract_support_text_mapper.py @@ -3,15 +3,15 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.common_utils import nested_access, nested_set -from data_juicer.utils.constant import Fields +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'extract_support_text_mapper' # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class ExtractSupportTextMapper(Mapper): """ @@ -46,8 +46,8 @@ class ExtractSupportTextMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', *, - summary_key: str = Fields.event_description, - support_text_key: str = Fields.support_text, + summary_key: str = MetaKeys.event_description, + support_text_key: str = MetaKeys.support_text, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, @@ -60,12 +60,11 @@ def __init__(self, """ Initialization method. :param api_model: API model name. - :param summary_key: The field name to store the input summary. - Support for nested keys such as "__dj__stats__.text_len". - It's "__dj__event_description__" in default. - :param support_text_key: The field name to store the output - support text for the summary. It's "__dj__support_text__" in - default. + :param summary_key: The key name to store the input summary in the + meta field. It's "event_description" in default. + :param support_text_key: The key name to store the output + support text for the summary in the meta field. It's + "support_text" in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -98,9 +97,18 @@ def __init__(self, self.drop_text = drop_text def process_single(self, sample, rank=None): + + # check if it's generated already + if self.support_text_key in sample[Fields.meta]: + return sample + client = get_model(self.model_key, rank=rank) - summary = nested_access(sample, self.summary_key) + if self.summary_key not in sample[Fields.meta]: + logger.warning( + f'{self.summary_key} does not exist in the meta field!') + return sample + summary = sample[Fields.meta][self.summary_key] if not isinstance(summary, str): logger.warning('Unvalid input summary!') return sample @@ -128,5 +136,5 @@ def process_single(self, sample, rank=None): if not support_text: support_text = summary - sample = nested_set(sample, self.support_text_key, support_text) + sample[Fields.meta][self.support_text_key] = support_text return sample diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py index dc2099b78..7b1f43125 100644 --- a/data_juicer/ops/mapper/image_tagging_mapper.py +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -2,7 +2,7 @@ import numpy as np -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.mm_utils import load_data_with_context, load_image from data_juicer.utils.model_utils import get_model, prepare_model @@ -27,13 +27,13 @@ class ImageTaggingMapper(Mapper): _accelerator = 'cuda' def __init__(self, - tag_field_name: str = Fields.image_tags, + tag_field_name: str = MetaKeys.image_tags, *args, **kwargs): """ Initialization method. :param tag_field_name: the field name to store the tags. It's - "__dj__image_tags__" in default. + "image_tags" in default. :param args: extra args :param kwargs: extra args """ diff --git a/data_juicer/ops/mapper/query_intent_detection_mapper.py b/data_juicer/ops/mapper/query_intent_detection_mapper.py index b0d240e2d..57bf5e3c9 100644 --- a/data_juicer/ops/mapper/query_intent_detection_mapper.py +++ b/data_juicer/ops/mapper/query_intent_detection_mapper.py @@ -1,21 +1,19 @@ from typing import Dict, Optional -from data_juicer.utils.common_utils import nested_set from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, Mapper OP_NAME = 'query_intent_detection_mapper' +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class QueryIntentDetectionMapper(Mapper): """ Mapper to predict user's Intent label in query. Input from query_key. - Output intent label and corresponding score for the query, which is - store in 'query_intent_label' and 'query_intent_label_score' in - Data-Juicer meta field. + Output intent label and corresponding score for the query. """ _accelerator = 'cuda' @@ -28,6 +26,9 @@ def __init__( zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, + *, + label_key: str = MetaKeys.query_intent_label, + score_key: str = MetaKeys.query_intent_score, **kwargs): """ Initialization method. @@ -37,10 +38,18 @@ def __init__( If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. + :param label_key: The key name in the meta field to store the + output label. It is 'query_intent_label' in default. + :param score_key: The key name in the meta field to store the + corresponding label score. It is 'query_intent_label_score' + in default. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) + self.label_key = label_key + self.score_key = score_key + self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_model, return_pipe=True, @@ -58,6 +67,11 @@ def __init__( self.zh_to_en_model_key = None def process_batched(self, samples, rank=None): + + metas = samples[Fields.meta] + if self.label_key in metas[0] and self.score_key in metas[0]: + return samples + queries = samples[self.query_key] if self.zh_to_en_model_key is not None: @@ -71,14 +85,8 @@ def process_batched(self, samples, rank=None): labels = [r['label'] for r in results] scores = [r['score'] for r in results] - if Fields.meta not in samples: - samples[Fields.meta] = [{} for val in labels] - for i in range(len(samples[Fields.meta])): - samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], - MetaKeys.query_intent_label, - labels[i]) - samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], - MetaKeys.query_intent_score, - scores[i]) + for i in range(len(metas)): + metas[i][self.label_key] = labels[i] + metas[i][self.score_key] = scores[i] return samples diff --git a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py index 634bdeab3..20bc77fff 100644 --- a/data_juicer/ops/mapper/query_sentiment_detection_mapper.py +++ b/data_juicer/ops/mapper/query_sentiment_detection_mapper.py @@ -1,14 +1,14 @@ from typing import Dict, Optional -from data_juicer.utils.common_utils import nested_set from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, Mapper OP_NAME = 'query_sentiment_detection_mapper' +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class QuerySentimentDetectionMapper(Mapper): """ @@ -29,6 +29,9 @@ def __init__( zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, + *, + label_key: str = MetaKeys.query_sentiment_label, + score_key: str = MetaKeys.query_sentiment_score, **kwargs): """ Initialization method. @@ -38,10 +41,18 @@ def __init__( If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. + :param label_key: The key name in the meta field to store the + output label. It is 'query_sentiment_label' in default. + :param score_key: The key name in the meta field to store the + corresponding label score. It is 'query_sentiment_label_score' + in default. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) + self.label_key = label_key + self.score_key = score_key + self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_model, return_pipe=True, @@ -59,6 +70,11 @@ def __init__( self.zh_to_en_model_key = None def process_batched(self, samples, rank=None): + + metas = samples[Fields.meta] + if self.label_key in metas[0] and self.score_key in metas[0]: + return samples + queries = samples[self.query_key] if self.zh_to_en_model_key is not None: @@ -72,14 +88,8 @@ def process_batched(self, samples, rank=None): labels = [r['label'] for r in results] scores = [r['score'] for r in results] - if Fields.meta not in samples: - samples[Fields.meta] = [{} for val in labels] - for i in range(len(samples[Fields.meta])): - samples[Fields.meta][i] = nested_set( - samples[Fields.meta][i], MetaKeys.query_sentiment_label, - labels[i]) - samples[Fields.meta][i] = nested_set( - samples[Fields.meta][i], MetaKeys.query_sentiment_score, - scores[i]) + for i in range(len(metas)): + metas[i][self.label_key] = labels[i] + metas[i][self.score_key] = scores[i] return samples diff --git a/data_juicer/ops/mapper/query_topic_detection_mapper.py b/data_juicer/ops/mapper/query_topic_detection_mapper.py index 8e5687ee3..e8aedf431 100644 --- a/data_juicer/ops/mapper/query_topic_detection_mapper.py +++ b/data_juicer/ops/mapper/query_topic_detection_mapper.py @@ -1,14 +1,14 @@ from typing import Dict, Optional -from data_juicer.utils.common_utils import nested_set from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model -from ..base_op import OPERATORS, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, Mapper OP_NAME = 'query_topic_detection_mapper' +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class QueryTopicDetectionMapper(Mapper): """ @@ -28,6 +28,9 @@ def __init__( zh_to_en_hf_model: Optional[str] = 'Helsinki-NLP/opus-mt-zh-en', model_params: Dict = {}, zh_to_en_model_params: Dict = {}, + *, + label_key: str = MetaKeys.query_topic_label, + score_key: str = MetaKeys.query_topic_score, **kwargs): """ Initialization method. @@ -37,10 +40,18 @@ def __init__( If not None, translate the query from Chinese to English. :param model_params: model param for hf_model. :param zh_to_en_model_params: model param for zh_to_hf_model. + :param label_key: The key name in the meta field to store the + output label. It is 'query_topic_label' in default. + :param score_key: The key name in the meta field to store the + corresponding label score. It is 'query_topic_label_score' + in default. :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) + self.label_key = label_key + self.score_key = score_key + self.model_key = prepare_model(model_type='huggingface', pretrained_model_name_or_path=hf_model, return_pipe=True, @@ -58,6 +69,11 @@ def __init__( self.zh_to_en_model_key = None def process_batched(self, samples, rank=None): + + metas = samples[Fields.meta] + if self.label_key in metas[0] and self.score_key in metas[0]: + return samples + queries = samples[self.query_key] if self.zh_to_en_model_key is not None: @@ -71,14 +87,8 @@ def process_batched(self, samples, rank=None): labels = [r['label'] for r in results] scores = [r['score'] for r in results] - if Fields.meta not in samples: - samples[Fields.meta] = [{} for val in labels] - for i in range(len(samples[Fields.meta])): - samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], - MetaKeys.query_topic_label, - labels[i]) - samples[Fields.meta][i] = nested_set(samples[Fields.meta][i], - MetaKeys.query_topic_score, - scores[i]) + for i in range(len(metas)): + metas[i][self.label_key] = labels[i] + metas[i][self.score_key] = scores[i] return samples diff --git a/data_juicer/ops/mapper/relation_identity_mapper.py b/data_juicer/ops/mapper/relation_identity_mapper.py index 29994d744..2370d7b8b 100644 --- a/data_juicer/ops/mapper/relation_identity_mapper.py +++ b/data_juicer/ops/mapper/relation_identity_mapper.py @@ -4,14 +4,15 @@ from loguru import logger from pydantic import PositiveInt -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.utils.common_utils import nested_access, nested_set +from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.model_utils import get_model, prepare_model OP_NAME = 'relation_identity_mapper' # TODO: LLM-based inference. +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class RelationIdentityMapper(Mapper): """ @@ -40,9 +41,8 @@ def __init__(self, api_model: str = 'gpt-4o', source_entity: str = None, target_entity: str = None, - input_key: str = None, - output_key: str = None, *, + output_key: str = MetaKeys.role_relation, api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt_template: Optional[str] = None, @@ -60,12 +60,8 @@ def __init__(self, identified. :param target_entity: The target entity of the relation to be identified. - :param input_key: The input field key in the samples. Support for - nested keys such as "__dj__stats__.text_len". It is text_key - in default. - :param output_key: The output field key in the samples. Support - for nested keys such as "__dj__stats__.text_len". It is - input_key in default. + :param output_key: The output key in the meta field in the + samples. It is 'role_relation' in default. :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. @@ -89,8 +85,7 @@ def __init__(self, self.source_entity = source_entity self.target_entity = target_entity - self.input_key = input_key or self.text_key - self.output_key = output_key or self.input_key + self.output_key = output_key system_prompt_template = system_prompt_template or \ self.DEFAULT_SYSTEM_PROMPT_TEMPLATE @@ -125,9 +120,14 @@ def parse_output(self, raw_output): return relation def process_single(self, sample, rank=None): + + meta = sample[Fields.meta] + if self.output_key in meta: + return sample + client = get_model(self.model_key, rank=rank) - text = nested_access(sample, self.input_key) + text = sample[self.text_key] input_prompt = self.input_template.format(entity1=self.source_entity, entity2=self.target_entity, text=text) @@ -148,7 +148,8 @@ def process_single(self, sample, rank=None): except Exception as e: logger.warning(f'Exception: {e}') - sample = nested_set(sample, self.output_key, relation) + meta[self.output_key] = relation + if self.drop_text: sample.pop(self.text_key) diff --git a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py index 67eb7e234..c054f7695 100644 --- a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py @@ -3,7 +3,7 @@ from pydantic import PositiveInt -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.lazy_loader import AUTOINSTALL from data_juicer.utils.mm_utils import SpecialTokens, remove_special_tokens from data_juicer.utils.model_utils import get_model, prepare_model @@ -197,17 +197,19 @@ def _process_single_sample(self, sample, rank=None): temp_sample = { self.text_key: chunk, self.video_key: loaded_video_keys[offset:offset + vid_count], + Fields.meta: {}, } captioned_text_list = [] # tag ops for op in self.tag_op_list: temp_sample = op.process(temp_sample, rank=rank) - if Fields.video_audio_tags in temp_sample: + if MetaKeys.video_audio_tags in temp_sample[Fields.meta]: captioned_text_list.extend( - temp_sample[Fields.video_audio_tags]) - if Fields.video_frame_tags in temp_sample: - for tag_list in temp_sample[Fields.video_frame_tags]: + temp_sample[Fields.meta][MetaKeys.video_audio_tags]) + if MetaKeys.video_frame_tags in temp_sample[Fields.meta]: + for tag_list in temp_sample[Fields.meta][ + MetaKeys.video_frame_tags]: captioned_text_list.extend(tag_list[self.keep_tag_num]) # cap ops for op in self.cap_op_list: diff --git a/data_juicer/ops/mapper/video_extract_frames_mapper.py b/data_juicer/ops/mapper/video_extract_frames_mapper.py index 4eb522abe..384ab3ee5 100644 --- a/data_juicer/ops/mapper/video_extract_frames_mapper.py +++ b/data_juicer/ops/mapper/video_extract_frames_mapper.py @@ -4,7 +4,7 @@ from pydantic import PositiveInt -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.file_utils import dict_to_hash from data_juicer.utils.mm_utils import ( SpecialTokens, close_video, extract_key_frames, @@ -12,12 +12,13 @@ extract_video_frames_uniformly_by_seconds, load_data_with_context, load_video) -from ..base_op import OPERATORS, Mapper +from ..base_op import OPERATORS, TAGGING_OPS, Mapper from ..op_fusion import LOADED_VIDEOS OP_NAME = 'video_extract_frames_mapper' +@TAGGING_OPS.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) @LOADED_VIDEOS.register_module(OP_NAME) class VideoExtractFramesMapper(Mapper): @@ -41,7 +42,7 @@ def __init__( frame_num: PositiveInt = 3, duration: float = 0, frame_dir: str = None, - frame_key=Fields.video_frames, + frame_key=MetaKeys.video_frames, *args, **kwargs, ): @@ -103,7 +104,7 @@ def _get_default_frame_dir(self, original_filepath): def process_single(self, sample, context=False): # check if it's generated already - if self.frame_key in sample: + if self.frame_key in sample[Fields.meta]: return sample # there is no videos in this sample @@ -168,6 +169,6 @@ def process_single(self, sample, context=False): for vid_key in videos: close_video(videos[vid_key]) - sample[self.frame_key] = json.dumps(video_to_frame_dir) + sample[Fields.meta][self.frame_key] = json.dumps(video_to_frame_dir) return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 7302953f2..8348e381e 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -1,7 +1,7 @@ import librosa import numpy as np -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.lazy_loader import AUTOINSTALL, LazyLoader from data_juicer.utils.mm_utils import extract_audio_from_video from data_juicer.utils.model_utils import get_model, prepare_model @@ -25,7 +25,7 @@ class VideoTaggingFromAudioMapper(Mapper): def __init__(self, hf_ast: str = 'MIT/ast-finetuned-audioset-10-10-0.4593', trust_remote_code: bool = False, - tag_field_name: str = Fields.video_audio_tags, + tag_field_name: str = MetaKeys.video_audio_tags, *args, **kwargs): """ @@ -34,7 +34,7 @@ def __init__(self, :param hf_ast: path to the HF model to tag from audios. :param trust_remote_code: whether to trust the remote code of HF models :param tag_field_name: the field name to store the tags. It's - "__dj__video_audio_tags__" in default. + "video_audio_tags" in default. :param args: extra args :param kwargs: extra args """ diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 31927e1b2..1d0aca2f7 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -3,7 +3,7 @@ import numpy as np from pydantic import PositiveInt -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.mm_utils import (close_video, extract_key_frames, extract_video_frames_uniformly, @@ -32,7 +32,7 @@ class VideoTaggingFromFramesMapper(Mapper): def __init__(self, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, - tag_field_name: str = Fields.video_frame_tags, + tag_field_name: str = MetaKeys.video_frame_tags, *args, **kwargs): """ @@ -52,7 +52,7 @@ def __init__(self, than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. :param tag_field_name: the field name to store the tags. It's - "__dj__video_frame_tags__" in default. + "video_frame_tags" in default. :param args: extra args :param kwargs: extra args """ diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 8a13ae361..dda159965 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -63,30 +63,6 @@ def nested_access(data, path, digit_allowed=True): return data -def nested_set(data: dict, path: str, val): - """ - Set the val to the nested data in the dot-separated path. - - :param data: A dictionary with nested format. - :param path: A dot-separated string representing the path to set. - :return: The nested data after the val set. - """ - keys = path.split('.') - cur = data - try: - for key in keys[:-1]: - if key not in cur: - cur[key] = {} - cur = cur[key] - if keys[-1] in cur: - logger.warning(f'Overwrite value in {path}!') - cur[keys[-1]] = val - except Exception: - logger.warning(f'Unvalid dot-separated path: {path}!') - return data - return data - - def is_string_list(var): """ return if the var is list of string. diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 32fce693b..11de97427 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -11,87 +11,100 @@ class Fields(object): + # for storing stats generated by filter op stats = DEFAULT_PREFIX + 'stats__' + # for storing metas generated by mapper op meta = DEFAULT_PREFIX + 'meta__' + # for storing metas of batch samples generated by aggregator op + batch_meta = DEFAULT_PREFIX + 'batch_meta__' context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' - # tags in meta - # video_frame_tags - video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' - # video_audio_tags - video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' - # image_tags - image_tags = DEFAULT_PREFIX + 'image_tags__' - - # video_frames - video_frames = DEFAULT_PREFIX + 'video_frames__' - # the name of the original file from which this sample was derived. source_file = DEFAULT_PREFIX + 'source_file__' # the name of directory to store the produced multimodal data multimodal_data_output_dir = DEFAULT_PREFIX + 'produced_data__' - # field names for info extraction - event_description = DEFAULT_PREFIX + 'event_description__' - # # a list of characters relevant to the event - relevant_characters = DEFAULT_PREFIX + 'relevant_characters__' - # # the given main entities for attribute extraction - main_entities = DEFAULT_PREFIX + 'main_entities__' - # # the given attributes to be extracted - attributes = DEFAULT_PREFIX + 'attributes__' - # # the extracted attribute descriptions - attribute_descriptions = DEFAULT_PREFIX + 'attribute_descriptions__' - # # extract from raw datas for support the attribute - attribute_support_texts = DEFAULT_PREFIX + 'attribute_support_texts__' - # # the nickname relationship - nickname = DEFAULT_PREFIX + 'nickname__' - # # the entity for knowledge graph - entity = DEFAULT_PREFIX + 'entity__' - # # # the name of entity - entity_name = DEFAULT_PREFIX + 'entity_name__' - # # # the type of entity - entity_type = DEFAULT_PREFIX + 'entity_type__' - # # # the description of entity - entity_description = DEFAULT_PREFIX + 'entity_entity_description__' - # # the relationship for knowledge graph - relation = DEFAULT_PREFIX + 'relation__' - # # # the source entity of the relation - source_entity = DEFAULT_PREFIX + 'relation_source_entity__' - # # # the target entity of the relation - target_entity = DEFAULT_PREFIX + 'relation_target_entity__' - # # # the description of the relation - relation_description = DEFAULT_PREFIX + 'relation_description__' - # # # the keywords of the relation - relation_keywords = DEFAULT_PREFIX + 'relation_keywords__' - # # # the strength of the relation - relation_strength = DEFAULT_PREFIX + 'relation_strength__' - # # the keyword in a text - keyword = DEFAULT_PREFIX + 'keyword__' - # # support text - support_text = DEFAULT_PREFIX + 'support_text__' + +class BatchMetaKeys(object): + entity_attribute = 'entity_attribute' + most_relavant_entities = 'most_relavant_entities' class MetaKeys(object): + # === text related tags === + # # sentiment dialog_sentiment_intensity = 'dialog_sentiment_intensity' dialog_sentiment_intensity_analysis = 'dialog_sentiment_intensity_analysis' query_sentiment_label = 'query_sentiment_label' query_sentiment_score = 'query_sentiment_label_score' dialog_sentiment_labels = 'dialog_sentiment_labels' dialog_sentiment_labels_analysis = 'dialog_sentiment_labels_analysis' - + # # intent dialog_intent_labels = 'dialog_intent_labels' dialog_intent_labels_analysis = 'dialog_intent_labels_analysis' query_intent_label = 'query_intent_label' query_intent_score = 'query_intent_label_score' - + # # topic dialog_topic_labels = 'dialog_topic_labels' dialog_topic_labels_analysis = 'dialog_topic_labels_analysis' query_topic_label = 'query_topic_label' query_topic_score = 'query_topic_label_score' + # === multi-modal related tags === + # # video-frame tags + video_frame_tags = 'video_frame_tags' + # # video-audio tags + video_audio_tags = 'video_audio_tags' + # # video frames + video_frames = 'video_frames' + # # image tags + image_tags = 'image_tags' + + # === info extraction related tags === + # # for event extraction + event_description = 'event_description' + # # a list of characters relevant to the event + relevant_characters = 'relevant_characters' + # # the given main entities for attribute extraction + main_entities = 'main_entities' + # # the given attributes to be extracted + attributes = 'attributes' + # # the extracted attribute descriptions + attribute_descriptions = 'attribute_descriptions' + # # extract from raw datas for support the attribute + attribute_support_texts = 'attribute_support_texts' + # # the nickname relationship + nickname = 'nickname' + # # the entity for knowledge graph + entity = 'entity' + # # # the name of entity + entity_name = 'entity_name' + # # # the type of entity + entity_type = 'entity_type' + # # # the description of entity + entity_description = 'entity_entity_description' + # # the relationship for knowledge graph + relation = 'relation' + # # # the source entity of the relation + source_entity = 'relation_source_entity' + # # # the target entity of the relation + target_entity = 'relation_target_entity' + # # # the description of the relation + relation_description = 'relation_description' + # # # the keywords of the relation + relation_keywords = 'relation_keywords' + # # # the strength of the relation + relation_strength = 'relation_strength' + # # the keyword in a text + keyword = 'keyword' + # # support text + support_text = 'support_text' + # # role relation + role_relation = 'role_relation' + class StatsKeysMeta(type): """ @@ -179,7 +192,7 @@ def get_access_log(cls, dj_cfg=None): class StatsKeysConstant(object): - # text + # === text === alpha_token_ratio = 'alpha_token_ratio' alnum_ratio = 'alnum_ratio' avg_line_length = 'avg_line_length' @@ -198,7 +211,7 @@ class StatsKeysConstant(object): num_words = 'num_words' word_rep_ratio = 'word_rep_ratio' - # image + # === image === aspect_ratios = 'aspect_ratios' image_width = 'image_width' image_height = 'image_height' @@ -211,12 +224,12 @@ class StatsKeysConstant(object): image_watermark_prob = 'image_watermark_prob' image_pair_similarity = 'image_pair_similarity' - # audios + # === audios === audio_duration = 'audio_duration' audio_nmf_snr = 'audio_nmf_snr' audio_sizes = 'audio_sizes' - # videos + # === videos === video_duration = 'video_duration' video_aspect_ratios = 'video_aspect_ratios' video_width = 'video_width' @@ -228,7 +241,7 @@ class StatsKeysConstant(object): video_nsfw_score = 'video_nsfw_score' video_watermark_prob = 'video_watermark_prob' - # multimodal + # === multimodal === # image-text image_text_similarity = 'image_text_similarity' image_text_matching_score = 'image_text_matching_score' @@ -259,24 +272,24 @@ class HashKeys(object): class InterVars(object): - # text + # === text === lines = DEFAULT_PREFIX + 'lines' words = DEFAULT_PREFIX + 'words' refined_words = DEFAULT_PREFIX + 'refined_words' - # image + # === image === loaded_images = DEFAULT_PREFIX + 'loaded_images' # Image - # audios + # === audios === loaded_audios = DEFAULT_PREFIX + 'loaded_audios' # (data, sampling_rate) - # videos - # InputContainer from av. - # Key: {video_path} + # === videos === + # # InputContainer from av. + # # Key: {video_path} loaded_videos = DEFAULT_PREFIX + 'loaded_videos' # sampled frames. - # Key: {video_path}-{frame_sampling_method}[-{frame_num}] - # {frame_num} is only used when {frame_sampling_method} is "uniform" + # # Key: {video_path}-{frame_sampling_method}[-{frame_num}] + # # {frame_num} is only used when {frame_sampling_method} is "uniform" sampled_frames = DEFAULT_PREFIX + 'sampled_frames' diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index 7a8618660..8f30da41a 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -133,6 +133,7 @@ def create_directory_if_not_exists(directory_path): :param directory_path: directory path to be create """ + directory_path = os.path.abspath(directory_path) try: os.makedirs(directory_path, exist_ok=True) except FileExistsError: diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 94b4440eb..6e32434fa 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -613,6 +613,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, ) -> Union[tuple, transformers.modeling_outputs.BaseModelOutputWithPooling]: """Flatten `pixel_values` along the batch and time dimension, @@ -654,6 +655,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, + interpolate_pos_encoding=interpolate_pos_encoding, ) # now restore the original dimensions diff --git a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml index eadac45da..da044ae75 100644 --- a/demos/role_playing_system_prompt/role_playing_system_prompt.yaml +++ b/demos/role_playing_system_prompt/role_playing_system_prompt.yaml @@ -33,23 +33,23 @@ process: api_model: 'qwen2.5-72b-instruct' entity: '李莲花' attribute: '身份背景' - input_key: '__dj__event_description__' - output_key: '__dj__role_background__' + input_key: 'event_description' + output_key: 'role_background' word_limit: 50 - entity_attribute_aggregator: api_model: 'qwen2.5-72b-instruct' entity: '李莲花' attribute: '主要经历' - input_key: '__dj__event_description__' - output_key: '__dj__role_experience__' + input_key: 'event_description' + output_key: 'role_experience' word_limit: 150 # most relavant roles summary from events - most_relavant_entities_aggregator: api_model: 'qwen2.5-72b-instruct' entity: '李莲花' query_entity_type: '人物' - input_key: '__dj__event_description__' - output_key: '__dj__important_relavant_roles__' + input_key: 'event_description' + output_key: 'important_relavant_roles' # generate the system prompt - python_file_mapper: file_path: 'path_to_system_prompt_gereration_python_file' diff --git a/demos/role_playing_system_prompt/system_prompt_generator.py b/demos/role_playing_system_prompt/system_prompt_generator.py index dc2738900..afbeb9bd4 100644 --- a/demos/role_playing_system_prompt/system_prompt_generator.py +++ b/demos/role_playing_system_prompt/system_prompt_generator.py @@ -7,13 +7,15 @@ from data_juicer.ops.aggregator import NestedAggregator from data_juicer.ops.aggregator import EntityAttributeAggregator from data_juicer.ops.mapper import RelationIdentityMapper -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys +from data_juicer.core.data import NestedDataset as Dataset + api_model = 'qwen2.5-72b-instruct' main_entity = "李莲花" query_attributes = ["语言风格", "角色性格", "角色武艺和能力"] -system_prompt_key = '__dj__system_prompt__' +system_prompt_key = 'system_prompt' example_num_limit = 5 max_relavant_roles_num = 5 @@ -30,9 +32,9 @@ api_model=api_model, try_num=3) -def dedup_sort_val_by_chunk_id(sample, id_key, val_key): +def dedup_sort_val_by_chunk_id(sample, id_key, meta_key): chunk_ids = sample[id_key] - vals = sample[val_key] + vals = [d[meta_key] for d in sample[Fields.meta]] id_to_val = {} for id, val in zip(chunk_ids, vals): id_to_val[id] = val @@ -42,10 +44,10 @@ def dedup_sort_val_by_chunk_id(sample, id_key, val_key): return list(chain(*sorted_vals)) def get_attributes(sample): - main_entities = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.main_entities) - attribute_names = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attributes) - attribute_descs = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_descriptions) - attribute_support_texts = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.attribute_support_texts) + main_entities = dedup_sort_val_by_chunk_id(sample, 'chunk_id', MetaKeys.main_entities) + attribute_names = dedup_sort_val_by_chunk_id(sample, 'chunk_id', MetaKeys.attributes) + attribute_descs = dedup_sort_val_by_chunk_id(sample, 'chunk_id', MetaKeys.attribute_descriptions) + attribute_support_texts = dedup_sort_val_by_chunk_id(sample, 'chunk_id', MetaKeys.attribute_support_texts) attributes = {} support_texts = {} for attr in query_attributes: @@ -59,7 +61,7 @@ def get_attributes(sample): return attributes, support_texts def get_nicknames(sample): - nicknames = dedup_sort_val_by_chunk_id(sample, 'chunk_id', Fields.nickname) + nicknames = dedup_sort_val_by_chunk_id(sample, 'chunk_id', MetaKeys.nickname) nickname_map = {} for nr in nicknames: if nr[Fields.source_entity] == main_entity: @@ -85,8 +87,8 @@ def get_nicknames(sample): def get_system_prompt(sample): - main_role_identity = sample['__dj__role_background__'] - main_role_experience = sample['__dj__role_experience__'] + main_role_identity = sample[Fields.batch_meta]['role_background'] + main_role_experience = sample[Fields.batch_meta]['role_experience'] attributes, support_texts = get_attributes(sample) main_role_character = nested_sum.recursive_summary(attributes['角色性格']) main_role_skill = nested_sum.recursive_summary(attributes['角色武艺和能力']) @@ -104,34 +106,37 @@ def get_system_prompt(sample): nicknames = get_nicknames(sample) relation_detail = "" - relavant_roles = sample['__dj__important_relavant_roles__'] + relavant_roles = sample[Fields.batch_meta]['important_relavant_roles'] for role_name in relavant_roles[:max_relavant_roles_num]: if role_name == main_entity: continue - + + cur_sample = {k: sample[k] for k in sample if k != Fields.batch_meta} + + dataset = Dataset.from_list([cur_sample]) # get sub role identity op = EntityAttributeAggregator( api_model=api_model, entity=role_name, attribute='身份背景', - input_key='__dj__event_description__', - output_key='__dj__role_background__', + input_key='event_description', + output_key='role_background', word_limit=30 ) - sample = op.process_single(sample) - role_identity = sample['__dj__role_background__'].replace('\n', '') + dataset = op.run(dataset) + role_identity = dataset[0][Fields.batch_meta]['role_background'].replace('\n', '') # get sub role experience op = EntityAttributeAggregator( api_model=api_model, entity=role_name, attribute='主要经历', - input_key='__dj__event_description__', - output_key='__dj__role_experience__', + input_key='event_description', + output_key='role_experience', word_limit=100 ) - sample = op.process_single(sample) - role_experience = sample['__dj__role_experience__'].replace('\n', '') + dataset = op.run(dataset) + role_experience = dataset[0][Fields.batch_meta]['role_experience'].replace('\n', '') # get relation identity with main role role_info = role_info_template.format( @@ -143,7 +148,7 @@ def get_system_prompt(sample): api_model=api_model, source_entity=main_entity, target_entity=role_name, - output_key='__dj__relation_identity__' + output_key='relation_identity' ) if role_name in nicknames: cur_nicknames = '、'.join(nicknames[role_name]) @@ -157,8 +162,9 @@ def get_system_prompt(sample): nicknames = cur_nicknames ) tmp_sample = {'text': text} - tmp_sample = op.process_single(tmp_sample) - relation = tmp_sample['__dj__relation_identity__'] + dataset = Dataset.from_list([tmp_sample]) + dataset = op.run(dataset) + relation = dataset[0][Fields.meta]['relation_identity'] relation_detail += f"\n{role_name} (称呼:{cur_nicknames})" if relation: diff --git a/tests/ops/aggregator/test_entity_attribute_aggregator.py b/tests/ops/aggregator/test_entity_attribute_aggregator.py index 1f80da3a3..ff390b6fd 100644 --- a/tests/ops/aggregator/test_entity_attribute_aggregator.py +++ b/tests/ops/aggregator/test_entity_attribute_aggregator.py @@ -5,6 +5,7 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.aggregator import EntityAttributeAggregator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.constant import Fields, MetaKeys @SKIPPED_TESTS.register_module() @@ -28,12 +29,12 @@ def _run_helper(self, op, samples): def test_default_aggregator(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -47,12 +48,12 @@ def test_default_aggregator(self): def test_input_output(self): samples = [ { - 'sub_docs': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {'sub_docs': "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {'sub_docs': "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {'sub_docs': '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {'sub_docs': '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {'sub_docs': '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -68,12 +69,12 @@ def test_input_output(self): def test_max_token_num(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -88,12 +89,12 @@ def test_max_token_num(self): def test_word_limit_num(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -109,12 +110,12 @@ def test_word_limit_num(self): def test_example_prompt(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] diff --git a/tests/ops/aggregator/test_most_relavant_entities_aggregator.py b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py index 1d8678134..062cad43d 100644 --- a/tests/ops/aggregator/test_most_relavant_entities_aggregator.py +++ b/tests/ops/aggregator/test_most_relavant_entities_aggregator.py @@ -6,6 +6,8 @@ from data_juicer.ops.aggregator import MostRelavantEntitiesAggregator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.constant import Fields, MetaKeys + @SKIPPED_TESTS.register_module() class MostRelavantEntitiesAggregatorTest(DataJuicerTestCaseBase): @@ -28,12 +30,12 @@ def _run_helper(self, op, samples): def test_default_aggregator(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -48,15 +50,13 @@ def test_default_aggregator(self): def test_input_output(self): samples = [ { - 'dj_result':{ - 'events': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' - ] - } + Fields.meta: [ + {'events': "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {'events': "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {'events': '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {'events': '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {'events': '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} + ] }, ] @@ -64,20 +64,20 @@ def test_input_output(self): api_model='qwen2.5-72b-instruct', entity='李莲花', query_entity_type='人物', - input_key='dj_result.events', - output_key='dj_result.relavant_roles' + input_key='events', + output_key='relavant_roles' ) self._run_helper(op, samples) def test_max_token_num(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] diff --git a/tests/ops/aggregator/test_nested_aggregator.py b/tests/ops/aggregator/test_nested_aggregator.py index 6347652bc..0d16648df 100644 --- a/tests/ops/aggregator/test_nested_aggregator.py +++ b/tests/ops/aggregator/test_nested_aggregator.py @@ -6,6 +6,8 @@ from data_juicer.ops.aggregator import NestedAggregator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.constant import Fields, MetaKeys + @SKIPPED_TESTS.register_module() class NestedAggregatorTest(DataJuicerTestCaseBase): @@ -28,12 +30,12 @@ def _run_helper(self, op, samples): def test_default_aggregator(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -45,12 +47,12 @@ def test_default_aggregator(self): def test_input_output(self): samples = [ { - 'sub_docs': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {'sub_docs': "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {'sub_docs': "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {'sub_docs': '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {'sub_docs': '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {'sub_docs': '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -64,12 +66,12 @@ def test_input_output(self): def test_max_token_num_1(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -82,12 +84,12 @@ def test_max_token_num_1(self): def test_max_token_num_2(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] @@ -100,12 +102,12 @@ def test_max_token_num_2(self): def test_max_token_num_3(self): samples = [ { - 'text': [ - "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。", - "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。", - '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。', - '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。', - '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。' + Fields.meta: [ + {MetaKeys.event_description: "十年前,李相夷十五岁战胜西域天魔成为天下第一高手,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。"}, + {MetaKeys.event_description: "有人视李相夷为中原武林的希望,但也有人以战胜他为目标,包括魔教金鸳盟盟主笛飞声。笛飞声设计加害李相夷的师兄单孤刀,引得李相夷与之一战。"}, + {MetaKeys.event_description: '在东海的一艘船上,李相夷独自一人对抗金鸳盟的高手,最终击败了大部分敌人。笛飞声突然出现,两人激战,李相夷在战斗中中毒,最终被笛飞声重伤,船只爆炸,李相夷沉入大海。'}, + {MetaKeys.event_description: '十年后,李莲花在一个寒酸的莲花楼内醒来,表现出与李相夷截然不同的性格。他以神医的身份在小镇上行医,但生活贫困。'}, + {MetaKeys.event_description: '小镇上的皮影戏摊讲述李相夷和笛飞声的故事,孩子们争论谁赢了。风火堂管事带着人来找李莲花,要求他救治一个“死人”。'} ] }, ] diff --git a/tests/ops/grouper/test_naive_reverse_grouper.py b/tests/ops/grouper/test_naive_reverse_grouper.py index 29c06451d..34a77375b 100644 --- a/tests/ops/grouper/test_naive_reverse_grouper.py +++ b/tests/ops/grouper/test_naive_reverse_grouper.py @@ -1,18 +1,29 @@ import unittest +import json +import os from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.grouper.naive_reverse_grouper import NaiveReverseGrouper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.constant import Fields class NaiveReverseGrouperTest(DataJuicerTestCaseBase): - def _run_helper(self, op, samples, target): + def _run_helper(self, op, samples, target, meta_target=None, meta_path=None): dataset = Dataset.from_list(samples) new_dataset = op.run(dataset) for d, t in zip(new_dataset, target): self.assertEqual(d['text'], t['text']) + + if meta_target is not None: + batch_meta = [] + with open(meta_path) as f: + for line in f.readlines(): + batch_meta.append(json.loads(line)) + self.assertEqual(batch_meta, meta_target) + os.remove(meta_path) def test_one_batched_sample(self): @@ -78,6 +89,100 @@ def test_two_batch_sample(self): op = NaiveReverseGrouper() self._run_helper(op, source, target) + + def test_rm_unbatched_keys1(self): + source = [ + { + 'text':[ + "Today is Sunday and it's a happy day!", + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + ], + Fields.batch_meta: {'batch_size': 2}, + } + ] + + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + } + ] + + op = NaiveReverseGrouper() + self._run_helper(op, source, target) + + def test_rm_unbatched_keys2(self): + source = [ + { + 'text':[ + '欢迎来到阿里巴巴!' + ], + 'query':[ + 'Can I help you?' + ], + Fields.batch_meta: { + 'reponse':[ + 'No', + 'Yes' + ], + 'batch_size': 1, + } + }, + { + 'text':[ + 'Can I help you?' + ], + 'query':[ + '欢迎来到阿里巴巴!' + ], + Fields.batch_meta: { + 'reponse':[ + 'No', + 'Yes' + ], + 'batch_size': 1, + } + } + ] + + target = [ + { + 'text': '欢迎来到阿里巴巴!', + 'query': 'Can I help you?', + }, + { + 'text': 'Can I help you?', + 'query': '欢迎来到阿里巴巴!', + }, + ] + + target_meta = [ + { + 'reponse':[ + 'No', + 'Yes' + ], + 'batch_size': 1, + }, + { + 'reponse':[ + 'No', + 'Yes' + ], + 'batch_size': 1, + } + ] + + export_path = '__dj__naive_reverse_grouper_test_file.jsonl' + op = NaiveReverseGrouper(export_path) + self._run_helper(op, source, target, + meta_target=target_meta, + meta_path=export_path) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/ops/mapper/test_dialog_intent_detection_mapper.py b/tests/ops/mapper/test_dialog_intent_detection_mapper.py index bc3a18752..5baece8f8 100644 --- a/tests/ops/mapper/test_dialog_intent_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_intent_detection_mapper.py @@ -8,7 +8,6 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access # Skip tests for this OP. # These tests have been tested locally. @@ -18,11 +17,13 @@ class TestDialogIntentDetectionMapper(DataJuicerTestCaseBase): # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 # export OPENAI_API_KEY=your_key - def _run_op(self, op, samples, target_len): + def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels_analysis) - labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_intent_labels) + dataset = op.run(dataset) + labels_key = labels_key or MetaKeys.dialog_intent_labels + analysis_key = analysis_key or MetaKeys.dialog_intent_labels_analysis + labels_list = dataset[0][Fields.meta][labels_key] + analysis_list = dataset[0][Fields.meta][analysis_key] for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') @@ -165,6 +166,36 @@ def test_intent_candidates(self): ) self._run_op(op, samples, 4) + def test_rename_keys(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + labels_key = 'my_label' + analysis_key = 'my_analysis' + op = DialogIntentDetectionMapper(api_model='qwen2.5-72b-instruct', + labels_key=labels_key, + analysis_key=analysis_key) + self._run_op(op, samples, 4, labels_key, analysis_key) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py index b19bf6359..ac6236282 100644 --- a/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_detection_mapper.py @@ -8,7 +8,6 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access # Skip tests for this OP. # These tests have been tested locally. @@ -18,11 +17,13 @@ class TestDialogSentimentDetectionMapper(DataJuicerTestCaseBase): # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 # export OPENAI_API_KEY=your_key - def _run_op(self, op, samples, target_len): + def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels_analysis) - labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_labels) + dataset = op.run(dataset) + labels_key = labels_key or MetaKeys.dialog_sentiment_labels + analysis_key = analysis_key or MetaKeys.dialog_sentiment_labels_analysis + labels_list = dataset[0][Fields.meta][labels_key] + analysis_list = dataset[0][Fields.meta][analysis_key] for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') @@ -136,6 +137,63 @@ def test_query(self): max_round=1) self._run_op(op, samples, 4) + def test_sentiment_candidates(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + sentiment_candidates=['认可', '不满', '困惑']) + self._run_op(op, samples, 4) + + def test_rename_keys(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + labels_key = 'my_label' + analysis_key = 'my_analysis' + op = DialogSentimentDetectionMapper(api_model='qwen2.5-72b-instruct', + labels_key=labels_key, + analysis_key=analysis_key) + self._run_op(op, samples, 4, labels_key, analysis_key) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py index a8953c3e4..ed7de409a 100644 --- a/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py +++ b/tests/ops/mapper/test_dialog_sentiment_intensity_mapper.py @@ -8,7 +8,6 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access # Skip tests for this OP. # These tests have been tested locally. @@ -18,11 +17,13 @@ class TestDialogSentimentIntensityMapper(DataJuicerTestCaseBase): # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 # export OPENAI_API_KEY=your_key - def _run_op(self, op, samples, target_len): + def _run_op(self, op, samples, target_len, intensities_key=None, analysis_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity_analysis) - intensity_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_sentiment_intensity) + dataset = op.run(dataset) + intensities_key = intensities_key or MetaKeys.dialog_sentiment_intensity + analysis_key = analysis_key or MetaKeys.dialog_sentiment_intensity_analysis + intensity_list = dataset[0][Fields.meta][intensities_key] + analysis_list = dataset[0][Fields.meta][analysis_key] for analysis, intensity in zip(analysis_list, intensity_list): logger.info(f'分析:{analysis}') @@ -136,6 +137,36 @@ def test_query(self): max_round=1) self._run_op(op, samples, 4) + def test_rename_keys(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + intensities_key = 'my_intensity' + analysis_key = 'my_analysis' + op = DialogSentimentIntensityMapper(api_model='qwen2.5-72b-instruct', + intensities_key=intensities_key, + analysis_key=analysis_key) + self._run_op(op, samples, 4, intensities_key, analysis_key) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_dialog_topic_detection_mapper.py b/tests/ops/mapper/test_dialog_topic_detection_mapper.py index 887e96bad..f1dc1d9cb 100644 --- a/tests/ops/mapper/test_dialog_topic_detection_mapper.py +++ b/tests/ops/mapper/test_dialog_topic_detection_mapper.py @@ -8,7 +8,6 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access # Skip tests for this OP. # These tests have been tested locally. @@ -18,11 +17,14 @@ class TestDialogTopicDetectionMapper(DataJuicerTestCaseBase): # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 # export OPENAI_API_KEY=your_key - def _run_op(self, op, samples, target_len): + def _run_op(self, op, samples, target_len, labels_key=None, analysis_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) - analysis_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels_analysis) - labels_list = nested_access(dataset[0][Fields.meta], MetaKeys.dialog_topic_labels) + dataset = op.run(dataset) + + labels_key = labels_key or MetaKeys.dialog_topic_labels + analysis_key = analysis_key or MetaKeys.dialog_topic_labels_analysis + labels_list = dataset[0][Fields.meta][labels_key] + analysis_list = dataset[0][Fields.meta][analysis_key] for analysis, labels in zip(analysis_list, labels_list): logger.info(f'分析:{analysis}') @@ -136,6 +138,63 @@ def test_query(self): max_round=1) self._run_op(op, samples, 4) + def test_topic_candidates(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + topic_candidates=['评价', '沟通', '闲聊', '其他']) + self._run_op(op, samples, 4) + + def test_rename_keys(self): + + samples = [{ + 'history': [ + ( + '李莲花有口皆碑', + '「微笑」过奖了,我也就是个普通大夫,没什么值得夸耀的。' + ), + ( + '是的,你确实是一个普通大夫,没什么值得夸耀的。', + '「委屈」你这话说的,我也是尽心尽力治病救人了。' + ), + ( + '你自己说的呀,我现在说了,你又不高兴了。', + 'or of of of of or or and or of of of of of of of,,, ' + ), + ( + '你在说什么我听不懂。', + '「委屈」我也没说什么呀,就是觉得你有点冤枉我了' + ) + ] + }] + + labels_key = 'my_label' + analysis_key = 'my_analysis' + op = DialogTopicDetectionMapper(api_model='qwen2.5-72b-instruct', + labels_key=labels_key, + analysis_key=analysis_key) + self._run_op(op, samples, 4, labels_key, analysis_key) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py index a2c156d48..9707b2beb 100644 --- a/tests/ops/mapper/test_extract_entity_attribute_mapper.py +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.extract_entity_attribute_mapper import ExtractEntityAttributeMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @@ -47,12 +47,12 @@ def _run_op(self, api_model, response_path=None): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=1) + dataset = op.run(dataset) for sample in dataset: - ents = sample[Fields.main_entities] - attrs = sample[Fields.attributes] - descs = sample[Fields.attribute_descriptions] - sups = sample[Fields.attribute_support_texts] + ents = sample[Fields.meta][MetaKeys.main_entities] + attrs = sample[Fields.meta][MetaKeys.attributes] + descs = sample[Fields.meta][MetaKeys.attribute_descriptions] + sups = sample[Fields.meta][MetaKeys.attribute_support_texts] for ent, attr, desc, sup in zip(ents, attrs, descs, sups): logger.info(f'{ent} {attr}: {desc}') self.assertNotEqual(desc, '') diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py index 0aed4fcee..a4c413a33 100644 --- a/tests/ops/mapper/test_extract_entity_relation_mapper.py +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.extract_entity_relation_mapper import ExtractEntityRelationMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @@ -54,10 +54,10 @@ def _run_op(self, op): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) sample = dataset[0] - logger.info(f"entitis: {sample[Fields.entity]}") - logger.info(f"relations: {sample[Fields.relation]}") + logger.info(f"entitis: {sample[Fields.meta][MetaKeys.entity]}") + logger.info(f"relations: {sample[Fields.meta][MetaKeys.relation]}") def test_default(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py index e936cb06c..4c7f47a2b 100644 --- a/tests/ops/mapper/test_extract_event_mapper.py +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.extract_event_mapper import ExtractEventMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @@ -63,10 +63,10 @@ def _run_op(self, api_model, response_path=None): for sample in dataset: logger.info(f"chunk_id: {sample['chunk_id']}") self.assertEqual(sample['chunk_id'], 0) - logger.info(f"event: {sample[Fields.event_description]}") - self.assertNotEqual(sample[Fields.event_description], '') - logger.info(f"characters: {sample[Fields.relevant_characters]}") - self.assertNotEqual(sample[Fields.relevant_characters], []) + logger.info(f"event: {sample[Fields.meta][MetaKeys.event_description]}") + self.assertNotEqual(sample[Fields.meta][MetaKeys.event_description], '') + logger.info(f"characters: {sample[Fields.meta][MetaKeys.relevant_characters]}") + self.assertNotEqual(sample[Fields.meta][MetaKeys.relevant_characters], []) def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py index 2501a46ca..8528be5d4 100644 --- a/tests/ops/mapper/test_extract_keyword_mapper.py +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.extract_keyword_mapper import ExtractKeywordMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @@ -57,9 +57,9 @@ def _run_op(self, api_model, response_path=None): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) sample = dataset[0] - logger.info(f"keywords: {sample[Fields.keyword]}") + logger.info(f"keywords: {sample[Fields.meta][MetaKeys.keyword]}") def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py index 457a7d53b..a869bda92 100644 --- a/tests/ops/mapper/test_extract_nickname_mapper.py +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -7,7 +7,7 @@ from data_juicer.ops.mapper.extract_nickname_mapper import ExtractNicknameMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @@ -37,12 +37,12 @@ def _run_op(self, api_model, response_path=None): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) - result = dataset[0][Fields.nickname] + dataset = op.run(dataset) + result = dataset[0][Fields.meta][MetaKeys.nickname] result = [( - d[Fields.source_entity], - d[Fields.target_entity], - d[Fields.relation_description]) + d[MetaKeys.source_entity], + d[MetaKeys.target_entity], + d[MetaKeys.relation_description]) for d in result] logger.info(f'result: {result}') self.assertIn(("李莲花","方多病","方小宝"), result) diff --git a/tests/ops/mapper/test_extract_support_text_mapper.py b/tests/ops/mapper/test_extract_support_text_mapper.py index 080dfd672..d4d920fe8 100644 --- a/tests/ops/mapper/test_extract_support_text_mapper.py +++ b/tests/ops/mapper/test_extract_support_text_mapper.py @@ -7,8 +7,7 @@ from data_juicer.ops.mapper.extract_support_text_mapper import ExtractSupportTextMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields -from data_juicer.utils.common_utils import nested_access +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @@ -18,11 +17,7 @@ class ExtractSupportTextMapperTest(DataJuicerTestCaseBase): def _run_op(self, api_model): - summary_key = 'data.event' - support_text_key = 'data.support_text' - op = ExtractSupportTextMapper(api_model=api_model, - summary_key=summary_key, - support_text_key=support_text_key) + op = ExtractSupportTextMapper(api_model=api_model) raw_text = """△芩婆走到中间,看着众人。 芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 @@ -59,15 +54,15 @@ def _run_op(self, api_model): event = "李相显托付单孤刀。" samples = [{ 'text': raw_text, - 'data':{ - 'event': event + Fields.meta:{ + MetaKeys.event_description: event } }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) sample = dataset[0] - logger.info(f"support_text: \n{nested_access(sample, support_text_key)}") + logger.info(f"support_text: \n{sample[Fields.meta][MetaKeys.support_text]}") def test(self): # before runing this test, set below environment variables: diff --git a/tests/ops/mapper/test_image_tagging_mapper.py b/tests/ops/mapper/test_image_tagging_mapper.py index d2bbddec2..5c1d3b9c4 100644 --- a/tests/ops/mapper/test_image_tagging_mapper.py +++ b/tests/ops/mapper/test_image_tagging_mapper.py @@ -5,7 +5,7 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.image_tagging_mapper import \ ImageTaggingMapper -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS # Skip tests for this OP in the GitHub actions due to OOM on the current runner @@ -42,7 +42,7 @@ def test(self): tgt_list = [{ 'images': [self.img1_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', 'chair', 'pillar', 'comfort', 'side table', 'floor', 'hardwood floor', 'headboard', 'linen', 'mattress', @@ -51,14 +51,14 @@ def test(self): }, { 'images': [self.img2_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'advertisement', 'back', 'bus', 'car', 'city bus', 'city street', 'curb', 'decker bus', 'drive', 'license plate', 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'alley', 'black', 'building', 'catch', 'person', 'pavement', 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] @@ -74,11 +74,11 @@ def test_no_images(self): tgt_list = [{ 'images': [], Fields.meta: { - Fields.image_tags: [[]]}, + MetaKeys.image_tags: [[]]}, }, { 'images': [self.img2_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'advertisement', 'back', 'bus', 'car', 'city bus', 'city street', 'curb', 'decker bus', 'drive', 'license plate', 'road', 'street scene', 'tour bus', 'travel', 'white']]}, @@ -138,7 +138,7 @@ def test_multi_process(self): tgt_list = [{ 'images': [self.img1_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', 'chair', 'pillar', 'comfort', 'side table', 'floor', 'hardwood floor', 'headboard', 'linen', 'mattress', @@ -147,14 +147,14 @@ def test_multi_process(self): }, { 'images': [self.img2_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'advertisement', 'back', 'bus', 'car', 'city bus', 'city street', 'curb', 'decker bus', 'drive', 'license plate', 'road', 'street scene', 'tour bus', 'travel', 'white']]}, }, { 'images': [self.img3_path], Fields.meta: { - Fields.image_tags: [[ + MetaKeys.image_tags: [[ 'alley', 'black', 'building', 'catch', 'person', 'pavement', 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']]}, }] diff --git a/tests/ops/mapper/test_query_intent_detection_mapper.py b/tests/ops/mapper/test_query_intent_detection_mapper.py index 92d0346a4..7d00eb2df 100644 --- a/tests/ops/mapper/test_query_intent_detection_mapper.py +++ b/tests/ops/mapper/test_query_intent_detection_mapper.py @@ -8,19 +8,23 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access class TestQueryIntentDetectionMapper(DataJuicerTestCaseBase): hf_model = 'bespin-global/klue-roberta-small-3i4k-intent-classification' zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' - def _run_op(self, op, samples, label_key, targets): + def _run_op(self, op, samples, targets, label_key=None, score_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) + + label_key = label_key or MetaKeys.query_intent_label + score_key = score_key or MetaKeys.query_intent_score for sample, target in zip(dataset, targets): - label = nested_access(sample[Fields.meta], label_key) + label = sample[Fields.meta][label_key] + score = sample[Fields.meta][score_key] + logger.info(f'{label}: {score}') self.assertEqual(label, target) def test_default(self): @@ -39,7 +43,7 @@ def test_default(self): hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, ) - self._run_op(op, samples, MetaKeys.query_intent_label, targets) + self._run_op(op, samples, targets) def test_no_zh_to_en(self): @@ -55,7 +59,29 @@ def test_no_zh_to_en(self): hf_model = self.hf_model, zh_to_en_hf_model = None, ) - self._run_op(op, samples, MetaKeys.query_intent_label, targets) + self._run_op(op, samples, targets) + + def test_rename_keys(self): + + samples = [{ + 'query': '这样好吗?' + },{ + 'query': '站住!' + },{ + 'query': '今天阳光灿烂。' + } + ] + targets = ['question', 'command', 'statement'] + + label_key = 'my_label' + score_key = 'my_score' + op = QueryIntentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + label_key = label_key, + score_key = score_key, + ) + self._run_op(op, samples, targets, label_key, score_key) if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_query_sentiment_detection_mapper.py b/tests/ops/mapper/test_query_sentiment_detection_mapper.py index 62ed0f380..09834cfdc 100644 --- a/tests/ops/mapper/test_query_sentiment_detection_mapper.py +++ b/tests/ops/mapper/test_query_sentiment_detection_mapper.py @@ -8,19 +8,23 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access class TestQuerySentimentDetectionMapper(DataJuicerTestCaseBase): hf_model = 'mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis' zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' - def _run_op(self, op, samples, label_key, targets): + def _run_op(self, op, samples, targets, label_key=None, score_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) + + label_key = label_key or MetaKeys.query_sentiment_label + score_key = score_key or MetaKeys.query_sentiment_score for sample, target in zip(dataset, targets): - label = nested_access(sample[Fields.meta], label_key) + label = sample[Fields.meta][label_key] + score = sample[Fields.meta][score_key] + logger.info(f'{label}: {score}') self.assertEqual(label, target) def test_default(self): @@ -39,7 +43,7 @@ def test_default(self): hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, ) - self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + self._run_op(op, samples, targets) def test_no_zh_to_en(self): @@ -55,7 +59,29 @@ def test_no_zh_to_en(self): hf_model = self.hf_model, zh_to_en_hf_model = None, ) - self._run_op(op, samples, MetaKeys.query_sentiment_label, targets) + self._run_op(op, samples, targets) + + def test_rename_keys(self): + + samples = [{ + 'query': '太棒了!' + },{ + 'query': '嗯嗯' + },{ + 'query': '没有希望。' + }, + ] + targets = ['positive', 'neutral', 'negative'] + + label_key = 'my_label' + score_key = 'my_score' + op = QuerySentimentDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + label_key = label_key, + score_key = score_key, + ) + self._run_op(op, samples, targets, label_key, score_key) if __name__ == '__main__': diff --git a/tests/ops/mapper/test_query_topic_detection_mapper.py b/tests/ops/mapper/test_query_topic_detection_mapper.py index 6304290c7..ba0b8932c 100644 --- a/tests/ops/mapper/test_query_topic_detection_mapper.py +++ b/tests/ops/mapper/test_query_topic_detection_mapper.py @@ -8,19 +8,23 @@ from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) from data_juicer.utils.constant import Fields, MetaKeys -from data_juicer.utils.common_utils import nested_access class TestQueryTopicDetectionMapper(DataJuicerTestCaseBase): hf_model = 'dstefa/roberta-base_topic_classification_nyt_news' zh_to_en_hf_model = 'Helsinki-NLP/opus-mt-zh-en' - def _run_op(self, op, samples, label_key, targets): + def _run_op(self, op, samples, targets, label_key=None, score_key=None): dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) + + label_key = label_key or MetaKeys.query_topic_label + score_key = score_key or MetaKeys.query_topic_score for sample, target in zip(dataset, targets): - label = nested_access(sample[Fields.meta], label_key) + label = sample[Fields.meta][label_key] + score = sample[Fields.meta][score_key] + logger.info(f'{label}: {score}') self.assertEqual(label, target) def test_default(self): @@ -37,7 +41,7 @@ def test_default(self): hf_model = self.hf_model, zh_to_en_hf_model = self.zh_to_en_hf_model, ) - self._run_op(op, samples, MetaKeys.query_topic_label, targets) + self._run_op(op, samples, targets) def test_no_zh_to_en(self): @@ -53,7 +57,27 @@ def test_no_zh_to_en(self): hf_model = self.hf_model, zh_to_en_hf_model = None, ) - self._run_op(op, samples, MetaKeys.query_topic_label, targets) + self._run_op(op, samples, targets) + + def test_rename_keys(self): + + samples = [{ + 'query': '今天火箭和快船的比赛谁赢了。' + },{ + 'query': '你最近身体怎么样。' + } + ] + targets = ['Sports', 'Health and Wellness'] + + label_key = 'my_label' + score_key = 'my_score' + op = QueryTopicDetectionMapper( + hf_model = self.hf_model, + zh_to_en_hf_model = self.zh_to_en_hf_model, + label_key = label_key, + score_key = score_key, + ) + self._run_op(op, samples, targets, label_key, score_key) if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_relation_identity_mapper.py b/tests/ops/mapper/test_relation_identity_mapper.py index 231b20ba1..3a243189b 100644 --- a/tests/ops/mapper/test_relation_identity_mapper.py +++ b/tests/ops/mapper/test_relation_identity_mapper.py @@ -7,20 +7,23 @@ from data_juicer.ops.mapper.relation_identity_mapper import RelationIdentityMapper from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, DataJuicerTestCaseBase) -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys # Skip tests for this OP. # These tests have been tested locally. @SKIPPED_TESTS.register_module() class RelationIdentityMapperTest(DataJuicerTestCaseBase): + # before runing this test, set below environment variables: + # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 + # export OPENAI_API_KEY=your_key - def _run_op(self, api_model, response_path=None): + def _run_op(self, api_model, output_key=MetaKeys.role_relation): op = RelationIdentityMapper(api_model=api_model, source_entity="李莲花", target_entity="方多病", - response_path=response_path) + output_key=output_key) raw_text = """李莲花原名李相夷,十五岁战胜西域天魔,十七岁建立四顾门,二十岁问鼎武林盟主,成为传奇人物。 在与金鸳盟盟主笛飞声的对决中,李相夷中毒重伤,沉入大海,十年后在莲花楼醒来,过起了市井生活。他帮助肉铺掌柜解决家庭矛盾,表现出敏锐的洞察力。 @@ -42,17 +45,17 @@ def _run_op(self, api_model, response_path=None): }] dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process, batch_size=2) + dataset = op.run(dataset) for data in dataset: for k in data: logger.info(f"{k}: {data[k]}") - def test(self): - # before runing this test, set below environment variables: - # export OPENAI_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 - # export OPENAI_API_KEY=your_key + def test_default(self): self._run_op('qwen2.5-72b-instruct') + def test_rename_key(self): + self._run_op('qwen2.5-72b-instruct', output_key='output') + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_video_extract_frames_mapper.py b/tests/ops/mapper/test_video_extract_frames_mapper.py index 7ae2dd29f..d6b227165 100644 --- a/tests/ops/mapper/test_video_extract_frames_mapper.py +++ b/tests/ops/mapper/test_video_extract_frames_mapper.py @@ -9,7 +9,7 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_extract_frames_mapper import \ VideoExtractFramesMapper -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -70,18 +70,20 @@ def test_duration(self): vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) tgt_list = copy.deepcopy(ds_list) - tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) - tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) - tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}}) + tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}}) + tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}}) op = VideoExtractFramesMapper( frame_sampling_method='uniform', frame_num=frame_num, duration=0, - frame_dir=frame_dir) + frame_dir=frame_dir, + batch_size=2, + num_proc=1) dataset = Dataset.from_list(ds_list) - dataset = dataset.map(op.process, batch_size=2, num_proc=1) + dataset = op.run(dataset) res_list = dataset.to_list() self.assertEqual(res_list, tgt_list) self.assertListEqual( @@ -114,18 +116,20 @@ def test_uniform_sampling(self): vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) tgt_list = copy.deepcopy(ds_list) - tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) - tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) - tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}}) + tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}}) + tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}}) op = VideoExtractFramesMapper( frame_sampling_method='uniform', frame_num=frame_num, duration=10, - frame_dir=frame_dir) + frame_dir=frame_dir, + batch_size=2, + num_proc=1) dataset = Dataset.from_list(ds_list) - dataset = dataset.map(op.process, batch_size=2, num_proc=1) + dataset = op.run(dataset) res_list = dataset.to_list() self.assertEqual(res_list, tgt_list) self.assertListEqual( @@ -158,22 +162,24 @@ def test_all_keyframes_sampling(self): vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) tgt_list = copy.deepcopy(ds_list) - tgt_list[0].update({Fields.video_frames: - json.dumps({self.vid1_path: vid1_frame_dir})}) - tgt_list[1].update({Fields.video_frames: json.dumps({ + tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: + json.dumps({self.vid1_path: vid1_frame_dir})}}) + tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({ self.vid2_path: vid2_frame_dir, self.vid3_path: vid3_frame_dir - })}) - tgt_list[2].update({Fields.video_frames: - json.dumps({self.vid3_path: vid3_frame_dir})}) + })}}) + tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: + json.dumps({self.vid3_path: vid3_frame_dir})}}) op = VideoExtractFramesMapper( frame_sampling_method='all_keyframes', frame_dir=frame_dir, - duration=5) + duration=5, + batch_size=2, + num_proc=2) dataset = Dataset.from_list(ds_list) - dataset = dataset.map(op.process, batch_size=2, num_proc=2) + dataset = op.run(dataset) res_list = dataset.to_list() self.assertEqual(res_list, tgt_list) self.assertListEqual( @@ -205,6 +211,8 @@ def test_default_frame_dir(self): frame_sampling_method='uniform', frame_num=frame_num, duration=5, + batch_size=2, + num_proc=1 ) vid1_frame_dir = op._get_default_frame_dir(self.vid1_path) @@ -212,12 +220,12 @@ def test_default_frame_dir(self): vid3_frame_dir = op._get_default_frame_dir(self.vid3_path) tgt_list = copy.deepcopy(ds_list) - tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) - tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) - tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}}) + tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}}) + tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}}) dataset = Dataset.from_list(ds_list) - dataset = dataset.map(op.process, batch_size=2, num_proc=1) + dataset = op.run(dataset) res_list = dataset.to_list() frame_dir_prefix = self._get_default_frame_dir_prefix() diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py index 00a376170..1929557d0 100644 --- a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py @@ -4,7 +4,7 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_tagging_from_audio_mapper import \ VideoTaggingFromAudioMapper -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -28,7 +28,7 @@ def _run_video_tagging_from_audio_mapper(self, op, source_list, target_list, - tag_field_name=Fields.video_audio_tags, + tag_field_name=MetaKeys.video_audio_tags, num_proc=1): dataset = Dataset.from_list(source_list) if Fields.meta not in dataset.features: diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py index 31fc04c3b..97fb74d77 100644 --- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -5,7 +5,7 @@ from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_tagging_from_frames_mapper import \ VideoTaggingFromFramesMapper -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS @@ -50,7 +50,7 @@ def test(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'animal', 'ray', 'text', 'writing', 'yellow', 'game', 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', 'sky' @@ -60,7 +60,7 @@ def test(self): f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', 'ball', 'person' @@ -70,7 +70,7 @@ def test(self): f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', 'conversation', 'round table', 'closet', 'computer', 'girl', 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', @@ -94,13 +94,13 @@ def test_no_video(self): f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [], Fields.meta: { - Fields.video_frame_tags: [[]]} + MetaKeys.video_frame_tags: [[]]} }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', 'ball', 'person' @@ -177,7 +177,7 @@ def test_uniform(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'cartoon', 'animal', 'anime', 'game', 'screenshot', 'video game', 'cartoon character', 'robe', 'ray', 'text', 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']]} @@ -186,7 +186,7 @@ def test_uniform(self): f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy', 'tennis racket', 'blind', 'game controller', 'remote', 'stand', @@ -197,7 +197,7 @@ def test_uniform(self): f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', 'round table', 'computer', 'girl', 'man', 'closet', 'laptop', 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', @@ -231,7 +231,7 @@ def test_multi_process(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'animal', 'ray', 'text', 'writing', 'yellow', 'game', 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', 'sky' @@ -241,7 +241,7 @@ def test_multi_process(self): f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', 'ball', 'person' @@ -251,7 +251,7 @@ def test_multi_process(self): f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', 'conversation', 'round table', 'closet', 'computer', 'girl', 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', @@ -286,7 +286,7 @@ def test_multi_chunk(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', 'videos': [self.vid1_path, self.vid2_path], Fields.meta: { - Fields.video_frame_tags: + MetaKeys.video_frame_tags: [[ 'animal', 'ray', 'text', 'writing', 'yellow', 'game', 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', @@ -301,7 +301,7 @@ def test_multi_chunk(self): f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid2_path, self.vid3_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', 'ball', 'person' @@ -316,7 +316,7 @@ def test_multi_chunk(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid1_path, self.vid3_path], Fields.meta: { - Fields.video_frame_tags: [[ + MetaKeys.video_frame_tags: [[ 'animal', 'ray', 'text', 'writing', 'yellow', 'game', 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', 'sky'