From 8ade9b55bbfe535453e2b88385c18e97f9fb1a3f Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Mon, 25 Nov 2024 10:28:47 +0800 Subject: [PATCH] Probe-based OP Fusion & Reordering (#464) * * init adapter * + add basic logic of workloads adaptation * * update the adaptation logic * + add unittests for Monitor * + add unittests for Monitor * * use multiprocessing to monitor resource utilization * + add unittests for adapter * * modification and fix for gece's comments * * bug fixed: there is no attr _name in FusedFilter * * support OP fusion based on probed speed of each OP * * support OP fusion based on probed speed of each OP * * fix bugs in fused OP speed calculation + add unit tests for probe-based OP fusion * * bug fixed * * bug fixed: enable batched when is_batched_op is True * * bug fixed: enable batched when is_batched_op is True * * expand the test dataset batch according to the num_proc * * support OP-wise adaptive batch size setting * * support batched processing for 4 image OPs * * set mp method for each OP during probing as well * + add visualization graphs for monitor results * * share the same context space for a batch * * share the same context space for a batch * * share the same context space for a batch but with idx info * * support batched process for fused filter * * bug fixed: add idx info to inter vars of word repetition filter, reduce the logic value for each op in fused filter * * extract batched to an outer func * modified for text_length_filter * * use a branch to decide which funcs are used * * update context for each sample as well * * modify for mapper and whitespace_normalization_mapper * * modify for two filters with context * * allow optional args for batched funcs * * restore to batched version and rename to xxx_batched * * restore to batched version and rename to xxx_batched * * restore to batched version and rename to xxx_batched * * update docs for this modification * * DO NOT allow to override the compute_stats or process methods in the subclass of Mapper and Filter * * rename the methods for the newly-added OP image_face_count_filter * * rename FusedFilter with "_batched" suffix * - bug fixed: update context organization for FusedFilter * * merge main into this branch * * support probe-based op fusion for ray mode * * restore to separate contexts of different samples * * restore to separate contexts of different samples * * avoid storing contexts from different samples into the same dict object * * set default fusion strategy to 'probe' - remove useless arguments * * set default fusion strategy to 'probe' - remove useless arguments * - skip tests with randomness --- configs/config_all.yaml | 2 + data_juicer/config/config.py | 22 + data_juicer/core/adapter.py | 51 +- data_juicer/core/analyzer.py | 15 +- data_juicer/core/data.py | 11 +- data_juicer/core/executor.py | 30 +- data_juicer/core/monitor.py | 24 + data_juicer/core/ray_executor.py | 17 +- .../ops/filter/image_aspect_ratio_filter.py | 2 + data_juicer/ops/filter/image_shape_filter.py | 2 + data_juicer/ops/filter/image_size_filter.py | 2 + .../filter/image_text_similarity_filter.py | 1 + data_juicer/ops/load.py | 10 +- data_juicer/ops/op_fusion.py | 130 ++- tests/core/test_adapter.py | 20 +- tests/ops/test_op_fusion.py | 950 +++++++++++++++++- 16 files changed, 1211 insertions(+), 78 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 90fc18875..eeb1ba1b2 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -21,9 +21,11 @@ open_tracer: false # whether to open th op_list_to_trace: [] # only ops in this list will be traced by tracer. If it's empty, all ops will be traced. Only available when tracer is opened. trace_num: 10 # number of samples to show the differences between datasets before and after each op. Only available when tracer is opened. op_fusion: false # whether to fuse operators that share the same intermediate variables automatically. Op fusion might reduce the memory requirements slightly but speed up the whole process. +fusion_strategy: 'probe' # OP fusion strategy. Support ['greedy', 'probe'] now. 'greedy' means keep the basic OP order and put the fused OP to the last of each fused OP group. 'probe' means Data-Juicer will probe the running speed for each OP at the beginning and reorder the OPs and fused OPs according to their probed speed (fast to slow). It's 'probe' in default. cache_compress: null # the compression method of the cache file, which can be specified in ['gzip', 'zstd', 'lz4']. If this parameter is None, the cache file will not be compressed. We recommend you turn on this argument when your input dataset is larger than tens of GB and your disk space is not enough. keep_stats_in_res_ds: false # whether to keep the computed stats in the result dataset. The intermediate fields to store the stats computed by Filters will be removed if it's False. It's False in default. keep_hashes_in_res_ds: false # whether to keep the computed hashes in the result dataset. The intermediate fields to store the hashes computed by Deduplicators will be removed if it's False. It's False in default. +adaptive_batch_size: false # whether to use adaptive batch sizes for each OP according to the probed results. It's False in default. # for multimodal data processing image_key: 'images' # key name of field to store the list of sample image paths. diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 0b0487dc3..76a20b786 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -15,6 +15,7 @@ from loguru import logger from data_juicer.ops.base_op import OPERATORS +from data_juicer.ops.op_fusion import FUSION_STRATEGIES from data_juicer.utils.logger_utils import setup_logger from data_juicer.utils.mm_utils import SpecialTokens @@ -275,6 +276,22 @@ def init_configs(args: Optional[List[str]] = None): help='Whether to fuse operators that share the same intermediate ' 'variables automatically. Op fusion might reduce the memory ' 'requirements slightly but speed up the whole process.') + parser.add_argument( + '--fusion_strategy', + type=str, + default='probe', + help='OP fusion strategy. Support ["greedy", "probe"] now. "greedy" ' + 'means keep the basic OP order and put the fused OP to the last ' + 'of each fused OP group. "probe" means Data-Juicer will probe ' + 'the running speed for each OP at the beginning and reorder the ' + 'OPs and fused OPs according to their probed speed (fast to ' + 'slow). It\'s "probe" in default.') + parser.add_argument( + '--adaptive_batch_size', + type=bool, + default=False, + help='Whether to use adaptive batch sizes for each OP according to ' + 'the probed results. It\'s False in default.') parser.add_argument( '--process', type=List[Dict], @@ -436,6 +453,11 @@ def init_setup_from_cfg(cfg: Namespace): # The checkpoint mode is not compatible with op fusion for now. if cfg.op_fusion: cfg.use_checkpoint = False + cfg.fusion_strategy = cfg.fusion_strategy.lower() + if cfg.fusion_strategy not in FUSION_STRATEGIES: + raise NotImplementedError( + f'Unsupported OP fusion strategy [{cfg.fusion_strategy}]. ' + f'Should be one of {FUSION_STRATEGIES}.') # update huggingface datasets cache directory only when ds_cache_dir is set from datasets import config diff --git a/data_juicer/core/adapter.py b/data_juicer/core/adapter.py index aa746a058..5ab6e6ec8 100644 --- a/data_juicer/core/adapter.py +++ b/data_juicer/core/adapter.py @@ -1,6 +1,9 @@ +from datasets import concatenate_datasets from datasets.config import DEFAULT_MAX_BATCH_SIZE from data_juicer.core.monitor import Monitor +from data_juicer.ops import UNFORKABLE +from data_juicer.utils.process_utils import setup_mp class Adapter: @@ -27,28 +30,43 @@ def execute_and_probe(dataset, operators, sample_interval=0.5): if operators is None or len(operators) == 0: return [] + # number of test samples + sample_num = len(dataset) + # resource utilization list resource_util_list = [] # probe for each OP + unforkable_operators = set(UNFORKABLE.modules.keys()) for op in operators: - # set num_proc to 1 for each OP to focus on the influence of batch - # size only. - old_num_proc = op.num_proc - op.num_proc = 1 + # select suitable mp method for each OP + mp_context = ['forkserver', 'spawn'] if ( + op.use_cuda() or op._name in unforkable_operators) else None + setup_mp(mp_context) + # expand the test dataset according to the runtime number of + # processes to ensure enough data for a batch and probe the true + # resource utilization for each OP + expanded_dataset = concatenate_datasets([dataset] * + op.runtime_np()) + + # set the test batch size and save the old one + if op.is_batched_op(): + old_batch_size = op.batch_size + op.batch_size = sample_num - # number of test samples - sample_num = len(dataset) # run single op and monitor the resource utilization - dataset, resource_util_per_op = Monitor.monitor_func( - op.run, args=(dataset, ), sample_interval=sample_interval) + _, resource_util_per_op = Monitor.monitor_func( + op.run, + args=(expanded_dataset, ), + sample_interval=sample_interval) # calculate speed resource_util_per_op[ 'speed'] = sample_num / resource_util_per_op['time'] resource_util_list.append(resource_util_per_op) - # restore to the original num_proc - op.num_proc = old_num_proc + # # restore the batch size + if op.is_batched_op(): + op.batch_size = old_batch_size return resource_util_list @@ -96,11 +114,20 @@ def probe_small_batch(self, dataset, operators): current load and estimated OP speed, returning load factors and speed ranks for each OP. + Notice: the probe should be run with cache enabled. + :param dataset: The dataset to pre-execute small batch on :param operators: The OP list to be pre-execution and probe :return: A list of probe results for each OP and the length of data batch to probe. """ + # record the cache state and enable the cache + from datasets import (disable_caching, enable_caching, + is_caching_enabled) + previous_state = is_caching_enabled() + if not previous_state: + enable_caching() + # take a small batch data_batch = self.take_batch(dataset, self.cfg) # process and monitor the resource utilization @@ -108,6 +135,10 @@ def probe_small_batch(self, dataset, operators): # analyze resource utilization analysis_res = Monitor.analyze_resource_util_list(resource_util_list) + # if the cache is disabled before, disable it again + if not previous_state: + disable_caching() + return analysis_res, len(data_batch) def batch_size_strategy(self, load_analysis_res, base_bs=1, util_th=0.9): diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index e9a6ef8d2..2ae4d3511 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -9,8 +9,10 @@ from data_juicer.config import init_configs from data_juicer.format import load_formatter from data_juicer.ops import Filter, load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils +from .adapter import Adapter from .exporter import Exporter @@ -88,7 +90,18 @@ def run(self, # extract processes logger.info('Preparing process operators...') - ops = load_ops(self.cfg.process, self.cfg.op_fusion) + ops = load_ops(self.cfg.process) + + if self.cfg.op_fusion: + probe_res = None + if self.cfg.fusion_strategy == 'probe': + logger.info('Probe the OP speed for OP reordering...') + adapter = Adapter(self.cfg) + probe_res, _ = adapter.probe_small_batch(dataset, ops) + + logger.info(f'Start OP fusion and reordering with strategy ' + f'[{self.cfg.fusion_strategy}]...') + ops = fuse_operators(ops, probe_res) # 2. stats precompute only for filter ops logger.info('Computing the stats of dataset...') diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 7e51bd1f8..9cef1fe89 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -216,8 +216,13 @@ def process(self, dataset.cleanup_cache_files() checkpointer.save_ckpt(dataset) if work_dir: - with open(os.path.join(work_dir, 'monitor.json'), 'w') as out: + monitor_dir = os.path.join(work_dir, 'monitor') + os.makedirs(monitor_dir, exist_ok=True) + with open(os.path.join(monitor_dir, 'monitor.json'), + 'w') as out: json.dump(resource_util_list, out) + Monitor.draw_resource_util_graph(resource_util_list, + monitor_dir) return dataset def map(self, *args, **kargs): @@ -251,9 +256,7 @@ def map(self, *args, **kargs): 'is_batched_op')) and called_func.__self__.is_batched_op( ) or not getattr(called_func.__self__, 'turbo', False): kargs['batched'] = True - kargs['batch_size'] = kargs.pop('batch_size', 1) if hasattr( - called_func.__self__, 'is_batched_op' - ) and called_func.__self__.is_batched_op() else 1 + kargs['batch_size'] = kargs.pop('batch_size', 1) else: kargs['batched'] = False diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 472a5e858..d9445dad0 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -11,6 +11,7 @@ from data_juicer.format.load import load_formatter from data_juicer.format.mixture_formatter import MixtureFormatter from data_juicer.ops import OPERATORS, load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils import cache_utils from data_juicer.utils.ckpt_utils import CheckpointManager @@ -18,6 +19,7 @@ FrequencySpecifiedFieldSelector from ..ops.selector.topk_specified_field_selector import \ TopkSpecifiedFieldSelector +from .adapter import Adapter from .exporter import Exporter from .tracer import Tracer @@ -43,6 +45,8 @@ def __init__(self, cfg: Optional[Namespace] = None): self.tracer = None self.ckpt_manager = None + self.adapter = Adapter(self.cfg) + # only enable it when using cache if self.cfg.use_cache: logger.info(f'Using cache compression method: ' @@ -158,9 +162,31 @@ def run(self, load_data_np = self.cfg.np dataset = self.formatter.load_dataset(load_data_np, self.cfg) - # 2. extract processes + # 2. extract processes and optimize their orders logger.info('Preparing process operators...') - ops = load_ops(self.cfg.process, self.cfg.op_fusion) + ops = load_ops(self.cfg.process) + + # OP fusion + if self.cfg.op_fusion: + probe_res = None + if self.cfg.fusion_strategy == 'probe': + logger.info('Probe the OP speed for OP reordering...') + probe_res, _ = self.adapter.probe_small_batch(dataset, ops) + + logger.info(f'Start OP fusion and reordering with strategy ' + f'[{self.cfg.fusion_strategy}]...') + ops = fuse_operators(ops, probe_res) + + # adaptive batch size + if self.cfg.adaptive_batch_size: + # calculate the adaptive batch size + bs_per_op = self.adapter.adapt_workloads(dataset, ops) + assert len(bs_per_op) == len(ops) + # update the adaptive batch size + logger.info(f'Adapt batch sizes for each OP to {bs_per_op}') + for i, op in enumerate(ops): + if op.is_batched_op(): + op.batch_size = bs_per_op[i] # 3. data process # - If tracer is open, trace each op after it's processed diff --git a/data_juicer/core/monitor.py b/data_juicer/core/monitor.py index 7d2f7984c..67f8f62a5 100644 --- a/data_juicer/core/monitor.py +++ b/data_juicer/core/monitor.py @@ -1,3 +1,4 @@ +import os import time from functools import partial from multiprocessing import get_context @@ -28,6 +29,7 @@ class Monitor: '''python { 'time': 10, + 'sampling interval': 0.5, 'resource': [ { 'timestamp': xxx, @@ -50,6 +52,7 @@ class Monitor: '''python { 'time': 10, + 'sampling interval': 0.5, 'resource': [...], 'resource_analysis': { 'GPU free mem.': { @@ -118,6 +121,24 @@ def monitor_current_resources(): return resource_dict + @staticmethod + def draw_resource_util_graph(resource_util_list, store_dir): + import matplotlib.pyplot as plt + for idx, resource_util_dict in enumerate(resource_util_list): + resource_list = resource_util_dict['resource'] + interval = resource_util_dict['sampling interval'] + for focus_metric in Monitor.DYNAMIC_FIELDS: + fn = f'func_{idx}_{focus_metric.replace(" ", "_")}.jpg' + ylbl = '%' if focus_metric.endswith('util.') else 'MB' + metric_list = [item[focus_metric] for item in resource_list] + plt.plot([i * interval for i in range(len(metric_list))], + metric_list) + plt.title(focus_metric) + plt.xlabel('Time (s)') + plt.ylabel(ylbl) + plt.savefig(os.path.join(store_dir, fn), bbox_inches='tight') + plt.clf() + @staticmethod def analyze_resource_util_list(resource_util_list): """ @@ -209,6 +230,9 @@ def monitor_func(func, args=None, sample_interval=0.5): resource_util_dict['resource'] = mdict['resource'] + # record interval + resource_util_dict['sampling interval'] = sample_interval + # calculate speed resource_util_dict['time'] = end - start diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 6b93fd3dd..1d90e31b3 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -5,8 +5,11 @@ from data_juicer.config import init_configs from data_juicer.core.ray_data import RayDataset from data_juicer.ops import load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.lazy_loader import LazyLoader +from .adapter import Adapter + ray = LazyLoader('ray', 'ray') rd = LazyLoader('rd', 'ray.data') @@ -33,6 +36,8 @@ def __init__(self, cfg=None): self.work_dir = self.cfg.work_dir + self.adapter = Adapter(self.cfg) + # init ray logger.info('Initing Ray ...') ray.init(self.cfg.ray_address) @@ -62,7 +67,17 @@ def run(self, load_data_np=None): dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') - ops = load_ops(self.cfg.process, self.cfg.op_fusion) + ops = load_ops(self.cfg.process) + + if self.cfg.op_fusion: + probe_res = None + if self.cfg.fusion_strategy == 'probe': + logger.info('Probe the OP speed for OP reordering...') + probe_res, _ = self.adapter.probe_small_batch(dataset, ops) + + logger.info(f'Start OP fusion and reordering with strategy ' + f'[{self.cfg.fusion_strategy}]...') + ops = fuse_operators(ops, probe_res) # 3. data process logger.info('Processing data...') diff --git a/data_juicer/ops/filter/image_aspect_ratio_filter.py b/data_juicer/ops/filter/image_aspect_ratio_filter.py index e069a1943..6e5cb8516 100644 --- a/data_juicer/ops/filter/image_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/image_aspect_ratio_filter.py @@ -14,6 +14,8 @@ class ImageAspectRatioFilter(Filter): AspectRatio = W / H. """ + _batched_op = True + def __init__(self, min_ratio: float = 0.333, max_ratio: float = 3.0, diff --git a/data_juicer/ops/filter/image_shape_filter.py b/data_juicer/ops/filter/image_shape_filter.py index 064929111..b265add30 100644 --- a/data_juicer/ops/filter/image_shape_filter.py +++ b/data_juicer/ops/filter/image_shape_filter.py @@ -15,6 +15,8 @@ class ImageShapeFilter(Filter): """Filter to keep samples with image shape (w, h) within specific ranges. """ + _batched_op = True + def __init__(self, min_width: int = 1, max_width: int = sys.maxsize, diff --git a/data_juicer/ops/filter/image_size_filter.py b/data_juicer/ops/filter/image_size_filter.py index f4ab8f760..fd8b7bcef 100644 --- a/data_juicer/ops/filter/image_size_filter.py +++ b/data_juicer/ops/filter/image_size_filter.py @@ -12,6 +12,8 @@ class ImageSizeFilter(Filter): specific range. """ + _batched_op = True + def __init__(self, min_size: str = '0', max_size: str = '1TB', diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index ac23330c3..d43c9bc3f 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -19,6 +19,7 @@ class ImageTextSimilarityFilter(Filter): within a specific range.""" _accelerator = 'cuda' + _batched_op = True def __init__(self, hf_clip: str = 'openai/clip-vit-base-patch32', diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py index cf10cc51a..e0a4fb0b8 100644 --- a/data_juicer/ops/load.py +++ b/data_juicer/ops/load.py @@ -1,15 +1,12 @@ from .base_op import OPERATORS -from .op_fusion import fuse_operators -def load_ops(process_list, op_fusion=False): +def load_ops(process_list): """ Load op list according to the process list from config file. :param process_list: A process list. Each item is an op name and its arguments. - :param op_fusion: whether to fuse ops that share the same intermediate - variables. :return: The op instance list. """ ops = [] @@ -19,10 +16,7 @@ def load_ops(process_list, op_fusion=False): ops.append(OPERATORS.modules[op_name](**args)) new_process_list.append(process) - # detect filter groups - if op_fusion: - new_process_list, ops = fuse_operators(new_process_list, ops) - + # store the OP configs into each OP for op_cfg, op in zip(new_process_list, ops): op._op_cfg = op_cfg diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index 26aaa556e..489f90ab0 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -1,5 +1,6 @@ from typing import List +import numpy as np from loguru import logger from data_juicer.utils.constant import Fields, InterVars @@ -23,47 +24,49 @@ INTER_SAMPLED_FRAMES = Registry(InterVars.sampled_frames) # all -ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES, LOADED_VIDEOS] +ALL_INTER_VARS = [ + INTER_LINES, INTER_WORDS, LOADED_IMAGES, LOADED_VIDEOS, + INTER_SAMPLED_FRAMES +] +# supported fusion strategies +FUSION_STRATEGIES = {'greedy', 'probe'} -def fuse_operators(process_list, ops): + +def fuse_operators(ops, probe_res=None): """ Fuse the input ops list and return the fused ops list. - :param process_list: the list of original process definition, including op - names and args. :param ops: the corresponding list of op objects. + :param probe_res: the probed speed for each OP from Monitor. :return: a list of fused op objects. """ + if probe_res is None: + probe_res = [None for _ in range(len(ops))] # detect filter groups and try to fuse them - fused_op_def = [] fused_ops = [] filter_group = [] in_group = False - for process, op in zip(process_list, ops): + for op, op_probe in zip(ops, probe_res): if isinstance(op, Filter): if not in_group: in_group = True - filter_group.append((process, op)) + filter_group.append((op, op_probe)) elif in_group: # got a filter group, try to fuse them - fused_group_def, fused_group = fuse_filter_group(filter_group) - fused_op_def.extend(fused_group_def) + fused_group = fuse_filter_group(filter_group) fused_ops.extend(fused_group) filter_group = [] in_group = False # and add the current non-filter op into fused_ops - fused_op_def.append(process) fused_ops.append(op) else: # not a filter and not in a filter group, skip - fused_op_def.append(process) fused_ops.append(op) if in_group and len(filter_group) > 0: # the final filter group, try to fuse them - fused_group_def, fused_group = fuse_filter_group(filter_group) - fused_op_def.extend(fused_group_def) + fused_group = fuse_filter_group(filter_group) fused_ops.extend(fused_group) - return fused_op_def, fused_ops + return fused_ops def fuse_filter_group(original_filter_group): @@ -74,25 +77,25 @@ def fuse_filter_group(original_filter_group): definitions and objects. :return: the fused definitions and objects of the input filter group. """ - fused_group_def = [] fused_group = [] + group_speed = [] all_intermediate_vars = ALL_INTER_VARS all_fused_filters = { inter_vars: [] for inter_vars in all_intermediate_vars } # group these filters by their intermediate vars - for process, op in original_filter_group: - op_name, op_args = list(process.items())[0] + for op, probe_res in original_filter_group: + op_name = op._name for inter_vars in all_intermediate_vars: if op_name in inter_vars.modules: - all_fused_filters[inter_vars].append((process, op)) + all_fused_filters[inter_vars].append((op, probe_res)) break else: # first apply other filters to decrease the number of samples, so # we add them into the fused_group list directly - fused_group_def.append(process) fused_group.append(op) + group_speed.append(probe_res['speed'] if probe_res else 0) # try to fuse ops for each type of intermediate vars for inter_vars in all_intermediate_vars: @@ -102,40 +105,59 @@ def fuse_filter_group(original_filter_group): pass elif len(inter_vars_filter) > 1: # more than 1 ops share the same intermediate var, try to fuse them - defs, ops = zip(*inter_vars_filter) + ops, probe_res_list = zip(*inter_vars_filter) # new definition: new name and a definition list of fused op list - fused_filter_def = { - 'OpFusion:(%s)' % ','.join([ - list(process.items())[0][0] for process in defs - ]): - list(defs) - } + fused_filter_name = 'OpFusion:(%s)' % ','.join( + [op._name for op in ops]) logger.info(f'Ops are fused into one op ' - f'{list(fused_filter_def.keys())[0]}.') + f'{fused_filter_name}.') # use these ops to create a FusedFilter object, and add the fused # definition and op into the fused group - fused_filter = FusedFilter(ops) - fused_group_def.append(fused_filter_def) + fused_filter = FusedFilter(fused_filter_name, ops) + fused_filter._op_cfg = { + fused_filter_name: [op._op_cfg for op in ops] + } + fused_filter_speed = sum([ + 1.0 / probe_res['speed'] for probe_res in probe_res_list + if probe_res + ]) + if fused_filter_speed > 0: + fused_filter_speed = 1.0 / fused_filter_speed fused_group.append(fused_filter) + group_speed.append(fused_filter_speed) else: # only 1 op for this type of intermediate var, add it to the fused # group directly without fusion - fused_group_def.append(inter_vars_filter[0][0]) - fused_group.append(inter_vars_filter[0][1]) + fused_group.append(inter_vars_filter[0][0]) + probe_res = inter_vars_filter[0][1] + group_speed.append(probe_res['speed'] if probe_res else 0) + + # reorder according to the probed speed results in group_speed + # 'greedy': all speed data in group_speed will be 0, which will keep the + # current order of fused group + # 'probe': OPs in fused group will be reordered according to the speed data + # in group_speed in descending order + fused_group = [ + op for op, _ in sorted( + zip(fused_group, group_speed), key=lambda it: it[1], reverse=True) + ] - return fused_group_def, fused_group + return fused_group class FusedFilter(Filter): """A fused operator for filters.""" - def __init__(self, fused_filters: List): + _batched_op = True + + def __init__(self, name: str, fused_filters: List): """ Initialization method. :param fused_filters: a list of filters to be fused. """ super().__init__() + self._name = name self.fused_filters = fused_filters # set accelerator to 'cuda' if there exists any ops whose accelerator # is 'cuda' @@ -144,30 +166,40 @@ def __init__(self, fused_filters: List): if 'cuda' in accelerator_methods: self.accelerator = 'cuda' - def compute_stats_single(self, sample, rank=None): + # update num_proc with the min num_proc of all fusible filters + self.num_proc = min([op.runtime_np() for op in self.fused_filters]) + + def compute_stats_batched(self, samples, rank=None): import av # context for the intermediate vars - sample[Fields.context] = {} + num_samples = len(samples[Fields.stats]) + samples[Fields.context] = [{} for _ in range(num_samples)] for op in self.fused_filters: # open the context for these fused ops if op.accelerator == 'cuda': - sample = op.compute_stats(sample, rank=rank, context=True) + samples = op.compute_stats_batched(samples, + rank=rank, + context=True) else: - sample = op.compute_stats(sample, context=True) + samples = op.compute_stats_batched(samples, context=True) # clean up the contexts after processing # check if there are containers that need to be closed - for context_key in sample[Fields.context]: - if isinstance(sample[Fields.context][context_key], - av.container.InputContainer): - sample[Fields.context][context_key].streams.video[0].close() - sample[Fields.context][context_key].close() - _ = sample.pop(Fields.context) - return sample - - def process_single(self, sample): + for ctx in samples[Fields.context]: + for context_key in ctx: + if isinstance(ctx[context_key], av.container.InputContainer): + ctx[context_key].streams.video[0].close() + ctx[context_key].close() + _ = samples.pop(Fields.context) + return samples + + def process_batched(self, samples): # Only return True when all filters return True + res = None for op in self.fused_filters: - if not op.process(sample): - return False - return True + this_res = np.array(list(op.process_batched(samples))) + if res is not None: + res = np.logical_and(res, this_res) + else: + res = this_res + return res diff --git a/tests/core/test_adapter.py b/tests/core/test_adapter.py index 965355b96..4a58d882f 100644 --- a/tests/core/test_adapter.py +++ b/tests/core/test_adapter.py @@ -4,11 +4,12 @@ from datasets import load_dataset from loguru import logger from data_juicer.core import Adapter -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS from data_juicer.ops.mapper import FixUnicodeMapper from data_juicer.ops.filter import PerplexityFilter from data_juicer.ops.deduplicator import DocumentDeduplicator +@SKIPPED_TESTS.register_module() class AdapterTest(DataJuicerTestCaseBase): @classmethod @@ -177,6 +178,23 @@ def test_adapt_workloads(self): datasets.enable_caching() + def test_adapt_workloads_multiprocessing(self): + datasets.disable_caching() + # basic test + ds = load_dataset('json', data_files=self.test_file, split='train') + ops = [ + FixUnicodeMapper(num_proc=4), + PerplexityFilter(num_proc=4), + DocumentDeduplicator(num_proc=4), + ] # use some batched OPs later + + adapter = Adapter({'batch_size': 100}) + adapted_batch_sizes = adapter.adapt_workloads(ds, ops) + self.assertEqual(len(adapted_batch_sizes), len(ops)) + logger.info(adapted_batch_sizes) + + datasets.enable_caching() + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index 13d633134..04fc2a50e 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -1,13 +1,15 @@ import unittest from data_juicer.ops.load import load_ops +from data_juicer.ops.op_fusion import fuse_operators from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class OpFusionTest(DataJuicerTestCaseBase): - def _run_op_fusion(self, original_process_list, target_process_list): - ops = load_ops(original_process_list, op_fusion=True) + def _run_op_fusion(self, original_process_list, target_process_list, probe_res=None): + ops = load_ops(original_process_list) + ops = fuse_operators(ops, probe_res) new_process_list = [op._op_cfg for op in ops] self.assertEqual(new_process_list, target_process_list) @@ -1014,6 +1016,950 @@ def test_different_intermediate_vars(self): ] self._run_op_fusion(original_process, target_process) + def test_regular_config_with_probe_res(self): + probed_speeds = [ + # single filter + {'speed': 100}, + + # mappers + {'speed': 2}, + {'speed': 1}, + {'speed': 4}, + {'speed': 5}, + {'speed': 3}, + + # filter groups + # fused OPs: ~2.56 + # single OP 1: 1 (slowest) + # single OP 2: 3 (fastest) + {'speed': 15}, # fusible + {'speed': 1}, + {'speed': 14}, # fusible + {'speed': 3}, + {'speed': 13}, # fusible + {'speed': 12}, # fusible + {'speed': 11}, # fusible + + # deduplicator + {'speed': 0.1}, + ] + + original_process = [{ + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }, { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + }] + target_process = [ + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + } + ] + self._run_op_fusion(original_process, target_process, probed_speeds) + + def test_not_enough_fusible_ops_to_fuse_with_probe_res(self): + # still apply reordering: + # - ordinary ops + # - ops with InterVars.lines + # - ops with InterVars.words + probe_res_list = [ + {'speed': 3}, + {'speed': 1}, + {'speed': 4}, + {'speed': 2}, + ] + + original_process = [{ + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text' + } + }] + target_process = [{ + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }] + self._run_op_fusion(original_process, target_process, probe_res_list) + + def test_multiple_groups_with_probe_res(self): + probe_res_list = [ + # group 1 + # fused filter will be put before the single filter + {'speed': 10}, + {'speed': 10}, + {'speed': 1}, + + # mappers + {'speed': 4}, + {'speed': 2}, + {'speed': 5}, + {'speed': 3}, + {'speed': 1}, + + # group 2 + # fused filter will be put after those two single filters + {'speed': 1}, # fusible + {'speed': 8}, + {'speed': 1}, # fusible + {'speed': 10}, + {'speed': 1}, # fusible + + # deduplicator + {'speed': 1}, + ] + + original_process = [{ + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }, { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + }] + target_process = [ + { + 'OpFusion:(stopwords_filter,flagged_words_filter)': [{ + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }] + }, + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + } + ] + self._run_op_fusion(original_process, target_process, probe_res_list) + + def test_only_fusible_ops_with_probe_res(self): + probe_res_list = [ + {'speed': 1}, + {'speed': 1}, + {'speed': 1}, + {'speed': 1}, + {'speed': 1}, + ] + + original_process = [{ + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }] + target_process = [{ + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }] + self._run_op_fusion(original_process, target_process, probe_res_list) + + def test_different_intermediate_vars_with_probe_res(self): + probe_res_list = [ + # single filter + {'speed': 1}, + + # mappers + {'speed': 5}, + {'speed': 3}, + {'speed': 1}, + {'speed': 2}, + {'speed': 4}, + + # filter group + # single 1: 1 (2) + # single 2: 0.5 (3) + # group 1: 0.04 (4) + # group 2: 1.5 (1) + {'speed': 0.1}, # group 1 + {'speed': 1}, + {'speed': 3}, # group 2 + {'speed': 0.2}, # group 1 + {'speed': 0.5}, + {'speed': 0.3}, # group 1 + {'speed': 0.4}, # group 1 + {'speed': 3}, # group 2 + {'speed': 0.5}, # group 1 + + # deduplicator + {'speed': 1}, + ] + + original_process = [{ + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text' + } + }, { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'maximum_line_length_filter': { + 'min_len': 20, + 'text_key': 'text' + } + }, { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + }, { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + }] + target_process = [ + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { + 'lang': 'en', + 'substrings': None, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'OpFusion:(average_line_length_filter,maximum_line_length_filter)': # noqa: E501 + [ + { + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text', + } + }, + { + 'maximum_line_length_filter': { + 'min_len': 20, + 'text_key': 'text', + } + } + ] + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' + } + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, + 'text_key': 'text' + } + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 + [ + { + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } + } + ] + self._run_op_fusion(original_process, target_process, probe_res_list) + if __name__ == '__main__': unittest.main()