diff --git a/configs/config_all.yaml b/configs/config_all.yaml index d9951ca02..90fc18875 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -55,16 +55,16 @@ process: - audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters - calibrate_qa_mapper: # calibrate question-answer pairs based on reference text. api_model: 'gpt-4o' # API model name. - api_url: null # API URL. Defaults to DJ_API_URL environment variable. - api_key: null # API key. Defaults to DJ_API_KEY environment variable. + api_endpoint: null # URL endpoint for the API. response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. system_prompt: null # System prompt for the calibration task. input_template: null # Template for building the model input. reference_template: null # Template for formatting the reference text. qa_pair_template: null # Template for formatting question-answer pairs. output_pattern: null # Regular expression for parsing model output. - model_params: null # Parameters for initializing the model. - sampling_params: null # Extra parameters passed to the API call. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + model_params: {} # Parameters for initializing the model. + sampling_params: {} # Extra parameters passed to the API call. - calibrate_query_mapper: # calibrate query in question-answer pairs based on reference text. - calibrate_response_mapper: # calibrate response in question-answer pairs based on reference text. - chinese_convert_mapper: # convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. @@ -75,6 +75,81 @@ process: - clean_links_mapper: # remove web links from text. - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. + - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. + query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried. + query_attributes: ["人物性格"] # Attribute list to be queried. + api_model: 'gpt-4o' # API model name. + entity_key: '__dj__entity__' # The field name to store the given main entity for attribute extraction. + entity_attribute_key: '__dj__attribute__' # The field name to store the given attribute to be extracted. + attribute_desc_key: '__dj__attribute_description__' # The field name to store the extracted attribute description. + support_text_key: '__dj__support_text__' # The field name to store the attribute support text extracted from the raw text. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt_template: null # System prompt template for the task. Need to be specified by given entity and attribute. + input_template: null # Template for building the model input. + attr_pattern_template: null # Pattern for parsing the attribute from output. Need to be specified by given attribute. + demo_pattern: null # Pattern for parsing the demonstraction from output to support the attribute. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_entity_relation_mapper: # Extract entities and relations in the text for knowledge graph. + api_model: 'gpt-4o' # API model name. + entity_types: ['person', 'organization', 'location'] # Pre-defined entity types for knowledge graph. + entity_key: '__dj__entity__' # The field name to store the entities. + relation_key: '__dj__relation__' # The field name to store the relations between entities. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + prompt_template: null # The template of input prompt. + tuple_delimiter: null # Delimiter to separate items in outputs. + record_delimiter: null # Delimiter to separate records in outputs. + completion_delimiter: null # To mark the end of the output. + max_gleaning: 1 # the extra max num to call LLM to glean entities and relations. + continue_prompt: null # the prompt for gleaning entities and relations. + if_loop_prompt: null # the prompt to determine whether to stop gleaning. + entity_pattern: null # Regular expression for parsing entity record. + relation_pattern: null # Regular expression for parsing relation record. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_event_mapper: # Extract events and relevant characters in the text + api_model: 'gpt-4o' # API model name. + event_desc_key: '__dj__event_description__' # The field name to store the event descriptions. + relevant_char_key: '__dj__relevant_characters__' # The field name to store the relevant characters to the events. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + input_template: null # Template for building the model input. + output_pattern: null # Regular expression for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_keyword_mapper: # Generate keywords for the text. + api_model: 'gpt-4o' # API model name. + keyword_key: '__dj__keyword__' # The field name to store the keywords. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + prompt_template: null # The template of input prompt. + completion_delimiter: null # To mark the end of the output. + output_pattern: null # Regular expression for parsing keywords. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} + - extract_nickname_mapper: # Extract nickname relationship in the text. + api_model: 'gpt-4o' # API model name. + nickname_key: '__dj__nickname__' # The field name to store the nickname relationship. + api_endpoint: null # URL endpoint for the API. + response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'. + system_prompt: null # System prompt for the task. + input_template: null # Template for building the model input. + output_pattern: null # Regular expression for parsing model output. + try_num: 3 # The number of retry attempts when there is an API call error or output parsing error. + drop_text: false # If drop the text in the output. + model_params: {} # Parameters for initializing the API model. + sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - fix_unicode_mapper: # fix unicode errors in text. - generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples. hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs. @@ -87,14 +162,14 @@ process: qa_pair_template: null # Template for formatting a single QA pair within each example. output_pattern: null # Regular expression pattern to extract questions and answers from model response. enable_vllm: false # Whether to use vllm for inference acceleration. - model_params: null # Parameters for initializing the model. + model_params: {} # Parameters for initializing the model. sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - generate_qa_from_text_mapper: # mapper to generate question and answer pairs from text. hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # Model name on huggingface to generate question and answer pairs. output_pattern: null # Regular expression pattern to extract questions and answers from model response. enable_vllm: false # Whether to use vllm for inference acceleration. - model_params: null # Parameters for initializing the model. - sampling_params: null # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} + model_params: {} # Parameters for initializing the model. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] @@ -163,8 +238,8 @@ process: qa_pair_template: null # Template for formatting the question and answer pair. output_pattern: null # Regular expression pattern to extract question and answer from model response. enable_vllm: false # whether to use vllm for inference acceleration. - model_params: null # Parameters for initializing the model. - sampling_params: null # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} + model_params: {} # Parameters for initializing the model. + sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95} - optimize_query_mapper: # optimize query in question-answer pairs. - optimize_response_mapper: # optimize response in question-answer pairs. - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. @@ -197,6 +272,12 @@ process: substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language + - text_chunk_mapper: # Split input text to chunks. + max_len: 2000 # Split text into multi texts with this max len if not None. + split_pattern: '\n\n' # Make sure split in this pattern if it is not None and force cut if the length exceeds max_len. + overlap_len: 200 # Overlap length of the split texts if not split in the split pattern. + tokenizer: 'gpt-4o' # The tokenizer name of Hugging Face tokenizers. The text length will be calculate as the token num if it is offerd. Otherwise, the text length equals to string length. + trust_remote_code: True # for loading huggingface model. - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only captioned sample in the final datasets and the original sample will be removed. It's True in default. mem_required: '30GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched diff --git a/data_juicer/ops/common/__init__.py b/data_juicer/ops/common/__init__.py index 1493b9ee5..cb2501f42 100644 --- a/data_juicer/ops/common/__init__.py +++ b/data_juicer/ops/common/__init__.py @@ -1,11 +1,13 @@ from .helper_func import (get_sentences_from_document, get_words_from_document, merge_on_whitespace_tab_newline, split_on_newline_tab_whitespace, split_on_whitespace, - strip, words_augmentation, words_refinement) + split_text_by_punctuation, strip, words_augmentation, + words_refinement) from .special_characters import SPECIAL_CHARACTERS __all__ = [ 'get_sentences_from_document', 'get_words_from_document', 'merge_on_whitespace_tab_newline', 'split_on_newline_tab_whitespace', - 'split_on_whitespace', 'strip', 'words_augmentation', 'words_refinement' + 'split_on_whitespace', 'strip', 'words_augmentation', 'words_refinement', + 'split_text_by_punctuation' ] diff --git a/data_juicer/ops/common/helper_func.py b/data_juicer/ops/common/helper_func.py index 58e43d36f..644188a7a 100644 --- a/data_juicer/ops/common/helper_func.py +++ b/data_juicer/ops/common/helper_func.py @@ -198,3 +198,19 @@ def get_sentences_from_document(document, model_func=None): else: sentences = document.splitlines() return '\n'.join(sentences) + + +def split_text_by_punctuation(text): + """ + Split text by any zh and en punctuation + + :param text: text to be splitted. + :return: sub texts splitted by any zh and en punctuation + """ + # any zh and en punctuation + punctuation_pattern = r'[\u3000-\u303f\uff00-\uffef]|[!"#$%&\'()*+,-./:;<=>?@[\\\]^_`{|}~]' # noqa: E501 + + result = re.split(punctuation_pattern, text) + result = [s.strip() for s in result if s.strip()] + + return result diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 25368d68b..41bf092a3 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -9,6 +9,11 @@ from .clean_ip_mapper import CleanIpMapper from .clean_links_mapper import CleanLinksMapper from .expand_macro_mapper import ExpandMacroMapper +from .extract_entity_attribute_mapper import ExtractEntityAttributeMapper +from .extract_entity_relation_mapper import ExtractEntityRelationMapper +from .extract_event_mapper import ExtractEventMapper +from .extract_keyword_mapper import ExtractKeywordMapper +from .extract_nickname_mapper import ExtractNicknameMapper from .fix_unicode_mapper import FixUnicodeMapper from .generate_qa_from_examples_mapper import GenerateQAFromExamplesMapper from .generate_qa_from_text_mapper import GenerateQAFromTextMapper @@ -37,6 +42,7 @@ RemoveWordsWithIncorrectSubstringsMapper from .replace_content_mapper import ReplaceContentMapper from .sentence_split_mapper import SentenceSplitMapper +from .text_chunk_mapper import TextChunkMapper from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper from .video_captioning_from_frames_mapper import \ VideoCaptioningFromFramesMapper @@ -59,18 +65,20 @@ 'AudioFFmpegWrappedMapper', 'CalibrateQAMapper', 'CalibrateQueryMapper', 'CalibrateResponseMapper', 'ChineseConvertMapper', 'CleanCopyrightMapper', 'CleanEmailMapper', 'CleanHtmlMapper', 'CleanIpMapper', 'CleanLinksMapper', - 'ExpandMacroMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', - 'GenerateQAFromTextMapper', 'ImageBlurMapper', - 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', - 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', - 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', - 'OptimizeQueryMapper', 'OptimizeResponseMapper', + 'ExpandMacroMapper', 'ExtractEntityAttributeMapper', + 'ExtractEntityRelationMapper', 'ExtractEventMapper', + 'ExtractKeywordMapper', 'ExtractNicknameMapper', 'FixUnicodeMapper', + 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', + 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', + 'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper', + 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', + 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SentenceSplitMapper', 'VideoCaptioningFromAudioMapper', + 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', diff --git a/data_juicer/ops/mapper/calibrate_qa_mapper.py b/data_juicer/ops/mapper/calibrate_qa_mapper.py index 625443f53..69b860e33 100644 --- a/data_juicer/ops/mapper/calibrate_qa_mapper.py +++ b/data_juicer/ops/mapper/calibrate_qa_mapper.py @@ -1,6 +1,9 @@ import re from typing import Dict, Optional +from loguru import logger +from pydantic import PositiveInt + from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper from data_juicer.utils.model_utils import get_model, prepare_model @@ -30,21 +33,22 @@ class CalibrateQAMapper(Mapper): def __init__(self, api_model: str = 'gpt-4o', *, - api_url: Optional[str] = None, + api_endpoint: Optional[str] = None, response_path: Optional[str] = None, system_prompt: Optional[str] = None, input_template: Optional[str] = None, reference_template: Optional[str] = None, qa_pair_template: Optional[str] = None, output_pattern: Optional[str] = None, - model_params: Optional[Dict] = None, - sampling_params: Optional[Dict] = None, + try_num: PositiveInt = 3, + model_params: Dict = {}, + sampling_params: Dict = {}, **kwargs): """ Initialization method. :param api_model: API model name. - :param api_url: URL endpoint for the API. + :param api_endpoint: URL endpoint for the API. :param response_path: Path to extract content from the API response. Defaults to 'choices.0.message.content'. :param system_prompt: System prompt for the calibration task. @@ -54,6 +58,7 @@ def __init__(self, :param output_pattern: Regular expression for parsing model output. :param model_params: Parameters for initializing the API model. :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} :param kwargs: Extra keyword arguments. """ super().__init__(**kwargs) @@ -65,15 +70,17 @@ def __init__(self, self.qa_pair_template = qa_pair_template or \ self.DEFAULT_QA_PAIR_TEMPLATE self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN - self.sampling_params = sampling_params or {} - model_params = model_params or {} + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', model=api_model, - url=api_url, + endpoint=api_endpoint, response_path=response_path, **model_params) + self.try_num = try_num + def build_input(self, sample): reference = self.reference_template.format(sample[self.text_key]) qa_pair = self.qa_pair_template.format(sample[self.query_key], @@ -89,7 +96,7 @@ def parse_output(self, raw_output): else: return None, None - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): client = get_model(self.model_key, rank=rank) messages = [{ @@ -99,9 +106,15 @@ def process_single(self, sample=None, rank=None): 'role': 'user', 'content': self.build_input(sample) }] - output = client(messages, **self.sampling_params) - - parsed_q, parsed_a = self.parse_output(output) + parsed_q, parsed_a = None, None + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + parsed_q, parsed_a = self.parse_output(output) + if parsed_q or parsed_a: + break + except Exception as e: + logger.warning(f'Exception: {e}') if parsed_q: sample[self.query_key] = parsed_q if parsed_a: diff --git a/data_juicer/ops/mapper/extract_entity_attribute_mapper.py b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py new file mode 100644 index 000000000..1fab935f9 --- /dev/null +++ b/data_juicer/ops/mapper/extract_entity_attribute_mapper.py @@ -0,0 +1,199 @@ +import re +from itertools import chain +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'extract_entity_attribute_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractEntityAttributeMapper(Mapper): + """ + Extract attributes for given entities from the text + """ + + _batched_op = True + + DEFAULT_SYSTEM_PROMPT_TEMPLATE = ( + '给定一段文本,从文本中总结{entity}的{attribute},并且从原文摘录最能说明该{attribute}的代表性示例。\n' + '要求:\n' + '- 摘录的示例应该简短。\n' + '- 遵循如下的回复格式:\n' + '## {attribute}:\n' + '{entity}的{attribute}描述...\n' + '### 代表性示例1:\n' + '说明{entity}该{attribute}的原文摘录1...\n' + '### 代表性示例2:\n' + '说明{entity}该{attribute}的原文摘录2...\n' + '...\n') + + DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' + DEFAULT_ATTR_PATTERN_TEMPLATE = r'\#\#\s*{attribute}:\s*(.*?)(?=\#\#\#|\Z)' + DEFAULT_DEMON_PATTERN = r'\#\#\#\s*代表性示例(\d+):\s*(.*?)(?=\#\#\#|\Z)' + + def __init__(self, + query_entities: List[str] = [], + query_attributes: List[str] = [], + api_model: str = 'gpt-4o', + *, + entity_key: str = Fields.main_entity, + attribute_key: str = Fields.attribute, + attribute_desc_key: str = Fields.attribute_description, + support_text_key: str = Fields.attribute_support_text, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt_template: Optional[str] = None, + input_template: Optional[str] = None, + attr_pattern_template: Optional[str] = None, + demo_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param query_entities: Entity list to be queried. + :param query_attributes: Attribute list to be queried. + :param api_model: API model name. + :param entity_key: The field name to store the given main entity for + attribute extraction. It's "__dj__entity__" in default. + :param entity_attribute_key: The field name to store the given + attribute to be extracted. It's "__dj__attribute__" in default. + :param attribute_desc_key: The field name to store the extracted + attribute description. It's "__dj__attribute_description__" in + default. + :param support_text_key: The field name to store the attribute + support text extracted from the raw text. It's + "__dj__support_text__" in default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt_template: System prompt template for the + task. Need to be specified by given entity and attribute. + :param input_template: Template for building the model input. + :param attr_pattern_template: Pattern for parsing the attribute from + output. Need to be specified by given attribute. + :param: demo_pattern: Pattern for parsing the demonstraction from + output to support the attribute. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.query_entities = query_entities + self.query_attributes = query_attributes + + self.entity_key = entity_key + self.attribute_key = attribute_key + self.attribute_desc_key = attribute_desc_key + self.support_text_key = support_text_key + + self.system_prompt_template = system_prompt_template \ + or self.DEFAULT_SYSTEM_PROMPT_TEMPLATE + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.attr_pattern_template = attr_pattern_template \ + or self.DEFAULT_ATTR_PATTERN_TEMPLATE + self.demo_pattern = demo_pattern or self.DEFAULT_DEMON_PATTERN + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output, attribute_name): + + attribute_pattern = self.attr_pattern_template.format( + attribute=attribute_name) + pattern = re.compile(attribute_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + if matches: + attribute = matches[0].strip() + else: + attribute = '' + + pattern = re.compile(self.demo_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + demos = [demo.strip() for _, demo in matches if demo.strip()] + + return attribute, demos + + def _process_single_sample(self, text='', rank=None): + client = get_model(self.model_key, rank=rank) + + entities, attributes, descs, demo_lists = [], [], [], [] + for entity in self.query_entities: + for attribute in self.query_attributes: + system_prompt = self.system_prompt_template.format( + entity=entity, attribute=attribute) + input_prompt = self.input_template.format(text=text) + messages = [{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + + desc, demos = '', [] + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + desc, demos = self.parse_output(output, attribute) + if desc and len(demos) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + entities.append(entity) + attributes.append(attribute) + descs.append(desc) + demo_lists.append(demos) + + return entities, attributes, descs, demo_lists + + def process_batched(self, samples, rank=None): + + sample_num = len(samples[self.text_key]) + + entities, attributes, descs, demo_lists = [], [], [], [] + for text in samples[self.text_key]: + res = self._process_single_sample(text, rank=rank) + cur_ents, cur_attrs, cur_descs, cur_demos = res + entities.append(cur_ents) + attributes.append(cur_attrs) + descs.append(cur_descs) + demo_lists.append(cur_demos) + + if self.drop_text: + samples.pop(self.text_key) + + for key in samples: + samples[key] = [[samples[key][i]] * len(descs[i]) + for i in range(sample_num)] + samples[self.entity_key] = entities + samples[self.attribute_key] = attributes + samples[self.attribute_desc_key] = descs + samples[self.support_text_key] = demo_lists + + for key in samples: + samples[key] = list(chain(*samples[key])) + + return samples diff --git a/data_juicer/ops/mapper/extract_entity_relation_mapper.py b/data_juicer/ops/mapper/extract_entity_relation_mapper.py new file mode 100644 index 000000000..4b026f2a4 --- /dev/null +++ b/data_juicer/ops/mapper/extract_entity_relation_mapper.py @@ -0,0 +1,333 @@ +# This OP is modified from light RAG +# https://github.com/HKUDS/LightRAG + +# flake8: noqa: E501 + +import re +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.common_utils import is_float +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +OP_NAME = 'extract_entity_relation_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractEntityRelationMapper(Mapper): + """ + Extract entities and relations in the text for knowledge graph. + """ + + DEFAULT_PROMPT_TEMPLATE = """-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity +- entity_type: One of the following types: [{entity_types}] +- entity_description: Comprehensive description of the entity's attributes and activities +Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter} + +2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. +For each pair of related entities, extract the following information: +- source_entity: name of the source entity, as identified in step 1 +- target_entity: name of the target entity, as identified in step 1 +- relationship_description: explanation as to why you think the source entity and the target entity are related to each other +- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity +- relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details +Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + +3. Return output in the language of the given text as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Entity_types: [person, technology, mission, organization, location] +Text: +``` +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +``` +################ +Output: +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter} +("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter} +("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter} +("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter} +("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter} +############################# +Example 2: + +Entity_types: [人物, 技术, 任务, 组织, 地点] +Text: +``` +他们不再是单纯的执行者;他们已成为某个超越星辰与条纹的领域的信息守护者。这一使命的提升不能被规则和既定协议所束缚——它需要一种新的视角,一种新的决心。 + +随着与华盛顿的通讯在背景中嗡嗡作响,对话中的紧张情绪通过嘟嘟声和静电噪音贯穿始终。团队站立着,一股不祥的气息笼罩着他们。显然,他们在接下来几个小时内做出的决定可能会重新定义人类在宇宙中的位置,或者将他们置于无知和潜在危险之中。 + +随着与星辰的联系变得更加牢固,小组开始处理逐渐成形的警告,从被动接受者转变为积极参与者。梅瑟后来的直觉占据了上风——团队的任务已经演变,不再仅仅是观察和报告,而是互动和准备。一场蜕变已经开始,而“杜尔塞行动”则以他们大胆的新频率震动,这种基调不是由世俗设定的 +``` +############# +Output: +("entity"{tuple_delimiter}"华盛顿"{tuple_delimiter}"地点"{tuple_delimiter}"华盛顿是正在接收通讯的地方,表明其在决策过程中的重要性。"){record_delimiter} +("entity"{tuple_delimiter}"杜尔塞行动"{tuple_delimiter}"任务"{tuple_delimiter}"杜尔塞行动被描述为一项已演变为互动和准备的任务,显示出目标和活动的重大转变。"){record_delimiter} +("entity"{tuple_delimiter}"团队"{tuple_delimiter}"组织"{tuple_delimiter}"团队被描绘成一群从被动观察者转变为积极参与者的人,展示了他们角色的动态变化。"){record_delimiter} +("relationship"{tuple_delimiter}"团队"{tuple_delimiter}"华盛顿"{tuple_delimiter}"团队收到来自华盛顿的通讯,这影响了他们的决策过程。"{tuple_delimiter}"决策、外部影响"{tuple_delimiter}7){record_delimiter} +("relationship"{tuple_delimiter}"团队"{tuple_delimiter}"杜尔塞行动"{tuple_delimiter}"团队直接参与杜尔塞行动,执行其演变后的目标和活动。"{tuple_delimiter}"任务演变、积极参与"{tuple_delimiter}9){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +``` +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +``` +############# +Output: +("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter} +("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter} +("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter} +("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter} +("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter} +("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter} +("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter} +("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter} +("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter} +############################# +-Real Data- +###################### +Entity_types: [{entity_types}] +Text: +``` +{input_text} +``` +###################### +Output: +""" + DEFAULT_CONTINUE_PROMPT = 'MANY entities were missed in the last extraction. Add them below using the same format:\n' + DEFAULT_IF_LOOP_PROMPT = 'It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.\n' + + DEFAULT_ENTITY_TYPES = ['organization', 'person', 'geo', 'event'] + DEFAULT_TUPLE_DELIMITER = '<|>' + DEFAULT_RECORD_DELIMITER = '##' + DEFAULT_COMPLETION_DELIMITER = '<|COMPLETE|>' + DEFAULT_ENTITY_PATTERN = r'\("entity"(.*?)\)' + DEFAULT_RELATION_PATTERN = r'\("relationship"(.*?)\)' + + def __init__(self, + api_model: str = 'gpt-4o', + entity_types: List[str] = None, + *, + entity_key: str = Fields.entity, + relation_key: str = Fields.relation, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + prompt_template: Optional[str] = None, + tuple_delimiter: Optional[str] = None, + record_delimiter: Optional[str] = None, + completion_delimiter: Optional[str] = None, + max_gleaning: NonNegativeInt = 1, + continue_prompt: Optional[str] = None, + if_loop_prompt: Optional[str] = None, + entity_pattern: Optional[str] = None, + relation_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param entity_types: Pre-defined entity types for knowledge graph. + :param entity_key: The field name to store the entities. It's + "__dj__entity__" in default. + :param relation_key: The field name to store the relations between + entities. It's "__dj__relation__" in default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param prompt_template: The template of input prompt. + :param tuple_delimiter: Delimiter to separate items in outputs. + :param record_delimiter: Delimiter to separate records in outputs. + :param completion_delimiter: To mark the end of the output. + :param max_gleaning: the extra max num to call LLM to glean entities + and relations. + :param continue_prompt: the prompt for gleaning entities and + relations. + :param if_loop_prompt: the prompt to determine whether to stop + gleaning. + :param entity_pattern: Regular expression for parsing entity record. + :param relation_pattern: Regular expression for parsing relation + record. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.entity_types = entity_types or self.DEFAULT_ENTITY_TYPES + + self.entity_key = entity_key + self.relation_key = relation_key + + self.prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE + self.tuple_delimiter = tuple_delimiter or self.DEFAULT_TUPLE_DELIMITER + self.record_delimiter = record_delimiter or self.DEFAULT_RECORD_DELIMITER + self.completion_delimiter = completion_delimiter or \ + self.DEFAULT_COMPLETION_DELIMITER + self.max_gleaning = max_gleaning + self.continue_prompt = continue_prompt or self.DEFAULT_CONTINUE_PROMPT + self.if_loop_prompt = if_loop_prompt or self.DEFAULT_IF_LOOP_PROMPT + self.entity_pattern = entity_pattern or self.DEFAULT_ENTITY_PATTERN + self.relation_pattern = relation_pattern or \ + self.DEFAULT_RELATION_PATTERN + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + entities, relations = [], [] + + def remove_outer_quotes(text): + if not text: + return text + if (text[0] == '"' and text[-1] == '"') or (text[0] == "'" + and text[-1] == "'"): + return text[1:-1] + else: + return text + + def split_by_tuple_delimiter(record): + items = record.split(self.tuple_delimiter) + items = [remove_outer_quotes(item.strip()) for item in items] + items = [item.strip() for item in items if item.strip()] + return tuple(items) + + entity_pattern = re.compile(self.entity_pattern, + re.VERBOSE | re.DOTALL) + matches = entity_pattern.findall(raw_output) + for record in matches: + items = split_by_tuple_delimiter(record) + if len(items) != 3: + continue + entities.append(items) + entities = list(set(entities)) + entities = [{ + Fields.entity_name: e[0], + Fields.entity_type: e[1], + Fields.entity_description: e[2] + } for e in entities] + + relation_pattern = re.compile(self.relation_pattern, + re.VERBOSE | re.DOTALL) + matches = relation_pattern.findall(raw_output) + for record in matches: + items = split_by_tuple_delimiter(record) + if len(items) != 5 or not is_float(items[4]): + continue + relations.append(items) + relations = list(set(relations)) + relations = [{ + Fields.source_entity: r[0], + Fields.target_entity: r[1], + Fields.relation_description: r[2], + Fields.relation_keywords: split_text_by_punctuation(r[3]), + Fields.relation_strength: float(r[4]) + } for r in relations] + + return entities, relations + + def add_message(self, messages, role, content): + return messages + [{'role': role, 'content': content}] + + def light_rag_extraction(self, messages, rank=None): + client = get_model(self.model_key, rank=rank) + + final_result = client(messages, **self.sampling_params) + history = self.add_message(messages, 'assistant', final_result) + + for glean_index in range(self.max_gleaning): + messages = self.add_message(history, 'user', self.continue_prompt) + glean_result = client(messages, **self.sampling_params) + history = self.add_message(messages, 'assistant', glean_result) + final_result += glean_result + + if glean_index == self.max_gleaning - 1: + break + + messages = self.add_message(history, 'user', self.if_loop_prompt) + if_loop_result = client(messages, **self.sampling_params) + if_loop_result = if_loop_result.strip().strip('"').strip( + "'").lower() + if if_loop_result != 'yes': + break + + return final_result + + def process_single(self, sample, rank=None): + + input_prompt = self.prompt_template.format( + tuple_delimiter=self.tuple_delimiter, + record_delimiter=self.record_delimiter, + completion_delimiter=self.completion_delimiter, + entity_types=', '.join(self.entity_types), + input_text=sample[self.text_key]) + messages = [{'role': 'user', 'content': input_prompt}] + + entities, relations = [], [] + for i in range(self.try_num): + try: + result = self.light_rag_extraction(messages, rank=rank) + entities, relations = self.parse_output(result) + if len(entities) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.entity_key] = entities + sample[self.relation_key] = relations + return sample diff --git a/data_juicer/ops/mapper/extract_event_mapper.py b/data_juicer/ops/mapper/extract_event_mapper.py new file mode 100644 index 000000000..208684b2c --- /dev/null +++ b/data_juicer/ops/mapper/extract_event_mapper.py @@ -0,0 +1,171 @@ +import re +from itertools import chain +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +OP_NAME = 'extract_event_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractEventMapper(Mapper): + """ + Extract events and relevant characters in the text + """ + + _batched_op = True + + DEFAULT_SYSTEM_PROMPT = ('给定一段文本,对文本的情节进行分点总结,并抽取与情节相关的人物。\n' + '要求:\n' + '- 尽量不要遗漏内容,不要添加文本中没有的情节,符合原文事实\n' + '- 联系上下文说明前因后果,但仍然需要符合事实\n' + '- 不要包含主观看法\n' + '- 注意要尽可能保留文本的专有名词\n' + '- 注意相关人物需要在对应情节中出现\n' + '- 只抽取情节中的主要人物,不要遗漏情节的主要人物\n' + '- 总结格式如下:\n' + '### 情节1:\n' + '- **情节描述**: ...\n' + '- **相关人物**:人物1,人物2,人物3,...\n' + '### 情节2:\n' + '- **情节描述**: ...\n' + '- **相关人物**:人物1,人物2,...\n' + '### 情节3:\n' + '- **情节描述**: ...\n' + '- **相关人物**:人物1,...\n' + '...\n') + DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' + DEFAULT_OUTPUT_PATTERN = r""" + \#\#\#\s*情节(\d+):\s* + -\s*\*\*情节描述\*\*\s*:\s*(.*?)\s* + -\s*\*\*相关人物\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) + """ + + def __init__(self, + api_model: str = 'gpt-4o', + *, + event_desc_key: str = Fields.event_description, + relevant_char_key: str = Fields.relevant_characters, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param event_desc_key: The field name to store the event descriptions. + It's "__dj__event_description__" in default. + :param relevant_char_key: The field name to store the relevant + characters to the events. It's "__dj__relevant_characters__" in + default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param input_template: Template for building the model input. + :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.event_desc_key = event_desc_key + self.relevant_char_key = relevant_char_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + + event_list, character_list = [], [] + + for match in matches: + _, desc, chars = match + chars = split_text_by_punctuation(chars) + if len(chars) > 0: + event_list.append(desc) + character_list.append(chars) + + return event_list, character_list + + def _process_single_sample(self, text='', rank=None): + client = get_model(self.model_key, rank=rank) + + input_prompt = self.input_template.format(text=text) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + + event_list, character_list = [], [] + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + event_list, character_list = self.parse_output(output) + if len(event_list) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + return event_list, character_list + + def process_batched(self, samples, rank=None): + + sample_num = len(samples[self.text_key]) + + events, characters = [], [] + for text in samples[self.text_key]: + cur_events, cur_characters = self._process_single_sample(text, + rank=rank) + events.append(cur_events) + characters.append(cur_characters) + + if self.drop_text: + samples.pop(self.text_key) + + for key in samples: + samples[key] = [[samples[key][i]] * len(events[i]) + for i in range(sample_num)] + samples[self.event_desc_key] = events + samples[self.relevant_char_key] = characters + + for key in samples: + samples[key] = list(chain(*samples[key])) + + return samples diff --git a/data_juicer/ops/mapper/extract_keyword_mapper.py b/data_juicer/ops/mapper/extract_keyword_mapper.py new file mode 100644 index 000000000..cb1814768 --- /dev/null +++ b/data_juicer/ops/mapper/extract_keyword_mapper.py @@ -0,0 +1,189 @@ +# flake8: noqa: E501 + +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..common import split_text_by_punctuation + +OP_NAME = 'extract_keyword_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractKeywordMapper(Mapper): + """ + Generate keywords for the text + """ + + # This prompt is modified from light RAG + # https://github.com/HKUDS/LightRAG + DEFAULT_PROMPT_TEMPLATE = """-Goal- +Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. + +-Steps- +1. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document. +Format the content-level key words as ("content_keywords" ) + +3. Return output in the language of the given text. + +4. When finished, output {completion_delimiter} + +###################### +-Examples- +###################### +Example 1: + +Text: +``` +while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order. + +Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.” + +The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce. + +It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths +``` +################ +Output: +("content_keywords" "power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter} +############################# +Example 2: + +Text: +``` +他们不再是单纯的执行者;他们已成为某个超越星辰与条纹的领域的信息守护者。这一使命的提升不能被规则和既定协议所束缚——它需要一种新的视角,一种新的决心。 + +随着与华盛顿的通讯在背景中嗡嗡作响,对话中的紧张情绪通过嘟嘟声和静电噪音贯穿始终。团队站立着,一股不祥的气息笼罩着他们。显然,他们在接下来几个小时内做出的决定可能会重新定义人类在宇宙中的位置,或者将他们置于无知和潜在危险之中。 + +随着与星辰的联系变得更加牢固,小组开始处理逐渐成形的警告,从被动接受者转变为积极参与者。梅瑟后来的直觉占据了上风——团队的任务已经演变,不再仅仅是观察和报告,而是互动和准备。一场蜕变已经开始,而“杜尔塞行动”则以他们大胆的新频率震动,这种基调不是由世俗设定的 +``` +############# +Output: +("content_keywords" "任务演变, 决策制定, 积极参与, 宇宙意义"){completion_delimiter} +############################# +Example 3: + +Entity_types: [person, role, technology, organization, event, location, concept] +Text: +``` +their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data. + +"It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning." + +Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back." + +Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history. + +The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation +``` +############# +Output: +("content_keywords" "first contact, control, communication, cosmic significance"){completion_delimiter} +-Real Data- +###################### +Text: +``` +{input_text} +``` +###################### +Output: +""" + + DEFAULT_COMPLETION_DELIMITER = '<|COMPLETE|>' + DEFAULT_OUTPUT_PATTERN = r'\("content_keywords"(.*?)\)' + + def __init__(self, + api_model: str = 'gpt-4o', + *, + keyword_key: str = Fields.keyword, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + prompt_template: Optional[str] = None, + completion_delimiter: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param keyword_key: The field name to store the keywords. It's + "__dj__keyword__" in default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param prompt_template: The template of input prompt. + :param completion_delimiter: To mark the end of the output. + :param output_pattern: Regular expression for parsing keywords. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.keyword_key = keyword_key + + self.prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE + self.completion_delimiter = completion_delimiter or \ + self.DEFAULT_COMPLETION_DELIMITER + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + keywords = [] + + output_pattern = re.compile(self.output_pattern, + re.VERBOSE | re.DOTALL) + matches = output_pattern.findall(raw_output) + for record in matches: + items = split_text_by_punctuation(record) + keywords.append(items) + + return keywords + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + input_prompt = self.prompt_template.format( + completion_delimiter=self.completion_delimiter, + input_text=sample[self.text_key]) + messages = [{'role': 'user', 'content': input_prompt}] + + keywords = [] + for i in range(self.try_num): + try: + result = client(messages, **self.sampling_params) + keywords = self.parse_output(result) + if len(keywords) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.keyword_key] = keywords + if self.drop_text: + sample.pop(self.text_key) + + return sample diff --git a/data_juicer/ops/mapper/extract_nickname_mapper.py b/data_juicer/ops/mapper/extract_nickname_mapper.py new file mode 100644 index 000000000..b11cbab57 --- /dev/null +++ b/data_juicer/ops/mapper/extract_nickname_mapper.py @@ -0,0 +1,159 @@ +import re +from typing import Dict, Optional + +from loguru import logger +from pydantic import PositiveInt + +from data_juicer.ops.base_op import OPERATORS, UNFORKABLE, Mapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.model_utils import get_model, prepare_model + +OP_NAME = 'extract_nickname_mapper' + + +# TODO: LLM-based inference. +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class ExtractNicknameMapper(Mapper): + """ + Extract nickname relationship in the text. + """ + + DEFAULT_SYSTEM_PROMPT = ('给定你一段文本,你的任务是将人物之间的称呼方式(昵称)提取出来。\n' + '要求:\n' + '- 需要给出说话人对被称呼人的称呼,不要搞反了。\n' + '- 相同的说话人和被称呼人最多给出一个最常用的称呼。\n' + '- 请不要输出互相没有昵称的称呼方式。\n' + '- 输出格式如下:\n' + '```\n' + '### 称呼方式1\n' + '- **说话人**:...\n' + '- **被称呼人**:...\n' + '- **...对...的昵称**:...\n' + '### 称呼方式2\n' + '- **说话人**:...\n' + '- **被称呼人**:...\n' + '- **...对...的昵称**:...\n' + '### 称呼方式3\n' + '- **说话人**:...\n' + '- **被称呼人**:...\n' + '- **...对...的昵称**:...\n' + '...\n' + '```\n') + DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n' + DEFAULT_OUTPUT_PATTERN = r""" + \#\#\#\s*称呼方式(\d+)\s* + -\s*\*\*说话人\*\*\s*:\s*(.*?)\s* + -\s*\*\*被称呼人\*\*\s*:\s*(.*?)\s* + -\s*\*\*(.*?)对(.*?)的昵称\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) # for double check + """ + + def __init__(self, + api_model: str = 'gpt-4o', + *, + nickname_key: str = Fields.nickname, + api_endpoint: Optional[str] = None, + response_path: Optional[str] = None, + system_prompt: Optional[str] = None, + input_template: Optional[str] = None, + output_pattern: Optional[str] = None, + try_num: PositiveInt = 3, + drop_text: bool = False, + model_params: Dict = {}, + sampling_params: Dict = {}, + **kwargs): + """ + Initialization method. + :param api_model: API model name. + :param nickname_key: The field name to store the nickname + relationship. It's "__dj__nickname__" in default. + :param api_endpoint: URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'choices.0.message.content'. + :param system_prompt: System prompt for the task. + :param input_template: Template for building the model input. + :param output_pattern: Regular expression for parsing model output. + :param try_num: The number of retry attempts when there is an API + call error or output parsing error. + :param drop_text: If drop the text in the output. + :param model_params: Parameters for initializing the API model. + :param sampling_params: Extra parameters passed to the API call. + e.g {'temperature': 0.9, 'top_p': 0.95} + :param kwargs: Extra keyword arguments. + """ + super().__init__(**kwargs) + + self.nickname_key = nickname_key + + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE + self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN + + self.sampling_params = sampling_params + self.model_key = prepare_model(model_type='api', + model=api_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params) + + self.try_num = try_num + self.drop_text = drop_text + + def parse_output(self, raw_output): + pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL) + matches = pattern.findall(raw_output) + + nickname_relations = [] + + for match in matches: + _, role1, role2, role1_tmp, role2_tmp, nickname = match + # for double check + if role1.strip() != role1_tmp.strip() or role2.strip( + ) != role2_tmp.strip(): + continue + role1 = role1.strip() + role2 = role2.strip() + nickname = nickname.strip() + # is name but not nickname + if role2 == nickname: + continue + if role1 and role2 and nickname: + nickname_relations.append((role1, role2, nickname)) + nickname_relations = list(set(nickname_relations)) + + nickname_relations = [{ + Fields.source_entity: nr[0], + Fields.target_entity: nr[1], + Fields.relation_description: nr[2], + Fields.relation_keywords: ['nickname'], + Fields.relation_strength: None + } for nr in nickname_relations] + + return nickname_relations + + def process_single(self, sample, rank=None): + client = get_model(self.model_key, rank=rank) + + input_prompt = self.input_template.format(text=sample[self.text_key]) + messages = [{ + 'role': 'system', + 'content': self.system_prompt + }, { + 'role': 'user', + 'content': input_prompt + }] + nickname_relations = [] + for i in range(self.try_num): + try: + output = client(messages, **self.sampling_params) + nickname_relations = self.parse_output(output) + if len(nickname_relations) > 0: + break + except Exception as e: + logger.warning(f'Exception: {e}') + + sample[self.nickname_key] = nickname_relations + if self.drop_text: + sample.pop(self.text_key) + + return sample diff --git a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py index 4d7ff01bd..6f5ad7dab 100644 --- a/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py +++ b/data_juicer/ops/mapper/generate_qa_from_examples_mapper.py @@ -210,7 +210,7 @@ def parse_output(self, raw_output): output_qa_pairs.append((question.strip(), answer.strip())) return output_qa_pairs - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) random_qa_samples = random.sample(self.seed_qa_samples, diff --git a/data_juicer/ops/mapper/optimize_qa_mapper.py b/data_juicer/ops/mapper/optimize_qa_mapper.py index cd5a0aba7..3563a112b 100644 --- a/data_juicer/ops/mapper/optimize_qa_mapper.py +++ b/data_juicer/ops/mapper/optimize_qa_mapper.py @@ -113,7 +113,7 @@ def parse_output(self, raw_output): else: return None, None - def process_single(self, sample=None, rank=None): + def process_single(self, sample, rank=None): model, _ = get_model(self.model_key, rank, self.use_cuda()) input_prompt = self.build_input(sample) diff --git a/data_juicer/ops/mapper/text_chunk_mapper.py b/data_juicer/ops/mapper/text_chunk_mapper.py new file mode 100644 index 000000000..d3b9990ef --- /dev/null +++ b/data_juicer/ops/mapper/text_chunk_mapper.py @@ -0,0 +1,136 @@ +import re +from itertools import chain +from typing import Union + +from pydantic import NonNegativeInt, PositiveInt + +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'text_chunk_mapper' + + +@OPERATORS.register_module(OP_NAME) +class TextChunkMapper(Mapper): + """Split input text to chunks.""" + + _batched_op = True + + def __init__(self, + max_len: Union[PositiveInt, None] = None, + split_pattern: Union[str, None] = r'\n\n', + overlap_len: NonNegativeInt = 0, + tokenizer: Union[str, None] = None, + trust_remote_code: bool = False, + *args, + **kwargs): + """ + Initialization method. + + :param max_len: Split text into multi texts with this max len if not + None. + :param split_pattern: Make sure split in this pattern if it is not None + and force cut if the length exceeds max_len. + :param overlap_len: Overlap length of the split texts if not split in + the split pattern. + :param tokenizer: The tokenizer name of Hugging Face tokenizers. + The text length will be calculate as the token num if it is offerd. + Otherwise, the text length equals to string length. Support + tiktoken tokenizer (such as gpt-4o), dashscope tokenizer (such as + qwen2.5-72b-instruct) and huggingface tokenizer. + :trust_remote_code: for loading huggingface model + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + if max_len is None and split_pattern is None: + raise ValueError('max_len and split_pattern cannot be both None') + + if max_len is not None and overlap_len >= max_len: + raise ValueError('overlap_len must be less than max_len') + + self.max_len = max_len + self.overlap_len = overlap_len + self.split_pattern = split_pattern + self.tokenizer_name = tokenizer + if tokenizer is not None: + self.model_key = prepare_model( + model_type='api', + model=tokenizer, + return_processor=True, + processor_config={'trust_remote_code': trust_remote_code}) + + def recursively_chunk(self, text): + if self.tokenizer_name is not None: + _, tokenizer = get_model(self.model_key) + tokens = tokenizer.encode(text) + total_len = len(tokens) + sub_text = tokenizer.decode(tokens[:self.max_len]) + else: + total_len = len(text) + sub_text = text[:self.max_len] + + if total_len <= self.max_len: + return [text] + + matches = list(re.finditer(self.split_pattern, sub_text)) + if not matches: + cur_text = sub_text + if self.tokenizer_name is not None: + left_text = tokenizer.decode(tokens[self.max_len - + self.overlap_len:]) + else: + left_text = text[self.max_len - self.overlap_len:] + else: + last_match = matches[-1] + cur_text = sub_text[:last_match.start()] + left_text = text[last_match.end():] + + return [cur_text] + self.recursively_chunk(left_text) + + def get_text_chunks(self, text, rank=None): + + if self.split_pattern is not None and self.max_len is None: + chunks = re.split(f'({self.split_pattern})', text) + chunks = [t for t in chunks if t.strip()] + elif self.split_pattern is None and self.max_len is not None: + tokens = text + total_len = len(text) + if self.tokenizer_name is not None: + _, tokenizer = get_model(self.model_key, rank=rank) + tokens = tokenizer.encode(text) + total_len = len(tokens) + if total_len <= self.max_len: + return [text] + chunks = [] + for start in range(0, total_len, self.max_len - self.overlap_len): + cur = tokens[start:start + self.max_len] + if self.tokenizer_name is not None: + cur = tokenizer.decode(cur) + chunks.append(cur) + else: + chunks = self.recursively_chunk(text) + + return chunks + + def process_batched(self, samples, rank=None): + + sample_num = len(samples[self.text_key]) + + samples[self.text_key] = [ + self.get_text_chunks(text, rank=rank) + for text in samples[self.text_key] + ] + + for key in samples: + if key != self.text_key: + samples[key] = [[samples[key][i]] * + len(samples[self.text_key][i]) + for i in range(sample_num)] + + for key in samples: + samples[key] = list(chain(*samples[key])) + + return samples diff --git a/data_juicer/utils/common_utils.py b/data_juicer/utils/common_utils.py index 5bd336b9b..959831c5d 100644 --- a/data_juicer/utils/common_utils.py +++ b/data_juicer/utils/common_utils.py @@ -19,3 +19,11 @@ def stats_to_number(s, reverse=True): return -sys.maxsize else: return sys.maxsize + + +def is_float(s): + try: + float(s) + return True + except Exception: + return False diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index d3634bc2f..ab88035b9 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -28,6 +28,43 @@ class Fields(object): # the name of directory to store the produced multimodal data multimodal_data_output_dir = DEFAULT_PREFIX + 'produced_data__' + # field names for info extraction + event_description = DEFAULT_PREFIX + 'event_description__' + # # a list of characters relevant to the event + relevant_characters = DEFAULT_PREFIX + 'relevant_characters__' + # # the given main entity for attribute extraction + main_entity = DEFAULT_PREFIX + 'main_entity__' + # # the given attribute to be extracted + attribute = DEFAULT_PREFIX + 'attribute__' + # # the extracted attribute description + attribute_description = DEFAULT_PREFIX + 'attribute_description__' + # # extract from raw data for support the attribute + attribute_support_text = DEFAULT_PREFIX + 'attribute_support_text__' + # # the nickname relationship + nickname = DEFAULT_PREFIX + 'nickname__' + # # the entity for knowledge graph + entity = DEFAULT_PREFIX + 'entity__' + # # # the name of entity + entity_name = DEFAULT_PREFIX + 'entity_name__' + # # # the type of entity + entity_type = DEFAULT_PREFIX + 'entity_type__' + # # # the description of entity + entity_description = DEFAULT_PREFIX + 'entity_entity_description__' + # # the relationship for knowledge graph + relation = DEFAULT_PREFIX + 'relation__' + # # # the source entity of the relation + source_entity = DEFAULT_PREFIX + 'relation_source_entity__' + # # # the target entity of the relation + target_entity = DEFAULT_PREFIX + 'relation_target_entity__' + # # # the description of the relation + relation_description = DEFAULT_PREFIX + 'relation_description__' + # # # the keywords of the relation + relation_keywords = DEFAULT_PREFIX + 'relation_keywords__' + # # # the strength of the relation + relation_strength = DEFAULT_PREFIX + 'relation_strength__' + # # the keyword in a text + keyword = DEFAULT_PREFIX + 'keyword__' + class StatsKeysMeta(type): """ diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 7f9687079..eb521e619 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -109,25 +109,31 @@ def check_model(model_name, force=False): class APIModel: - def __init__(self, model, url=None, response_path=None, **kwargs): + def __init__(self, model, endpoint=None, response_path=None, **kwargs): """ Initializes an instance of the APIModel class. - :param model: The model name to use for the API. - :param url: URL endpoint for the API. If relative, it will be joined - with base_url or the OPENAI_BASE_URL environment variable. Defaults - to '/chat/completions' for OpenAI compatibility. - :param response_path: Dot-separated path to extract the desired - response content. Defaults to 'choices.0.message.content' for + :param model: The name of the model to be used for making API + calls. This should correspond to a valid model identifier + recognized by the API server. + :param endpoint: The URL endpoint for the API. If provided as a + relative path, it will be appended to the base URL (defined by the + `OPENAI_BASE_URL` environment variable or through an additional + `base_url` parameter). Defaults to '/chat/completions' for OpenAI compatibility. - :param kwargs: Additional arguments to configure the OpenAI client. + :param response_path: A dot-separated string specifying the path to + extract the desired content from the API response. The default + value is 'choices.0.message.content', which corresponds to the + typical structure of an OpenAI API response. + :param kwargs: Additional keyword arguments for configuring the + internal OpenAI client. """ self.model = model - self.url = url or '/chat/completions' + self.endpoint = endpoint or '/chat/completions' self.response_path = response_path or 'choices.0.message.content' client_args = self._filter_arguments(openai.OpenAI, kwargs) - self.client = openai.OpenAI(**client_args) + self._client = openai.OpenAI(**client_args) def __call__(self, messages, **kwargs): """ @@ -150,11 +156,11 @@ def __call__(self, messages, **kwargs): stream_cls = openai.Stream[openai.types.chat.ChatCompletionChunk] try: - response = self.client.post(self.url, - body=body, - cast_to=httpx.Response, - stream=stream, - stream_cls=stream_cls) + response = self._client.post(self.endpoint, + body=body, + cast_to=httpx.Response, + stream=stream, + stream_cls=stream_cls) result = response.json() return self._nested_access(result, self.response_path) except Exception as e: @@ -205,37 +211,41 @@ def _filter_arguments(func, args_dict): def prepare_api_model(model, *, - url=None, + endpoint=None, response_path=None, return_processor=False, processor_config=None, **model_params): - """Creates a callable API model for interacting with OpenAI-compatible API. - The callable supports custom response parsing and works with proxy servers - that may be incompatible. - - :param model: The name of the model to interact with. - :param url: URL endpoint for the API. - :param response_path: The dot-separated path to extract desired content - from the API response. Defaults to 'choices.0.message.content'. + """ + Creates an instance of the APIModel for interacting with OpenAI-like APIs. + + :param model: The name of the model to be used for making API calls. + :param endpoint: The URL endpoint for the API. If provided as a relative + path, it will be appended to the base URL (defined by the + `OPENAI_BASE_URL` environment variable or through an additional + `base_url` parameter). By default, it is set to + '/chat/completions' for OpenAI compatibility. + :param response_path: A dot-separated string specifying the path to + extract desired content from the API response. The default value is + 'choices.0.message.content', which corresponds to the typical + structure of an OpenAI API response. :param return_processor: A boolean flag indicating whether to return a processor along with the model. The processor can be used for tasks like tokenization or encoding. Defaults to False. - :param processor_config: A dictionary containing configuration settings - for a specific processor from Hugging Face. It should include all - necessary parameters for initializing the processor. This parameter is - used only if `return_processor` is True. - :param model_params: Additional parameters to configure the API model. - :return: A tuple containing the callable API model object and optionally a - processor if `return_processor` is True. + :param processor_config: A dictionary containing configuration parameters + for initializing a Hugging Face processor. It is only relevant if + `return_processor` is set to True. + :param model_params: Additional parameters for configuring the API model. + :return: A callable APIModel instance, and optionally a processor + if `return_processor` is True. """ - model = APIModel(model=model, - url=url, - response_path=response_path, - **model_params) + client = APIModel(model=model, + endpoint=endpoint, + response_path=response_path, + **model_params) if not return_processor: - return model + return client def get_processor(): try: @@ -250,6 +260,13 @@ def get_processor(): except Exception: pass + try: + processor = transformers.AutoProcessor.from_pretrained( + pretrained_model_name_or_path=model, **processor_config) + return processor + except Exception: + pass + raise ValueError( 'Failed to initialize the processor. Please check the following:\n' # noqa: E501 "- For OpenAI models: Install 'tiktoken' via `pip install tiktoken`.\n" # noqa: E501 @@ -257,12 +274,13 @@ def get_processor(): "- For custom models: Use the 'processor_config' parameter to configure a Hugging Face processor." # noqa: E501 ) - if processor_config is not None: + if processor_config is not None \ + and 'pretrained_model_name_or_path' in processor_config: processor = transformers.AutoProcessor.from_pretrained( **processor_config) else: processor = get_processor() - return (model, processor) + return (client, processor) def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type, diff --git a/docs/Operators.md b/docs/Operators.md index 2a25c4847..7717ba434 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 52 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 58 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -67,6 +67,11 @@ All the specific operators are listed below, each featured with several capabili | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes IP addresses | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes links, such as those starting with http or ftp | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Expands macros usually defined at the top of TeX documents | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | +| extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract attributes for given entities from the text. | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | +| extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract entities and relations in the text for knowledge graph. | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | +| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract events and relevant characters in the text. | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | +| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Generate keywords for the text. | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | +| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Extract nickname relationship in the text. | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs based on examples. | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate question and answer pairs from text. | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -93,6 +98,7 @@ All the specific operators are listed below, each featured with several capabili | remove_words_with_incorrect_ substrings_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes words containing specified substrings | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | | sentence_split_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | Splits and reorganizes sentences according to semantics | [code](../data_juicer/ops/mapper/sentence_split_mapper.py) | [tests](../tests/ops/mapper/test_sentence_split_mapper.py) | +| text_chunk_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Split input text to chunks. | [code](../data_juicer/ops/mapper/text_chunk_mapper.py) | [tests](../tests/ops/mapper/test_text_chunk_mapper.py) | | video_captioning_from_audio_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Caption a video according to its audio streams based on Qwen-Audio model | [code](../data_juicer/ops/mapper/video_captioning_from_audio_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_audio_mapper.py) | | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 88d739d66..81aee2149 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 52 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 58 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -66,6 +66,11 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_ip_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除 IP 地址 | [code](../data_juicer/ops/mapper/clean_ip_mapper.py) | [tests](../tests/ops/mapper/test_clean_ip_mapper.py) | | clean_links_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Code](https://img.shields.io/badge/Code-590F08?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除链接,例如以 http 或 ftp 开头的 | [code](../data_juicer/ops/mapper/clean_links_mapper.py) | [tests](../tests/ops/mapper/test_clean_links_mapper.py) | | expand_macro_mapper | ![LaTeX](https://img.shields.io/badge/LaTeX-D99379?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 扩展通常在 TeX 文档顶部定义的宏 | [code](../data_juicer/ops/mapper/expand_macro_mapper.py) | [tests](../tests/ops/mapper/test_expand_macro_mapper.py) | +| extract_entity_attribute_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 给定主体和属性名,从文本中抽取主体的属性 | [code](../data_juicer/ops/mapper/extract_entity_attribute_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_attribute_mapper.py) | +| extract_entity_relation_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取知识图谱的实体和关系 | [code](../data_juicer/ops/mapper/extract_entity_relation_mapper.py) | [tests](../tests/ops/mapper/test_extract_entity_relation_mapper.py) | +| extract_event_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 从文本中抽取出事件和事件相关人物 | [code](../data_juicer/ops/mapper/extract_event_mapper.py) | [tests](../tests/ops/mapper/test_extract_event_mapper.py) | +| extract_keyword_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 构造文本的关键词 | [code](../data_juicer/ops/mapper/extract_keyword_mapper.py) | [tests](../tests/ops/mapper/test_extract_keyword_mapper.py) | +| extract_nickname_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 抽取昵称称呼关系 | [code](../data_juicer/ops/mapper/extract_nickname_mapper.py) | [tests](../tests/ops/mapper/test_extract_nickname_mapper.py) | | fix_unicode_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | [code](../data_juicer/ops/mapper/fix_unicode_mapper.py) | [tests](../tests/ops/mapper/test_fix_unicode_mapper.py) | | generate_qa_from_examples_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 根据种子数据,生成新的对话样本。 | [code](../data_juicer/ops/mapper/generate_qa_from_examples_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_examples_mapper.py) | | generate_qa_from_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从文本中生成问答对 | [code](../data_juicer/ops/mapper/generate_qa_from_text_mapper.py) | [tests](../tests/ops/mapper/test_generate_qa_from_text_mapper.py) | @@ -92,6 +97,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | remove_words_with_incorrect_ substrings_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除包含指定子字符串的单词 | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | | sentence_split_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | 根据语义拆分和重组句子 | [code](../data_juicer/ops/mapper/sentence_split_mapper.py) | [tests](../tests/ops/mapper/test_sentence_split_mapper.py) | +| text_chunk_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 对文本进行分片处理 | [code](../data_juicer/ops/mapper/text_chunk_mapper.py) | [tests](../tests/ops/mapper/test_text_chunk_mapper.py) | | video_captioning_from_audio_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | [code](../data_juicer/ops/mapper/video_captioning_from_audio_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_audio_mapper.py) | | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 通过对多种不同方式生成的文本进行摘要以生成样本的标题(从视频/音频/帧生成标题,从音频/帧生成标签,...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | diff --git a/tests/ops/mapper/test_extract_entity_attribute_mapper.py b/tests/ops/mapper/test_extract_entity_attribute_mapper.py new file mode 100644 index 000000000..96f186d29 --- /dev/null +++ b/tests/ops/mapper/test_extract_entity_attribute_mapper.py @@ -0,0 +1,64 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_entity_attribute_mapper import ExtractEntityAttributeMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractEntityAttributeMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + query_entities = ["李莲花", "方多病"] + query_attributes = ["语言风格", "角色性格"] + + op = ExtractEntityAttributeMapper( + query_entities=query_entities, + query_attributes=query_attributes, + api_model=api_model, + response_path=response_path) + + raw_text = """△笛飞声独自坐在莲花楼屋顶上。李莲花边走边悠闲地给马喂草。方多病则走在一侧,却总不时带着怀疑地盯向楼顶的笛飞声。 +方多病走到李莲花身侧:我昨日分明看到阿飞神神秘秘地见了一人,我肯定他有什么瞒着我们。阿飞的来历我必须去查清楚! +李莲花继续悠然地喂草:放心吧,我认识他十几年了,对他一清二楚。 +方多病:认识十几年?你上次才说是一面之缘? +李莲花忙圆谎:见得不多,但知根知底。哎,这老马吃得也太多了。 +方多病一把夺过李莲花手中的草料:别转移话题!——快说! +李莲花:阿飞啊,脾气不太好,他......这十年也没出过几次门,所以见识短,你不要和他计较。还有他是个武痴,武功深藏不露,你平时别惹他。 +方多病:呵,阿飞武功高?编瞎话能不能用心点? +李莲花:可都是大实话啊。反正,我和他彼此了解得很。你就别瞎操心了。 +方多病很是质疑:(突然反应过来)等等!你说你和他认识十几年?你们彼此了解?!这么说,就我什么都不知道?! +△李莲花一愣,意外方多病是如此反应。 +方多病很是不爽:不行,你们现在投奔我,我必须对我的手下都了解清楚。现在换我来问你,你,李莲花究竟籍贯何处?今年多大?家里还有什么人?平时都有些什么喜好?还有,可曾婚配? +△此时的笛飞声正坐在屋顶,从他的位置远远地向李莲花和方多病二人看去,二人声音渐弱。 +李莲花:鄙人李莲花,有个兄弟叫李莲蓬,莲花山莲花镇莲花村人,曾经订过亲,但媳妇跟人跑子。这一辈子呢,没什么抱负理想,只想种种萝卜、逗逗狗,平时豆花爱吃甜的,粽子要肉的...... +方多病:没一句实话。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=1) + for sample in dataset: + logger.info(f'{sample[Fields.main_entity]} {sample[Fields.attribute]}: {sample[Fields.attribute_description]}') + self.assertNotEqual(sample[Fields.attribute_description], '') + self.assertNotEqual(len(sample[Fields.attribute_support_text]), 0) + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_entity_relation_mapper.py b/tests/ops/mapper/test_extract_entity_relation_mapper.py new file mode 100644 index 000000000..40e3ca32d --- /dev/null +++ b/tests/ops/mapper/test_extract_entity_relation_mapper.py @@ -0,0 +1,86 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_entity_relation_mapper import ExtractEntityRelationMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractEntityRelationMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, op): + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + sample = dataset[0] + logger.info(f"entitis: {sample[Fields.entity]}") + logger.info(f"relations: {sample[Fields.relation]}") + + def test_default(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + op = ExtractEntityRelationMapper(api_model='qwen2.5-72b-instruct') + self._run_op(op) + + def test_entity_types(self): + op = ExtractEntityRelationMapper( + api_model='qwen2.5-72b-instruct', + entity_types=['人物', '组织', '地点', '物件', '武器', '武功'], + ) + self._run_op(op) + + def test_max_gleaning(self): + op = ExtractEntityRelationMapper( + api_model='qwen2.5-72b-instruct', + entity_types=['人物', '组织', '地点', '物件', '武器', '武功'], + max_gleaning=5, + ) + self._run_op(op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_event_mapper.py b/tests/ops/mapper/test_extract_event_mapper.py new file mode 100644 index 000000000..1652c8db2 --- /dev/null +++ b/tests/ops/mapper/test_extract_event_mapper.py @@ -0,0 +1,76 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_event_mapper import ExtractEventMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractEventMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = ExtractEventMapper(api_model=api_model, + response_path=response_path) + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + self.assertNotEqual(len(dataset), 0) + for sample in dataset: + logger.info(f"event: {sample[Fields.event_description]}") + self.assertNotEqual(sample[Fields.event_description], '') + logger.info(f"characters: {sample[Fields.relevant_characters]}") + self.assertNotEqual(sample[Fields.relevant_characters], []) + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_keyword_mapper.py b/tests/ops/mapper/test_extract_keyword_mapper.py new file mode 100644 index 000000000..5836f902a --- /dev/null +++ b/tests/ops/mapper/test_extract_keyword_mapper.py @@ -0,0 +1,72 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_keyword_mapper import ExtractKeywordMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractKeywordMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = ExtractKeywordMapper(api_model=api_model, + response_path=response_path) + + raw_text = """△芩婆走到中间,看着众人。 +芩婆:当年,我那老鬼漆木山与李相夷之父乃是挚交。原本李家隐世而居,一日为了救人,得罪附近山匪,夜里便遭了山匪所袭,唯有二子生还,流落街头。 +封磬震惊:二子?不是只有一个儿子吗? +芩婆:我和漆木山得知这个噩耗后,到处寻找李家那两个孩子的下落。只可惜等我们找他们时,李家长子李相显已经病死。 +李莲花似回忆起了什么:李相显...... +芩婆:我们只从乞丐堆里带回了年纪尚且未满四岁的李相夷,以及,(看向单孤刀)二个一直护着李相夷,与李相显年纪相仿的小乞丐...... +闪回/ +李相显将李且给他的玉佩塞给单孤刀,恳切托付:我没什么值钱的东西,这个玉佩是我唯一的家当了、送给你,我弟弟、相夷......求你照顾他一阵...... +△李相显还想再说什么已气绝而亡,小相夷唤着哥哥大哭,单孤刀愕然看着手里的玉佩有点不知所措。 +△话刚说完,哐当一声破庙门倒进来,几个其他少年乞丐进来。少年乞丐老大:这地儿不错,诶,你俩,出去! +△单孤刀把小相夷护在身后,抓住靠在墙边的木棍。单孤刀:这儿,是我,和我弟弟的。 +乞丐们要抢李相夷的馒头,小李相夷哭着死死护住自馒头不放。 +乞丐甲野蛮地抢:给我拿来! +小单孤刀:放开他! +△单孤刀用力撞向几个乞丐,救下小李相夷。乞丐甲:小子,活腻了! +△几个乞丐围攻小单孤刀,小单孤刀和众乞丐厮打到一起。突然其中一个乞丐掏出一把生锈的刀就朝单孤刀砍去、一个点燃火把棍戳他。单孤刀侧手一挡,火把棍在他手腕上烫出一道伤口,身后几根棍子打得他痛苦倒地! +/闪回结束 +△单孤刀拿着自己手里的玉佩看着,又看看自己手上的印记,不肯相信。单孤刀:胡说!全都是胡说!这些事我为何不知道?都是你在信口雌黄! +芩婆:那我问你,我们将你带回云隐山之前的事你又记得多少? +△单孤刀突然愣住,他意识到那之前的事自己竟都想不起来。 +芩婆:怎么?都想不起来了?(拽起单孤刀手腕,露出他的伤痕)你当日被你师父找到时,手腕上就受了伤,也正因为这处伤,高烧不退,醒来后便忘记了不少从前的事。 +△单孤刀呆住。 +芩婆:而相夷当年不过孩童,尚未到记事的年纪,很多事自然不知道。 +△李莲花得知真相,闭目叹息。 +△封磬震惊地看看单孤刀,又看看李莲花,终于想明白了一切,颓然、懊恼。 +封磬:自萱公主之子下落不明后,这近百年来我们整个家族都一直在不遗余力地寻找萱公主的子嗣后代,直到二十几年前终于让我寻得了线索,知道萱公主的曾孙被漆木山夫妇收为徒,但......我只知道萱公主之孙有一年约十岁的儿子,却不知......原来竟还有一幼子!我......我凭着南胤皇族的玉佩、孩子的年纪和他身上的印记来与主上相认,可没想到......这竟是一个错误!全错了! +△封磬神情复杂地看向李莲花,封磬:你,你才是我的主上...... +△封磬颓然地跪倒下来。 +△李莲花对眼前的一切有些意外、无措。 +笛飞声冷声:怪不得单孤刀的血对业火独毫无作用,李莲花的血才能毁掉这东西。 +△笛飞声不禁冷笑一下。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + sample = dataset[0] + logger.info(f"keywords: {sample[Fields.keyword]}") + + def test(self): + # before runing this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_extract_nickname_mapper.py b/tests/ops/mapper/test_extract_nickname_mapper.py new file mode 100644 index 000000000..635801155 --- /dev/null +++ b/tests/ops/mapper/test_extract_nickname_mapper.py @@ -0,0 +1,58 @@ +import unittest +import json + +from loguru import logger + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.extract_nickname_mapper import ExtractNicknameMapper +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) +from data_juicer.utils.constant import Fields + +# Skip tests for this OP in the GitHub actions due to unknown DistNetworkError. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ExtractNicknameMapperTest(DataJuicerTestCaseBase): + + + def _run_op(self, api_model, response_path=None): + + op = ExtractNicknameMapper(api_model=api_model, + response_path=response_path) + + raw_text = """△李莲花又指出刚才门框上的痕迹。 +△李莲花:门框上也是人的掌痕和爪印。指力能嵌入硬物寸余,七分力道主上,三分力道垫下,还有辅以的爪式,看样子这还有昆仑派的外家功夫。 +方多病看着李莲花,愈发生疑os:通过痕迹就能判断出功夫和门派,这绝对只有精通武艺之人才能做到,李莲花你到底是什么人?! +笛飞声环顾四周:有朝月派,还有昆仑派,看来必是一群武林高手在这发生了决斗! +李莲花:如果是武林高手过招,为何又会出现如此多野兽的痕迹。方小宝,你可听过江湖上有什么门派是驯兽来斗?方小宝?方小宝? +方多病回过神:不、不曾听过。 +李莲花:还有这些人都去了哪里? +笛飞声:打架不管是输是赢,自然是打完就走。 +李莲花摇头:就算打完便走,但这里是客栈,为何这么多年一直荒在这里,甚至没人来收拾一下? +笛飞声:闹鬼?这里死过这么多人,楼下又画了那么多符,所以不敢进来? +△这时,梁上又出现有东西移动的声响,李莲花、笛飞声都猛然回头看去。 +""" + samples = [{ + 'text': raw_text, + }] + + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + result = dataset[0][Fields.nickname] + result = [( + d[Fields.source_entity], + d[Fields.target_entity], + d[Fields.relation_description]) + for d in result] + logger.info(f'result: {result}') + self.assertIn(("李莲花","方多病","方小宝"), result) + + def test(self): + # before runing this test, set below environment variables: + # export DJ_API_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions + # export DJ_API_KEY=your_key + self._run_op('qwen2.5-72b-instruct') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_text_chunk_mapper.py b/tests/ops/mapper/test_text_chunk_mapper.py new file mode 100644 index 000000000..8004d9ede --- /dev/null +++ b/tests/ops/mapper/test_text_chunk_mapper.py @@ -0,0 +1,364 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.text_chunk_mapper import TextChunkMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TextChunkMapperTest(DataJuicerTestCaseBase): + + def _run_helper(self, op, samples, target): + dataset = Dataset.from_list(samples) + dataset = dataset.map(op.process, batch_size=2) + + for d, t in zip(dataset, target): + self.assertEqual(d['text'], t['text']) + + def test_naive_text_chunk(self): + + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + }, + { + 'text': + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(split_pattern='\n') + self._run_helper(op, source, target) + + def test_max_len_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and " + }, + { + 'text': "it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT" + }, + { + 'text': + '4, plusieurs manière' + }, + { + 'text': + "s d'accéder à ces fo" + }, + { + 'text': + 'nctionnalités sont c' + }, + { + 'text': + 'onçues simultanément' + }, + { + 'text': + '.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=20, split_pattern=None) + self._run_helper(op, source, target) + + def test_max_len_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and " + }, + { + 'text': "d it's a happy day!" + }, + { + 'text': "Sur la plateforme MT" + }, + { + 'text': 'MT4, plusieurs maniè' + }, + { + 'text': "ières d'accéder à ce" + }, + { + 'text': 'ces fonctionnalités ' + }, + { + 'text': 's sont conçues simul' + }, + { + 'text': 'ultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper(max_len=20, overlap_len=2) + self._run_helper(op, source, target) + + def test_max_len_and_split_pattern_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and " + }, + { + 'text': "d it's a happy day!" + }, + { + 'text': "Sur la plateforme MT" + }, + { + 'text': 'MT4, plusieurs maniè' + }, + { + 'text': "ières d'accéder à " + }, + { + 'text': 'ces fonctionnalités ' + }, + { + 'text': 's sont conçues simul' + }, + { + 'text': 'ultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=20, + overlap_len=2, + split_pattern='\n' + ) + self._run_helper(op, source, target) + + def test_tokenizer_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières" + }, + { + 'text': "ières d'accéder à ces fonctionnalités" + }, + { + 'text': "ités sont conçues simultanément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern=None, + tokenizer='Qwen/Qwen-7B-Chat', + trust_remote_code=True + ) + self._run_helper(op, source, target) + + def test_tiktoken_tokenizer_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières d" + }, + { + 'text': " d'accéder à ces fonctionnalités sont conçues simult" + }, + { + 'text': " simultanément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern=None, + tokenizer='gpt-4o', + trust_remote_code=True + ) + self._run_helper(op, source, target) + + def test_dashscope_tokenizer_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières" + }, + { + 'text': "ières d'accéder à ces fonctionnalités" + }, + { + 'text': "ités sont conçues simultanément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern=None, + tokenizer='qwen2.5-72b-instruct', + trust_remote_code=True + ) + self._run_helper(op, source, target) + + def test_all_text_chunk(self): + source = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à \n" + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + target = [ + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': "Sur la plateforme MT4, plusieurs manières" + }, + { + 'text': "ières d'accéder à " + }, + { + 'text': "ces fonctionnalités sont conçues simultan" + }, + { + 'text': "anément." + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + ] + op = TextChunkMapper( + max_len=10, + overlap_len=1, + split_pattern='\n', + tokenizer='Qwen/Qwen-7B-Chat', + trust_remote_code=True + ) + self._run_helper(op, source, target) + + +if __name__ == '__main__': + unittest.main()