Skip to content

Commit

Permalink
Optimize ray mode performance (#442)
Browse files Browse the repository at this point in the history
  • Loading branch information
pan-x-c authored Dec 13, 2024
1 parent 2e9a6cf commit 46062f8
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 118 deletions.
74 changes: 45 additions & 29 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial

import pyarrow as pa
from loguru import logger
Expand All @@ -13,28 +14,26 @@
rd = LazyLoader('rd', 'ray.data')


def is_valid_path(item, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, item))
return os.path.exists(full_path)
def get_abs_path(path, dataset_dir):
full_path = os.path.abspath(os.path.join(dataset_dir, path))
if os.path.exists(full_path):
return full_path
else:
return path


def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys):
def convert_to_absolute_paths(samples, dataset_dir, path_keys):
samples = samples.to_pydict()
for key in path_keys:
if key not in dict_with_paths:
continue
if isinstance(dict_with_paths[key], list):
dict_with_paths[key] = [
os.path.abspath(os.path.join(dataset_dir, item))
if isinstance(item, str) and is_valid_path(dataset_dir, item)
else item for item in dict_with_paths[key]
]
elif isinstance(dict_with_paths[key], str):
dict_with_paths[key] = os.path.abspath(
os.path.join(dataset_dir,
dict_with_paths[key])) if is_valid_path(
dict_with_paths[key],
dataset_dir) else dict_with_paths[key]
return dict_with_paths
for idx in range(len(samples[key])):
paths = samples[key][idx]
if isinstance(paths, str):
samples[key][idx] = get_abs_path(paths, dataset_dir)
elif isinstance(paths, list):
samples[key][idx] = [
get_abs_path(item, dataset_dir) for item in paths
]
return pa.Table.from_pydict(samples)


# TODO: check path for nestdataset
Expand All @@ -43,22 +42,26 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg):
Set all the path in input data to absolute path.
Checks dataset_dir and project_dir for valid paths.
"""
if not (cfg.video_key in dataset.columns() or cfg.image_key
in dataset.columns() or cfg.audio_key in dataset.columns()):
return dataset
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map(lambda item: convert_to_absolute_paths(
item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key]))
logger.info(f"transfer {dataset.count()} sample's paths")
path_keys = []
columns = dataset.columns()
for key in [cfg.video_key, cfg.image_key, cfg.audio_key]:
if key in columns:
path_keys.append(key)
if len(path_keys) > 0:
dataset_dir = os.path.dirname(dataset_path)
dataset = dataset.map_batches(partial(convert_to_absolute_paths,
dataset_dir=dataset_dir,
path_keys=path_keys),
batch_format='pyarrow',
zero_copy_batch=True)
return dataset


def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
columns = dataset.columns()
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
columns = dataset.columns()
if Fields.stats not in columns:
logger.info(f'columns {columns}')

def process_batch_arrow(table: pa.Table) -> pa.Table:
new_column_data = [{} for _ in range(len(table))]
Expand All @@ -77,6 +80,11 @@ def get_num_gpus(op, op_proc):
return 1.0 / proc_per_gpu


def filter_batch(batch, filter_func):
mask = pa.array(filter_func(batch.to_pydict()))
return batch.filter(mask)


class RayDataset(DJDataset):

def __init__(self,
Expand Down Expand Up @@ -122,7 +130,15 @@ def _run_single_op(self, op):
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path,
force_ascii=False)
self.data = self.data.filter(op.process)
if op.is_batched_op():
self.data = self.data.map_batches(partial(
filter_batch, filter_func=op.process),
batch_format='pyarrow',
batch_size=batch_size,
num_gpus=num_gpus,
zero_copy_batch=True)
else:
self.data = self.data.filter(op.process)
else:
logger.error(
'Ray executor only support Filter and Mapper OPs for now')
Expand Down
104 changes: 56 additions & 48 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class FlaggedWordFilter(Filter):
"""Filter to keep samples with flagged-word ratio less than a specific max
value."""

_batched_op = True

def __init__(self,
lang: str = 'en',
tokenization: bool = False,
Expand Down Expand Up @@ -72,53 +74,59 @@ def __init__(self,
self.model_key = prepare_model(model_type='sentencepiece',
lang=lang)

def compute_stats_single(self, sample, context=False):
def compute_stats_batched(self, samples, context=False):
# check if it's computed already
if StatsKeys.flagged_words_ratio in sample[Fields.stats]:
return sample

# try to get words from context
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.model_key}'
if context and words_key in sample[Fields.context]:
words = sample[Fields.context][words_key]
else:
tokenizer = get_model(self.model_key)
words = get_words_from_document(
sample[self.text_key],
token_func=tokenizer.encode_as_pieces if tokenizer else None)
if context:
sample[Fields.context][words_key] = words

# try to get refined words from context
refined_words_key = f'{InterVars.refined_words}-True-SPECIAL_CHARS-' \
f'{self.use_words_aug}-' \
f'{self.words_aug_group_sizes}-' \
f'{self.words_aug_join_char}'
if context and refined_words_key in sample[Fields.context]:
words = sample[Fields.context][refined_words_key]
else:
words = words_refinement(
words,
lower_case=True,
strip_chars=SPECIAL_CHARACTERS,
use_words_aug=self.use_words_aug,
words_aug_group_sizes=self.words_aug_group_sizes,
words_aug_join_char=self.words_aug_join_char)
if context:
sample[Fields.context][refined_words_key] = words

flagged_words_ratio = (len(
[word
for word in words if word in self.FLAGGED_WORDS[self.lang]]) /
len(words)) if len(words) != 0 else 0.0

if flagged_words_ratio > 1.0:
flagged_words_ratio = 1.0

sample[Fields.stats][
StatsKeys.flagged_words_ratio] = flagged_words_ratio
return sample

def process_single(self, sample):
return sample[Fields.stats][
StatsKeys.flagged_words_ratio] <= self.max_ratio
tokenizer = get_model(self.model_key)
for idx, stat in enumerate(samples_stats):
if StatsKeys.flagged_words_ratio in stat:
continue
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
if tokenizer else None)
if context:
samples[Fields.context][idx][words_key] = words
# try to get refined words from context
refined_words_key = f'{InterVars.refined_words}' \
'-True-SPECIAL_CHARS-' \
f'{self.use_words_aug}-' \
f'{self.words_aug_group_sizes}-' \
f'{self.words_aug_join_char}'
if context and refined_words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][refined_words_key]
else:
words = words_refinement(
words,
lower_case=True,
strip_chars=SPECIAL_CHARACTERS,
use_words_aug=self.use_words_aug,
words_aug_group_sizes=self.words_aug_group_sizes,
words_aug_join_char=self.words_aug_join_char)
if context:
samples[Fields.context][idx][refined_words_key] = words

flagged_words_ratio = (len([
word for word in words if word in self.FLAGGED_WORDS[self.lang]
]) / len(words)) if len(words) != 0 else 0.0

if flagged_words_ratio > 1.0:
flagged_words_ratio = 1.0

samples_stats[idx][
StatsKeys.flagged_words_ratio] = flagged_words_ratio

return samples

def process_batched(self, samples):
return list(
map(
lambda stat: stat[StatsKeys.flagged_words_ratio] <= self.
max_ratio,
samples[Fields.stats],
))
90 changes: 50 additions & 40 deletions data_juicer/ops/filter/image_aspect_ratio_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,43 +42,53 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats_single(self, sample, context=False):
# check if it's computed already
if StatsKeys.aspect_ratios in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.aspect_ratios] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

# compute aspect ratios for each image with W/H
aspect_ratios = {
key: (images[key].width / images[key].height)
for key in images
}
sample[Fields.stats][StatsKeys.aspect_ratios] = [
aspect_ratios[key] for key in loaded_image_keys
]
return sample

def process_single(self, sample):
aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios]
keep_bools = np.array([
self.min_ratio <= aspect_ratio <= self.max_ratio
for aspect_ratio in aspect_ratios
])
if len(keep_bools) <= 0:
return True

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
def compute_stats_batched(self, samples, context=False):
image_list = samples[self.image_key]
samples_stats = samples[Fields.stats]

for i, stat in enumerate(samples_stats):
# check if it's computed already
if StatsKeys.aspect_ratios in stat:
continue

# there is no image in this sample
loaded_image_keys = image_list[i]
if not loaded_image_keys:
stat[StatsKeys.aspect_ratios] = np.array([], dtype=np.float64)
continue

# load images
samples, images = load_data_with_context(samples, context,
loaded_image_keys,
load_image)

# compute aspect ratios for each image with W/H
aspect_ratios = {
key: (images[key].width / images[key].height)
for key in images
}
stat[StatsKeys.aspect_ratios] = [
aspect_ratios[key] for key in loaded_image_keys
]

return samples

def process_batched(self, samples):

def process_single(values):
keep_bools = np.array([
self.min_ratio <= value <= self.max_ratio for value in values
])
if len(keep_bools) <= 0:
return True

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()

return map(
lambda stat: process_single(stat[StatsKeys.aspect_ratios]),
samples[Fields.stats],
)
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/perplexity_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def compute_stats_batched(self, samples, context=False):
samples_list = samples[self.text_key]
samples_stats = samples[Fields.stats]
words_key = f'{InterVars.words}-{self.sp_model_key}'
tokenizer = get_model(self.sp_model_key)

for idx, stat in enumerate(samples_stats):
# check if it's computed already
Expand All @@ -54,7 +55,6 @@ def compute_stats_batched(self, samples, context=False):
if context and words_key in samples[Fields.context][idx]:
words = samples[Fields.context][idx][words_key]
else:
tokenizer = get_model(self.sp_model_key)
words = get_words_from_document(
samples_list[idx],
token_func=tokenizer.encode_as_pieces
Expand Down

0 comments on commit 46062f8

Please sign in to comment.