diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 561507585..a3f5c17e4 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -133,7 +133,7 @@ def __init__(self, *args, **kwargs): self.image_key = kwargs.get('image_key', 'images') self.audio_key = kwargs.get('audio_key', 'audios') self.video_key = kwargs.get('video_key', 'videos') - self.batch_size = kwargs.get('batch_size', 1) + self.batch_size = kwargs.get('batch_size', 1000) # whether the model can be accelerated using cuda _accelerator = kwargs.get('accelerator', None) @@ -204,6 +204,12 @@ def add_parameters(self, init_parameter_dict, **extra_param_dict): related_parameters.update(extra_param_dict) return related_parameters + def run(self, dataset): + from data_juicer.core.data import NestedDataset + if not isinstance(dataset, NestedDataset): + dataset = NestedDataset(dataset) + return dataset + class Mapper(OP): @@ -238,6 +244,7 @@ def process(self, sample): raise NotImplementedError def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Mapper, self).run(dataset) new_dataset = dataset.map( self.process, num_proc=self.runtime_np(), @@ -298,6 +305,7 @@ def process(self, sample): raise NotImplementedError def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Filter, self).run(dataset) if Fields.stats not in dataset.features: from data_juicer.core.data import add_same_content_to_new_column dataset = dataset.map(add_same_content_to_new_column, @@ -368,6 +376,7 @@ def process(self, dataset, show_num=0): raise NotImplementedError def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Deduplicator, self).run(dataset) dataset = dataset.map(self.compute_hash, num_proc=self.runtime_np(), with_rank=self.use_cuda(), @@ -406,6 +415,7 @@ def process(self, dataset): raise NotImplementedError def run(self, dataset, *, exporter=None, tracer=None): + dataset = super(Selector, self).run(dataset) new_dataset = self.process(dataset) if tracer: tracer.trace_filter(self._name, dataset, new_dataset) diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index 17361b29c..4e4112453 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -86,10 +86,9 @@ def process(self, samples): ratio_key = StatsKeys.alpha_token_ratio if self.tokenization \ else StatsKeys.alnum_ratio if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_ratio <= stat[ratio_key] <= self. - max_ratio, samples[Fields.stats])) + return map( + lambda stat: self.min_ratio <= stat[ratio_key] <= self. + max_ratio, samples[Fields.stats]) else: # single sample for ray filter if self.min_ratio <= samples[ diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py index 74d624a82..d2867b774 100644 --- a/data_juicer/ops/filter/average_line_length_filter.py +++ b/data_juicer/ops/filter/average_line_length_filter.py @@ -60,11 +60,9 @@ def compute_stats(self, samples, context=False): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_len <= stat[StatsKeys.avg_line_length - ] <= self.max_len, - samples[Fields.stats])) + return map( + lambda stat: self.min_len <= stat[StatsKeys.avg_line_length] <= + self.max_len, samples[Fields.stats]) else: # single sample for ray filter if self.min_len <= samples[Fields.stats][ diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py index a0441334a..965b368d6 100644 --- a/data_juicer/ops/filter/character_repetition_filter.py +++ b/data_juicer/ops/filter/character_repetition_filter.py @@ -80,11 +80,9 @@ def compute_stats(self, samples): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_ratio <= stat[ - StatsKeys.char_rep_ratio] <= self.max_ratio, - samples[Fields.stats])) + return map( + lambda stat: self.min_ratio <= stat[StatsKeys.char_rep_ratio] + <= self.max_ratio, samples[Fields.stats]) else: # single sample for ray filter if self.min_ratio <= samples[Fields.stats][ diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py index 146cfb0a2..16c919406 100644 --- a/data_juicer/ops/filter/maximum_line_length_filter.py +++ b/data_juicer/ops/filter/maximum_line_length_filter.py @@ -61,11 +61,9 @@ def compute_stats(self, samples, context=False): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_len <= stat[StatsKeys.max_line_length - ] <= self.max_len, - samples[Fields.stats])) + return map( + lambda stat: self.min_len <= stat[StatsKeys.max_line_length] <= + self.max_len, samples[Fields.stats]) else: # single sample for ray filter if self.min_len <= samples[Fields.stats][ diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index 287d15a11..9b532d7c6 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -80,8 +80,7 @@ def compute_stats(self, samples, context=False): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl, - samples[Fields.stats])) + return map(lambda stat: stat[StatsKeys.perplexity] <= self.max_ppl, + samples[Fields.stats]) else: return samples[Fields.stats][StatsKeys.perplexity] <= self.max_ppl diff --git a/data_juicer/ops/filter/special_characters_filter.py b/data_juicer/ops/filter/special_characters_filter.py index 0b56f390e..59fa61f52 100644 --- a/data_juicer/ops/filter/special_characters_filter.py +++ b/data_juicer/ops/filter/special_characters_filter.py @@ -54,11 +54,10 @@ def compute_stats(self, samples): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_ratio <= stat[ - StatsKeys.special_char_ratio] <= self.max_ratio, - samples[Fields.stats])) + return map( + lambda stat: self.min_ratio <= stat[ + StatsKeys.special_char_ratio] <= self.max_ratio, + samples[Fields.stats]) else: # single sample for ray filter if self.min_ratio <= \ diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py index ec61f8304..51e0bd68d 100644 --- a/data_juicer/ops/filter/text_length_filter.py +++ b/data_juicer/ops/filter/text_length_filter.py @@ -47,10 +47,9 @@ def compute_stats(self, samples): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_len <= stat[StatsKeys.text_len] <= - self.max_len, samples[Fields.stats])) + return map( + lambda stat: self.min_len <= stat[StatsKeys.text_len] <= self. + max_len, samples[Fields.stats]) else: # single sample for ray filter if self.min_len <= samples[Fields.stats][ diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 71f806e25..3e9cad251 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -116,11 +116,9 @@ def compute_stats(self, samples, context=False): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_ratio <= stat[ - StatsKeys.word_rep_ratio] <= self.max_ratio, - samples[Fields.stats])) + return map( + lambda stat: self.min_ratio <= stat[StatsKeys.word_rep_ratio] + <= self.max_ratio, samples[Fields.stats]) else: # single sample for ray filter if self.min_ratio <= samples[Fields.stats][ diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index 07eb8e2b7..978c252ad 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -80,10 +80,9 @@ def compute_stats(self, samples, context=False): def process(self, samples): if isinstance(samples[Fields.stats], list): - return list( - map( - lambda stat: self.min_num <= stat[StatsKeys.num_words] <= - self.max_num, samples[Fields.stats])) + return map( + lambda stat: self.min_num <= stat[StatsKeys.num_words] <= self. + max_num, samples[Fields.stats]) else: # single sample for ray filter if self.min_num <= samples[Fields.stats][ diff --git a/data_juicer/ops/mapper/chinese_convert_mapper.py b/data_juicer/ops/mapper/chinese_convert_mapper.py index 9236ddaa2..8e6bb9dc1 100644 --- a/data_juicer/ops/mapper/chinese_convert_mapper.py +++ b/data_juicer/ops/mapper/chinese_convert_mapper.py @@ -87,7 +87,7 @@ def __init__(self, mode: str = 's2t', *args, **kwargs): def process(self, samples): prepare_converter(self.mode) - samples[self.text_key] = list( - map(lambda text: OPENCC_CONVERTER.convert(text), - samples[self.text_key])) + samples[self.text_key] = [ + OPENCC_CONVERTER.convert(text) for text in samples[self.text_key] + ] return samples diff --git a/data_juicer/ops/mapper/clean_copyright_mapper.py b/data_juicer/ops/mapper/clean_copyright_mapper.py index 3bf6fcbdf..8908d33e9 100644 --- a/data_juicer/ops/mapper/clean_copyright_mapper.py +++ b/data_juicer/ops/mapper/clean_copyright_mapper.py @@ -55,7 +55,8 @@ def _process_single_sample(self, sample): return sample def process(self, samples): - samples[self.text_key] = list( - map(lambda text: self._process_single_sample(text), - samples[self.text_key])) + samples[self.text_key] = [ + self._process_single_sample(text) + for text in samples[self.text_key] + ] return samples diff --git a/data_juicer/ops/mapper/clean_html_mapper.py b/data_juicer/ops/mapper/clean_html_mapper.py index d959cc85f..09e847dd0 100644 --- a/data_juicer/ops/mapper/clean_html_mapper.py +++ b/data_juicer/ops/mapper/clean_html_mapper.py @@ -37,6 +37,7 @@ def _clean_html(raw_html): parser = HTMLParser(raw_html) return parser.text() - samples[self.text_key] = list( - map(lambda text: _clean_html(text), samples[self.text_key])) + samples[self.text_key] = [ + _clean_html(text) for text in samples[self.text_key] + ] return samples diff --git a/data_juicer/ops/mapper/fix_unicode_mapper.py b/data_juicer/ops/mapper/fix_unicode_mapper.py index 4ca71c30a..b44005076 100644 --- a/data_juicer/ops/mapper/fix_unicode_mapper.py +++ b/data_juicer/ops/mapper/fix_unicode_mapper.py @@ -36,9 +36,8 @@ def __init__(self, normalization: str = None, *args, **kwargs): '["NFC", "NFKC", "NFD", "NFKD"]') def process(self, samples): - samples[self.text_key] = list( - map( - lambda text: ftfy.fix_text(text, - normalization=self.normalization), - samples[self.text_key])) + samples[self.text_key] = [ + ftfy.fix_text(text, normalization=self.normalization) + for text in samples[self.text_key] + ] return samples diff --git a/data_juicer/ops/mapper/punctuation_normalization_mapper.py b/data_juicer/ops/mapper/punctuation_normalization_mapper.py index 6531833a3..18aa12c56 100644 --- a/data_juicer/ops/mapper/punctuation_normalization_mapper.py +++ b/data_juicer/ops/mapper/punctuation_normalization_mapper.py @@ -58,9 +58,8 @@ def __init__(self, *args, **kwargs): } def process(self, samples): - samples[self.text_key] = list( - map( - lambda text: ''.join( - [self.punctuation_unicode.get(c, c) for c in text]), - samples[self.text_key])) + samples[self.text_key] = [ + ''.join([self.punctuation_unicode.get(c, c) for c in text]) + for text in samples[self.text_key] + ] return samples diff --git a/data_juicer/ops/mapper/remove_bibliography_mapper.py b/data_juicer/ops/mapper/remove_bibliography_mapper.py index d2a2bf342..1eecd66d2 100644 --- a/data_juicer/ops/mapper/remove_bibliography_mapper.py +++ b/data_juicer/ops/mapper/remove_bibliography_mapper.py @@ -30,11 +30,11 @@ def __init__(self, *args, **kwargs): self.pattern += r').*$' def process(self, samples): - samples[self.text_key] = list( - map( - lambda text: re.sub(pattern=self.pattern, - repl=r'', - string=text, - flags=re.DOTALL), samples[self.text_key])) + samples[self.text_key] = [ + re.sub(pattern=self.pattern, + repl=r'', + string=text, + flags=re.DOTALL) for text in samples[self.text_key] + ] return samples diff --git a/data_juicer/ops/mapper/remove_specific_chars_mapper.py b/data_juicer/ops/mapper/remove_specific_chars_mapper.py index d487efa2f..78ca55e62 100644 --- a/data_juicer/ops/mapper/remove_specific_chars_mapper.py +++ b/data_juicer/ops/mapper/remove_specific_chars_mapper.py @@ -34,10 +34,10 @@ def process(self, samples): if self.pattern is None: return samples - samples[self.text_key] = list( - map( - lambda text: re.sub(pattern=self.pattern, - repl=r'', - string=text, - flags=re.DOTALL), samples[self.text_key])) + samples[self.text_key] = [ + re.sub(pattern=self.pattern, + repl=r'', + string=text, + flags=re.DOTALL) for text in samples[self.text_key] + ] return samples diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index e7e568f48..c024ceb0f 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -50,7 +50,7 @@ def test_yaml_cfg_file(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } }, 'nested dict load fail, for nonparametric op') self.assertDictEqual( @@ -68,7 +68,7 @@ def test_yaml_cfg_file(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } }, 'nested dict load fail, un-expected internal value') @@ -134,7 +134,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } }) self.assertDictEqual( @@ -152,7 +152,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } }) self.assertDictEqual( @@ -170,7 +170,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } }) self.assertDictEqual( @@ -188,7 +188,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } }) self.assertDictEqual( @@ -206,7 +206,7 @@ def test_mixture_cfg(self): 'cpu_required': 1, 'mem_required': 0, 'turbo': False, - 'batch_size': 1, + 'batch_size': 1000, } })