Skip to content

Commit

Permalink
support-mul-pattern (#319)
Browse files Browse the repository at this point in the history
* support-mul-pattern

* support mul pattern

* add error log

---------

Co-authored-by: zhangzc <[email protected]>
  • Loading branch information
seanzhang-zhichen and 15797939668 authored May 24, 2024
1 parent b5bd283 commit da79345
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions data_juicer/ops/mapper/replace_content_mapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Union

import regex as re

from ..base_op import OPERATORS, Mapper
Expand All @@ -9,30 +11,52 @@ class ReplaceContentMapper(Mapper):
a specific regular expression pattern with a designated
replacement string."""

def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs):
def __init__(self,
pattern: Union[str, List[str]] = None,
repl: Union[str, List[str]] = '',
*args,
**kwargs):
"""
Initialization method.
:param pattern: regular expression pattern to search for within text.
:param repl: replacement string, default is empty string.
:param pattern: regular expression pattern(s) to search for within text
:param repl: replacement string(s), default is empty string
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.pattern = pattern
self.repl = repl
self.compiled_patterns = []
if isinstance(pattern, str):
self.compiled_patterns.append(self._prepare_pattern(pattern))
elif isinstance(pattern, list):
for p in pattern:
self.compiled_patterns.append(self._prepare_pattern(p))

def _prepare_pattern(self, pattern: str) -> re.Pattern:
"""Prepare the regular expression pattern."""
if ((pattern is not None and len(pattern) > 2)
and (pattern.startswith("r'") and pattern.endswith("'")
or pattern.startswith('r"') and pattern.endswith('"'))):
self.pattern = pattern[2:-1]
self.repl = repl
pattern = pattern[2:-1]
return re.compile(pattern, flags=re.DOTALL)

def process(self, sample):

if self.pattern is None:
return sample

sample[self.text_key] = re.sub(pattern=self.pattern,
repl=self.repl,
string=sample[self.text_key],
flags=re.DOTALL)
for i, pattern in enumerate(self.compiled_patterns):
if isinstance(self.repl, list) and i < len(self.repl):
replacement = self.repl[i]
elif isinstance(self.repl, list) and i >= len(self.repl):
raise ValueError(f"pattern length: {len(self.pattern)} '"
f'must be equal to '
f'repl length: {len(self.repl)}')
else:
replacement = self.repl

sample[self.text_key] = pattern.sub(replacement,
sample[self.text_key])

return sample

0 comments on commit da79345

Please sign in to comment.