From da793458c70b4c034f0ba4bf03018b27febe12d1 Mon Sep 17 00:00:00 2001 From: seanzhang-zhichen <74812416+seanzhang-zhichen@users.noreply.github.com> Date: Fri, 24 May 2024 14:35:24 +0800 Subject: [PATCH] support-mul-pattern (#319) * support-mul-pattern * support mul pattern * add error log --------- Co-authored-by: zhangzc <2608882093@qq.com> --- .../ops/mapper/replace_content_mapper.py | 44 ++++++++++++++----- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/data_juicer/ops/mapper/replace_content_mapper.py b/data_juicer/ops/mapper/replace_content_mapper.py index 703405001..d73669c3e 100644 --- a/data_juicer/ops/mapper/replace_content_mapper.py +++ b/data_juicer/ops/mapper/replace_content_mapper.py @@ -1,3 +1,5 @@ +from typing import List, Union + import regex as re from ..base_op import OPERATORS, Mapper @@ -9,30 +11,52 @@ class ReplaceContentMapper(Mapper): a specific regular expression pattern with a designated replacement string.""" - def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): + def __init__(self, + pattern: Union[str, List[str]] = None, + repl: Union[str, List[str]] = '', + *args, + **kwargs): """ Initialization method. - :param pattern: regular expression pattern to search for within text. - :param repl: replacement string, default is empty string. + :param pattern: regular expression pattern(s) to search for within text + :param repl: replacement string(s), default is empty string :param args: extra args :param kwargs: extra args """ super().__init__(*args, **kwargs) self.pattern = pattern + self.repl = repl + self.compiled_patterns = [] + if isinstance(pattern, str): + self.compiled_patterns.append(self._prepare_pattern(pattern)) + elif isinstance(pattern, list): + for p in pattern: + self.compiled_patterns.append(self._prepare_pattern(p)) + + def _prepare_pattern(self, pattern: str) -> re.Pattern: + """Prepare the regular expression pattern.""" if ((pattern is not None and len(pattern) > 2) and (pattern.startswith("r'") and pattern.endswith("'") or pattern.startswith('r"') and pattern.endswith('"'))): - self.pattern = pattern[2:-1] - self.repl = repl + pattern = pattern[2:-1] + return re.compile(pattern, flags=re.DOTALL) def process(self, sample): - if self.pattern is None: return sample - sample[self.text_key] = re.sub(pattern=self.pattern, - repl=self.repl, - string=sample[self.text_key], - flags=re.DOTALL) + for i, pattern in enumerate(self.compiled_patterns): + if isinstance(self.repl, list) and i < len(self.repl): + replacement = self.repl[i] + elif isinstance(self.repl, list) and i >= len(self.repl): + raise ValueError(f"pattern length: {len(self.pattern)} '" + f'must be equal to ' + f'repl length: {len(self.repl)}') + else: + replacement = self.repl + + sample[self.text_key] = pattern.sub(replacement, + sample[self.text_key]) + return sample