From 46062f892a1a7ed2d9cccf266c57560167ebcfa3 Mon Sep 17 00:00:00 2001 From: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:29:00 +0800 Subject: [PATCH] Optimize ray mode performance (#442) --- data_juicer/core/ray_data.py | 74 ++++++++----- .../ops/filter/flagged_words_filter.py | 104 ++++++++++-------- .../ops/filter/image_aspect_ratio_filter.py | 90 ++++++++------- data_juicer/ops/filter/perplexity_filter.py | 2 +- 4 files changed, 152 insertions(+), 118 deletions(-) diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index 0c131561e..b252b5989 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -1,4 +1,5 @@ import os +from functools import partial import pyarrow as pa from loguru import logger @@ -13,28 +14,26 @@ rd = LazyLoader('rd', 'ray.data') -def is_valid_path(item, dataset_dir): - full_path = os.path.abspath(os.path.join(dataset_dir, item)) - return os.path.exists(full_path) +def get_abs_path(path, dataset_dir): + full_path = os.path.abspath(os.path.join(dataset_dir, path)) + if os.path.exists(full_path): + return full_path + else: + return path -def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys): +def convert_to_absolute_paths(samples, dataset_dir, path_keys): + samples = samples.to_pydict() for key in path_keys: - if key not in dict_with_paths: - continue - if isinstance(dict_with_paths[key], list): - dict_with_paths[key] = [ - os.path.abspath(os.path.join(dataset_dir, item)) - if isinstance(item, str) and is_valid_path(dataset_dir, item) - else item for item in dict_with_paths[key] - ] - elif isinstance(dict_with_paths[key], str): - dict_with_paths[key] = os.path.abspath( - os.path.join(dataset_dir, - dict_with_paths[key])) if is_valid_path( - dict_with_paths[key], - dataset_dir) else dict_with_paths[key] - return dict_with_paths + for idx in range(len(samples[key])): + paths = samples[key][idx] + if isinstance(paths, str): + samples[key][idx] = get_abs_path(paths, dataset_dir) + elif isinstance(paths, list): + samples[key][idx] = [ + get_abs_path(item, dataset_dir) for item in paths + ] + return pa.Table.from_pydict(samples) # TODO: check path for nestdataset @@ -43,22 +42,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg): Set all the path in input data to absolute path. Checks dataset_dir and project_dir for valid paths. """ - if not (cfg.video_key in dataset.columns() or cfg.image_key - in dataset.columns() or cfg.audio_key in dataset.columns()): - return dataset - dataset_dir = os.path.dirname(dataset_path) - dataset = dataset.map(lambda item: convert_to_absolute_paths( - item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key])) - logger.info(f"transfer {dataset.count()} sample's paths") + path_keys = [] + columns = dataset.columns() + for key in [cfg.video_key, cfg.image_key, cfg.audio_key]: + if key in columns: + path_keys.append(key) + if len(path_keys) > 0: + dataset_dir = os.path.dirname(dataset_path) + dataset = dataset.map_batches(partial(convert_to_absolute_paths, + dataset_dir=dataset_dir, + path_keys=path_keys), + batch_format='pyarrow', + zero_copy_batch=True) return dataset def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset: + columns = dataset.columns() if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - columns = dataset.columns() if Fields.stats not in columns: - logger.info(f'columns {columns}') def process_batch_arrow(table: pa.Table) -> pa.Table: new_column_data = [{} for _ in range(len(table))] @@ -77,6 +80,11 @@ def get_num_gpus(op, op_proc): return 1.0 / proc_per_gpu +def filter_batch(batch, filter_func): + mask = pa.array(filter_func(batch.to_pydict())) + return batch.filter(mask) + + class RayDataset(DJDataset): def __init__(self, @@ -122,7 +130,15 @@ def _run_single_op(self, op): if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) - self.data = self.data.filter(op.process) + if op.is_batched_op(): + self.data = self.data.map_batches(partial( + filter_batch, filter_func=op.process), + batch_format='pyarrow', + batch_size=batch_size, + num_gpus=num_gpus, + zero_copy_batch=True) + else: + self.data = self.data.filter(op.process) else: logger.error( 'Ray executor only support Filter and Mapper OPs for now') diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index dfadb0737..406ae1a23 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -24,6 +24,8 @@ class FlaggedWordFilter(Filter): """Filter to keep samples with flagged-word ratio less than a specific max value.""" + _batched_op = True + def __init__(self, lang: str = 'en', tokenization: bool = False, @@ -72,53 +74,59 @@ def __init__(self, self.model_key = prepare_model(model_type='sentencepiece', lang=lang) - def compute_stats_single(self, sample, context=False): + def compute_stats_batched(self, samples, context=False): # check if it's computed already - if StatsKeys.flagged_words_ratio in sample[Fields.stats]: - return sample - - # try to get words from context + samples_list = samples[self.text_key] + samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.model_key}' - if context and words_key in sample[Fields.context]: - words = sample[Fields.context][words_key] - else: - tokenizer = get_model(self.model_key) - words = get_words_from_document( - sample[self.text_key], - token_func=tokenizer.encode_as_pieces if tokenizer else None) - if context: - sample[Fields.context][words_key] = words - - # try to get refined words from context - refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \ - f'{self.use_words_aug}-' \ - f'{self.words_aug_group_sizes}-' \ - f'{self.words_aug_join_char}' - if context and refined_words_key in sample[Fields.context]: - words = sample[Fields.context][refined_words_key] - else: - words = words_refinement( - words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS, - use_words_aug=self.use_words_aug, - words_aug_group_sizes=self.words_aug_group_sizes, - words_aug_join_char=self.words_aug_join_char) - if context: - sample[Fields.context][refined_words_key] = words - - flagged_words_ratio = (len( - [word - for word in words if word in self.FLAGGED_WORDS[self.lang]]) / - len(words)) if len(words) != 0 else 0.0 - - if flagged_words_ratio > 1.0: - flagged_words_ratio = 1.0 - - sample[Fields.stats][ - StatsKeys.flagged_words_ratio] = flagged_words_ratio - return sample - - def process_single(self, sample): - return sample[Fields.stats][ - StatsKeys.flagged_words_ratio] <= self.max_ratio + tokenizer = get_model(self.model_key) + for idx, stat in enumerate(samples_stats): + if StatsKeys.flagged_words_ratio in stat: + continue + if context and words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][words_key] + else: + words = get_words_from_document( + samples_list[idx], + token_func=tokenizer.encode_as_pieces + if tokenizer else None) + if context: + samples[Fields.context][idx][words_key] = words + # try to get refined words from context + refined_words_key = f'{InterVars.refined_words}' \ + '-True-SPECIAL_CHARS-' \ + f'{self.use_words_aug}-' \ + f'{self.words_aug_group_sizes}-' \ + f'{self.words_aug_join_char}' + if context and refined_words_key in samples[Fields.context][idx]: + words = samples[Fields.context][idx][refined_words_key] + else: + words = words_refinement( + words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS, + use_words_aug=self.use_words_aug, + words_aug_group_sizes=self.words_aug_group_sizes, + words_aug_join_char=self.words_aug_join_char) + if context: + samples[Fields.context][idx][refined_words_key] = words + + flagged_words_ratio = (len([ + word for word in words if word in self.FLAGGED_WORDS[self.lang] + ]) / len(words)) if len(words) != 0 else 0.0 + + if flagged_words_ratio > 1.0: + flagged_words_ratio = 1.0 + + samples_stats[idx][ + StatsKeys.flagged_words_ratio] = flagged_words_ratio + + return samples + + def process_batched(self, samples): + return list( + map( + lambda stat: stat[StatsKeys.flagged_words_ratio] <= self. + max_ratio, + samples[Fields.stats], + )) diff --git a/data_juicer/ops/filter/image_aspect_ratio_filter.py b/data_juicer/ops/filter/image_aspect_ratio_filter.py index 6e5cb8516..d3b3785ee 100644 --- a/data_juicer/ops/filter/image_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/image_aspect_ratio_filter.py @@ -42,43 +42,53 @@ def __init__(self, f'Can only be one of ["any", "all"].') self.any = (any_or_all == 'any') - def compute_stats_single(self, sample, context=False): - # check if it's computed already - if StatsKeys.aspect_ratios in sample[Fields.stats]: - return sample - - # there is no image in this sample - if self.image_key not in sample or not sample[self.image_key]: - sample[Fields.stats][StatsKeys.aspect_ratios] = np.array( - [], dtype=np.float64) - return sample - - # load images - loaded_image_keys = sample[self.image_key] - sample, images = load_data_with_context(sample, context, - loaded_image_keys, load_image) - - # compute aspect ratios for each image with W/H - aspect_ratios = { - key: (images[key].width / images[key].height) - for key in images - } - sample[Fields.stats][StatsKeys.aspect_ratios] = [ - aspect_ratios[key] for key in loaded_image_keys - ] - return sample - - def process_single(self, sample): - aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios] - keep_bools = np.array([ - self.min_ratio <= aspect_ratio <= self.max_ratio - for aspect_ratio in aspect_ratios - ]) - if len(keep_bools) <= 0: - return True - - # different strategies - if self.any: - return keep_bools.any() - else: - return keep_bools.all() + def compute_stats_batched(self, samples, context=False): + image_list = samples[self.image_key] + samples_stats = samples[Fields.stats] + + for i, stat in enumerate(samples_stats): + # check if it's computed already + if StatsKeys.aspect_ratios in stat: + continue + + # there is no image in this sample + loaded_image_keys = image_list[i] + if not loaded_image_keys: + stat[StatsKeys.aspect_ratios] = np.array([], dtype=np.float64) + continue + + # load images + samples, images = load_data_with_context(samples, context, + loaded_image_keys, + load_image) + + # compute aspect ratios for each image with W/H + aspect_ratios = { + key: (images[key].width / images[key].height) + for key in images + } + stat[StatsKeys.aspect_ratios] = [ + aspect_ratios[key] for key in loaded_image_keys + ] + + return samples + + def process_batched(self, samples): + + def process_single(values): + keep_bools = np.array([ + self.min_ratio <= value <= self.max_ratio for value in values + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() + + return map( + lambda stat: process_single(stat[StatsKeys.aspect_ratios]), + samples[Fields.stats], + ) diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index 6a1e6e67e..6a6d74e16 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -45,6 +45,7 @@ def compute_stats_batched(self, samples, context=False): samples_list = samples[self.text_key] samples_stats = samples[Fields.stats] words_key = f'{InterVars.words}-{self.sp_model_key}' + tokenizer = get_model(self.sp_model_key) for idx, stat in enumerate(samples_stats): # check if it's computed already @@ -54,7 +55,6 @@ def compute_stats_batched(self, samples, context=False): if context and words_key in samples[Fields.context][idx]: words = samples[Fields.context][idx][words_key] else: - tokenizer = get_model(self.sp_model_key) words = get_words_from_document( samples_list[idx], token_func=tokenizer.encode_as_pieces