Skip to content

Commit

Permalink
[natural_translit] Change Alignment class method arguments from expre…
Browse files Browse the repository at this point in the history
…ssion to alignment to allow use of operators.

PiperOrigin-RevId: 660902116
  • Loading branch information
isingoo authored and copybara-github committed Aug 8, 2024
1 parent a0704f5 commit 8946805
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 78 deletions.
97 changes: 41 additions & 56 deletions nisaba/scripts/natural_translit/utils/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,8 @@ def tail_matches(self, other: Expression.OR_SYMBOL) -> bool:
def is_suffix(self, other: Expression.OR_SYMBOL) -> bool:
return self._compare(other, 'is_suffix')

_BASE_ANY = _BaseAlignment(Expression.ANY, Expression.ANY)


class Alignment(_BaseAlignment):
"""Alignment class for defining a relation between two expressions.
Expand Down Expand Up @@ -768,13 +770,10 @@ def simple(
@classmethod
def rule(
cls,
alias: str = '',
left: Expression.OR_SYMBOL = Expression.ANY,
right: Expression.OR_SYMBOL = Expression.ANY,
preceding_left: Expression.OR_SYMBOL = Expression.ANY,
preceding_right: Expression.OR_SYMBOL = Expression.ANY,
following_left: Expression.OR_SYMBOL = Expression.ANY,
following_right: Expression.OR_SYMBOL = Expression.ANY,
alias: str,
alignment: 'Alignment',
preceding: 'Alignment' = _BASE_ANY,
following: 'Alignment' = _BASE_ANY,
from_bos: bool = False,
to_eos: bool = False,
operation: op.Operation = op.Operation.COMMON.alignable,
Expand All @@ -784,12 +783,12 @@ def rule(
) -> 'Alignment':
rule = cls(
alias,
left,
right,
preceding_left,
preceding_right,
following_left,
following_right,
alignment.left,
alignment.right,
preceding.left,
preceding.right,
following.left,
following.right,
from_bos,
to_eos,
operation,
Expand All @@ -806,12 +805,10 @@ def rule(
@classmethod
def deletion(
cls,
alias: str = '',
left: Expression.OR_SYMBOL = Expression.ANY,
preceding_left: Expression.OR_SYMBOL = Expression.ANY,
preceding_right: Expression.OR_SYMBOL = Expression.ANY,
following_left: Expression.OR_SYMBOL = Expression.ANY,
following_right: Expression.OR_SYMBOL = Expression.ANY,
alias: str,
left: Expression.OR_SYMBOL,
preceding: 'Alignment' = _BASE_ANY,
following: 'Alignment' = _BASE_ANY,
from_bos: bool = False,
to_eos: bool = False,
operation: op.Operation = op.Operation.COMMON.deletion,
Expand All @@ -821,12 +818,9 @@ def deletion(
) -> 'Alignment':
return cls.rule(
alias,
left,
Atomic.CTRL.eps,
preceding_left,
preceding_right,
following_left,
following_right,
left >> Atomic.CTRL.eps,
preceding,
following,
from_bos,
to_eos,
operation,
Expand All @@ -838,12 +832,10 @@ def deletion(
@classmethod
def insertion(
cls,
alias: str = '',
right: Expression.OR_SYMBOL = Expression.ANY,
preceding_left: Expression.OR_SYMBOL = Expression.ANY,
preceding_right: Expression.OR_SYMBOL = Expression.ANY,
following_left: Expression.OR_SYMBOL = Expression.ANY,
following_right: Expression.OR_SYMBOL = Expression.ANY,
alias: str,
right: Expression.OR_SYMBOL,
preceding: 'Alignment' = _BASE_ANY,
following: 'Alignment' = _BASE_ANY,
from_bos: bool = False,
to_eos: bool = False,
operation: op.Operation = op.Operation.COMMON.insertion,
Expand All @@ -853,12 +845,9 @@ def insertion(
) -> 'Alignment':
return cls.rule(
alias,
Atomic.CTRL.eps,
right,
preceding_left,
preceding_right,
following_left,
following_right,
Atomic.CTRL.eps >> right,
preceding,
following,
from_bos,
to_eos,
operation,
Expand All @@ -870,13 +859,10 @@ def insertion(
@classmethod
def interchangeable(
cls,
alias: str = '',
left: Expression.OR_SYMBOL = Expression.ANY,
right: Expression.OR_SYMBOL = Expression.ANY,
preceding_left: Expression.OR_SYMBOL = Expression.ANY,
preceding_right: Expression.OR_SYMBOL = Expression.ANY,
following_left: Expression.OR_SYMBOL = Expression.ANY,
following_right: Expression.OR_SYMBOL = Expression.ANY,
alias: str,
alignment: 'Alignment',
preceding: 'Alignment' = _BASE_ANY,
following: 'Alignment' = _BASE_ANY,
from_bos: bool = False,
to_eos: bool = False,
operation: op.Operation = op.Operation.COMMON.interchangeable,
Expand All @@ -885,32 +871,31 @@ def interchangeable(
source: str = 'Alignment.NATIVE',
) -> tuple['Alignment', 'Alignment']:
common = (
preceding_left,
preceding_right,
following_left,
following_right,
preceding,
following,
from_bos,
to_eos,
operation,
priority,
applied_cost,
source,
)
left_to_right = cls.rule(alias + '_l2r', left, right, *common)
right_to_left = cls.rule(alias + '_r2l', right, left, *common)
left_to_right = cls.rule(
alias + '_l2r', alignment.left >> alignment.right, *common
)
right_to_left = cls.rule(
alias + '_r2l', alignment.right >> alignment.left, *common
)
return left_to_right, right_to_left

def copy(self) -> 'Alignment':
if self.source == Alignment.CONSTANT:
return self
return Alignment.rule(
self.alias,
self.left.copy(),
self.right.copy(),
self.preceding.left.copy(),
self.preceding.right.copy(),
self.following.left.copy(),
self.following.right.copy(),
self.left.copy() >> self.right.copy(),
self.preceding.left.copy() >> self.preceding.right.copy(),
self.following.left.copy() >> self.following.right.copy(),
self.from_bos,
self.to_eos,
self.operation,
Expand Down
44 changes: 22 additions & 22 deletions nisaba/scripts/natural_translit/utils/expression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,20 @@ def test_alignment_bools(self):
left=_ATM.a, right=exp.Atomic.CTRL.eps
).is_assigned()
)
self.assertTrue(exp.Alignment.deletion(left=_ATM.a).is_assigned())
self.assertTrue(exp.Alignment.deletion('a', left=_ATM.a).is_assigned())

def test_alignment_simple(self):
simple = exp.Alignment.simple(_SYM.a, _ATM.b + _ATM.c)
self.assertIsInstance(simple.left, exp.Atomic)
self.assertEqual(simple.string(), '(a∶(b c))')

def test_rule(self):
exp_any = exp.Expression.ANY
rule = exp.Alignment.rule(
'test',
_ATM.a,
_ATM.b,
preceding_left=_ATM.c,
following_right=_ATM.d,
_ATM.a >> _ATM.b,
preceding=_ATM.c >> exp_any,
following=exp_any >> _ATM.d,
applied_cost=0.1,
)
self.AssertEquivalent(rule.left, _ATM.a)
Expand All @@ -368,7 +368,7 @@ def test_deletion(self):
rule = exp.Alignment.deletion(
'a_deletion',
_ATM.a,
preceding_right=_ATM.b,
preceding=exp.Expression.ANY >> _ATM.b,
from_bos=True,
)
self.assertEqual(rule.string(), '(⌈​⊳​​🝓⋆​∶b⌋ a∶​ℰ​, deletion (1.000))')
Expand All @@ -377,22 +377,21 @@ def test_insertion(self):
rule = exp.Alignment.insertion(
'a_insertion',
_ATM.a,
following_right=_ATM.b,
following=exp.Expression.ANY >> _ATM.b,
to_eos=True,
)
self.assertEqual(rule.string(), '(​ℰ​∶a ⌈​🝓⋆​∶b​⊲​⌋, insertion (1.000))')

def test_interchangeable(self):
rule1, rule2 = exp.Alignment.interchangeable('a_b', _ATM.a, _ATM.b)
rule1, rule2 = exp.Alignment.interchangeable('a_b', _ATM.a >> _ATM.b)
self.assertEqual(rule1.string(), '(a∶b, interchangeable (0.100))')
self.assertEqual(rule2.string(), '(b∶a, interchangeable (0.100))')

def test_alignment_copy(self):
rule1 = exp.Alignment.rule(
'test',
_ATM.a,
_ATM.b,
preceding_left=_ATM.c,
_ATM.a >> _ATM.b,
preceding=_ATM.c >> exp.Expression.ANY,
applied_cost=0.1,
)
rule2 = rule1.copy()
Expand Down Expand Up @@ -439,21 +438,22 @@ def test_alignment_compare(self):
self.assertFalse(alg_a_b.is_suffix(alg_a_bc))

def test_context_matches(self):
ctx_a_any = _ATM.a >> exp.Expression.ANY
ctx_b_any = _ATM.b >> exp.Expression.ANY
ctx_ba_any = (_ATM.b + _ATM.a) >> exp.Expression.ANY
exp_any = exp.Expression.ANY
ctx_a_any = _ATM.a >> exp_any
ctx_b_any = _ATM.b >> exp_any
ctx_ba_any = (_ATM.b + _ATM.a) >> exp_any
rule1 = exp.Alignment.rule(
left=_ATM.c,
right=_ATM.d,
preceding_left=_ATM.a,
following_left=_ATM.b,
'rule1',
_ATM.c >> _ATM.d,
preceding=_ATM.a >> exp_any,
following=_ATM.b >> exp_any,
from_bos=True,
)
rule2 = exp.Alignment.rule(
left=_ATM.c,
right=_ATM.d,
preceding_left=_ATM.a,
following_left=_ATM.b,
'rule2',
_ATM.c >> _ATM.d,
preceding=_ATM.a >> exp_any,
following=_ATM.b >> exp_any,
to_eos=True,
)
self.assertTrue(
Expand Down

0 comments on commit 8946805

Please sign in to comment.