Skip to content

Commit

Permalink
[Ready] align sft formats & new ops (#454)
Browse files Browse the repository at this point in the history
* align sft formats

* fix test

* minor fix

* improve tests assert

* pre-commit

* sort

* add associated ops

* add tests

* fix install mapping

* fix subclasses

* fix import

* fix format

* unify methods naming

* unify tests

* update model

* update docs

* refine model loading

* fix empty history schema

* fix device

* ensure `with_rank` is set properly

* fix diffusion model_params

* minor fix

* TODO: new OP tests to be checked
  • Loading branch information
drcege authored Nov 5, 2024
1 parent d185d54 commit 65d7c91
Show file tree
Hide file tree
Showing 29 changed files with 1,323 additions and 1,137 deletions.
58 changes: 29 additions & 29 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,30 +61,26 @@ process:
- clean_links_mapper: # remove web links from text.
- clean_copyright_mapper: # remove copyright comments.
- expand_macro_mapper: # expand macro definitions in Latex text.
- extract_qa_mapper: # mapper to extract question and answer pair from text.
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # model name on huggingface to extract question and answer pair.
pattern: null # regular expression pattern to search for within text.
qa_format: 'chatml' # Output format of question and answer pair.
enable_vllm: true # Whether to use vllm for inference acceleration.
tensor_parallel_size: null # It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism.
max_model_len: null # It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config.
max_num_seqs: 256 # It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration.
sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- fix_unicode_mapper: # fix unicode errors in text.
- generate_instruction_mapper: # generate new instruction text data.
hf_model: 'Qwen/Qwen-7B-Chat' # model name on huggingface to generate instruction.
seed_file: 'demos/data/demo-dataset-chatml.jsonl' # Seed file as instruction samples to generate new instructions, chatml format.
instruct_num: 3 # the number of generated samples.
similarity_threshold: 0.7 # the similarity score threshold between the generated samples and the seed samples.Range from 0 to 1. Samples with similarity score less than this threshold will be kept.
prompt_template: null # Prompt template for generate samples. Please make sure the template contains "{augmented_data}", which corresponds to the augmented samples.
qa_pair_template: null # Prompt template for generate question and answer pair description. Please make sure the template contains two "{}" to format question and answer. Default: '【问题】\n{}\n【回答】\n{}\n'.
example_template: null # Prompt template for generate examples. Please make sure the template contains "{qa_pairs}", which corresponds to the question and answer pair description generated by param `qa_pair_template`.
qa_extraction_pattern: null # Regular expression pattern for parsing question and answer from model response.
enable_vllm: true # Whether to use vllm for inference acceleration.
tensor_parallel_size: null # It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism.
max_model_len: null # It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config.
max_num_seqs: 256 # It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration.
- generate_qa_from_examples_mapper: # mapper to generate question and answer pairs from examples.
hf_model: 'Qwen/Qwen2.5-7B-Instruct' # Model name on huggingface to generate question and answer pairs.
seed_file: 'demos/data/demo-dataset-chatml.jsonl' # Path to the seed file in chatml format.
example_num: 3 # The number of randomly selected seed examples.
similarity_threshold: 0.7 # the similarity score threshold between the generated samples and the seed examples. Range from 0 to 1. Samples with similarity score less than this threshold will be kept.
system_prompt: null # System prompt for guiding the generation task.
input_template: null # Template for building the input prompt.
example_template: null # Template for formatting each QA example.
qa_pair_template: null # Template for formatting a single QA pair within each example.
output_pattern: null # Regular expression pattern to extract questions and answers from model response.
enable_vllm: false # Whether to use vllm for inference acceleration.
model_params: null # Parameters for initializing the model.
sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- generate_qa_from_text_mapper: # mapper to generate question and answer pairs from text.
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa' # Model name on huggingface to generate question and answer pairs.
output_pattern: null # Regular expression pattern to extract questions and answers from model response.
enable_vllm: false # Whether to use vllm for inference acceleration.
model_params: null # Parameters for initializing the model.
sampling_params: null # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- image_blur_mapper: # mapper to blur images.
p: 0.2 # probability of the image being blured
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
Expand Down Expand Up @@ -146,13 +142,17 @@ process:
delete_random_char: false # whether to open the augmentation method of deleting random characters from the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据增强"
swap_random_char: false # whether to open the augmentation method of swapping random contiguous characters in the original texts. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有5种不同的数据强增方法"
replace_equivalent_num: false # whether to open the augmentation method of replacing random numbers with their equivalent representations in the original texts. **Notice**: Only for numbers for now. e.g. "这里一共有5种不同的数据增强方法" --> "这里一共有伍种不同的数据增强方法"
- optimize_instruction_mapper: # optimize instruction.
hf_model: 'alibaba-pai/Qwen2-7B-Instruct-Refine' # model name on huggingface to optimize instruction
enable_vllm: true # whether to use vllm for inference acceleration.
tensor_parallel_size: null # It is only valid when enable_vllm is True. The number of GPUs to use for distributed execution with tensor parallelism.
max_model_len: null # It is only valid when enable_vllm is True. Model context length. If unspecified, will be automatically derived from the model config.
max_num_seqs: 256 # It is only valid when enable_vllm is True. Maximum number of sequences to be processed in a single iteration.
sampling_params: {} # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- optimize_qa_mapper: # optimize question-answer pairs.
hf_model: 'Qwen/Qwen2.5-7B-Instruct' # model name on huggingface.
system_prompt: null # System prompt for guiding the optimization task.
input_template: null # Template for building the input for the model.
qa_pair_template: null # Template for formatting the question and answer pair.
output_pattern: null # Regular expression pattern to extract question and answer from model response.
enable_vllm: false # whether to use vllm for inference acceleration.
model_params: null # Parameters for initializing the model.
sampling_params: null # Sampling parameters for text generation. e.g {'temperature': 0.9, 'top_p': 0.95}
- optimize_query_mapper: # optimize query in question-answer pairs.
- optimize_response_mapper: # optimize response in question-answer pairs.
- punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations.
- remove_bibliography_mapper: # remove bibliography from Latex text.
- remove_comments_mapper: # remove comments from Latex text, code, etc.
Expand Down
21 changes: 15 additions & 6 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,23 @@ def map(self, *args, **kargs):

if inspect.ismethod(called_func):
# batched is required for fault-tolerant or batched OP
if not called_func.__self__.turbo or hasattr(
if callable(getattr(
called_func.__self__,
'is_batched_op') and called_func.__self__.is_batched_op():
'is_batched_op')) and called_func.__self__.is_batched_op(
) or not called_func.__self__.turbo:
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr(
called_func.__self__, 'is_batched_op'
) and called_func.__self__.is_batched_op() else 1
else:
kargs['batched'] = False

# rank is required for cuda model loading
if callable(
getattr(called_func.__self__,
'use_cuda')) and called_func.__self__.use_cuda():
kargs['with_rank'] = True

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
kargs['new_fingerprint'] = new_fingerprint
Expand Down Expand Up @@ -300,10 +307,12 @@ def filter(self, *args, **kargs):
called_func = called_func.__wrapped__

# Batched is always required for fault tolerance
if inspect.ismethod(
called_func) and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)
if inspect.ismethod(called_func):
if callable(getattr(
called_func.__self__,
'is_batched_op')) and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
Expand Down
9 changes: 9 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback
from functools import wraps

import numpy as np
import pyarrow as pa
from loguru import logger

Expand Down Expand Up @@ -133,6 +134,11 @@ def __init__(self, *args, **kwargs):
self.image_key = kwargs.get('image_key', 'images')
self.audio_key = kwargs.get('audio_key', 'audios')
self.video_key = kwargs.get('video_key', 'videos')

self.query_key = kwargs.get('query_key', 'query')
self.response_key = kwargs.get('response_key', 'response')
self.history_key = kwargs.get('history_key', 'history')

self.batch_size = kwargs.get('batch_size', 1000)

# whether the model can be accelerated using cuda
Expand Down Expand Up @@ -210,6 +216,9 @@ def run(self, dataset):
dataset = NestedDataset(dataset)
return dataset

def empty_history(self):
return np.empty((0, 0), dtype=str)


class Mapper(OP):

Expand Down
11 changes: 3 additions & 8 deletions data_juicer/ops/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
from .special_characters import SPECIAL_CHARACTERS

__all__ = [
'get_sentences_from_document',
'get_words_from_document',
'merge_on_whitespace_tab_newline',
'split_on_newline_tab_whitespace',
'split_on_whitespace',
'strip',
'words_augmentation',
'words_refinement',
'get_sentences_from_document', 'get_words_from_document',
'merge_on_whitespace_tab_newline', 'split_on_newline_tab_whitespace',
'split_on_whitespace', 'strip', 'words_augmentation', 'words_refinement'
]
7 changes: 4 additions & 3 deletions data_juicer/ops/deduplicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .video_deduplicator import VideoDeduplicator

__all__ = [
'VideoDeduplicator', 'RayBasicDeduplicator', 'DocumentMinhashDeduplicator',
'RayImageDeduplicator', 'RayDocumentDeduplicator', 'DocumentDeduplicator',
'ImageDeduplicator', 'DocumentSimhashDeduplicator', 'RayVideoDeduplicator'
'DocumentDeduplicator', 'DocumentMinhashDeduplicator',
'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator',
'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator',
'VideoDeduplicator'
]
33 changes: 16 additions & 17 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,20 @@
from .words_num_filter import WordsNumFilter

__all__ = [
'ImageTextSimilarityFilter', 'VideoAspectRatioFilter',
'ImageTextMatchingFilter', 'ImageNSFWFilter', 'TokenNumFilter',
'TextLengthFilter', 'SpecifiedNumericFieldFilter', 'AudioNMFSNRFilter',
'VideoAestheticsFilter', 'PerplexityFilter', 'PhraseGroundingRecallFilter',
'MaximumLineLengthFilter', 'AverageLineLengthFilter',
'SpecifiedFieldFilter', 'VideoTaggingFromFramesFilter',
'TextEntityDependencyFilter', 'VideoResolutionFilter',
'AlphanumericFilter', 'ImageWatermarkFilter', 'ImageAestheticsFilter',
'AudioSizeFilter', 'StopWordsFilter', 'CharacterRepetitionFilter',
'ImageShapeFilter', 'VideoDurationFilter', 'TextActionFilter',
'VideoOcrAreaRatioFilter', 'VideoNSFWFilter', 'SpecialCharactersFilter',
'VideoFramesTextSimilarityFilter', 'ImageAspectRatioFilter',
'AudioDurationFilter', 'LanguageIDScoreFilter', 'SuffixFilter',
'ImageSizeFilter', 'VideoWatermarkFilter', 'WordsNumFilter',
'ImageFaceCountFilter', 'ImageFaceRatioFilter', 'FlaggedWordFilter',
'WordRepetitionFilter', 'VideoMotionScoreFilter',
'ImagePairSimilarityFilter'
'AlphanumericFilter', 'AudioDurationFilter', 'AudioNMFSNRFilter',
'AudioSizeFilter', 'AverageLineLengthFilter', 'CharacterRepetitionFilter',
'FlaggedWordFilter', 'ImageAestheticsFilter', 'ImageAspectRatioFilter',
'ImageFaceCountFilter', 'ImageFaceRatioFilter', 'ImageNSFWFilter',
'ImagePairSimilarityFilter', 'ImageShapeFilter', 'ImageSizeFilter',
'ImageTextMatchingFilter', 'ImageTextSimilarityFilter',
'ImageWatermarkFilter', 'LanguageIDScoreFilter', 'MaximumLineLengthFilter',
'PerplexityFilter', 'PhraseGroundingRecallFilter',
'SpecialCharactersFilter', 'SpecifiedFieldFilter',
'SpecifiedNumericFieldFilter', 'StopWordsFilter', 'SuffixFilter',
'TextActionFilter', 'TextEntityDependencyFilter', 'TextLengthFilter',
'TokenNumFilter', 'VideoAestheticsFilter', 'VideoAspectRatioFilter',
'VideoDurationFilter', 'VideoFramesTextSimilarityFilter',
'VideoMotionScoreFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter',
'VideoResolutionFilter', 'VideoTaggingFromFramesFilter',
'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter'
]
Loading

0 comments on commit 65d7c91

Please sign in to comment.