Skip to content

Commit

Permalink
Allow multiple response patterns in the insert_mask_before_placeholde…
Browse files Browse the repository at this point in the history
…r transform (OpenNMT#2567)

* updated  onmt/transforms/insert_mask_before_placeholder.py"
* updated onmt/tests/test_transform.py
  • Loading branch information
l-k-11235 authored Mar 13, 2024
1 parent 211aeec commit c9c9ee3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion onmt/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ class TestInsertMaskBeforePlaceholder(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_opts = {
"response_pattern": "Response : ⦅newline⦆",
"response_patterns": ["Response : ⦅newline⦆"],
}

def test_insert_mask_before_placeholder(self):
Expand Down
26 changes: 15 additions & 11 deletions onmt/transforms/insert_mask_before_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,28 @@ def add_options(cls, parser):
"Transform/InsertMaskBeforePlaceholdersTransform"
)
group.add(
"--response_pattern",
"-response_pattern",
type=str,
"--response_patterns",
"-response_patterns",
help="Response patten to locate the end of the prompt",
default="Response : ⦅newline⦆",
default=["Response : ⦅newline⦆"],
nargs="+",
)

def _parse_opts(self):
self.response_pattern = self.opts.response_pattern
self.response_patterns = self.opts.response_patterns

def apply(self, example, is_train=False, stats=None, **kwargs):
_src = " ".join(example["src"])
if len(_src.split(self.response_pattern)) != 2:
response = None
for _pattern in self.response_patterns:
if len(_src.split(_pattern)) == 2:
prompt, response = _src.split(_pattern)
response = DefaultTokens.MASK_BEFORE.join([_pattern, response])
if response is not None:
_src = "".join([prompt, response])
example["src"] = _src.split(" ")
example["tgt"] = _src.split(" ")
else:
logger.info("The mask_before could not be inserted")
return example
prompt, response = _src.split(self.response_pattern)
response = DefaultTokens.MASK_BEFORE.join([self.response_pattern, response])
_src = "".join([prompt, response])
example["src"] = _src.split(" ")
example["tgt"] = _src.split(" ")
return example

0 comments on commit c9c9ee3

Please sign in to comment.