Skip to content

Commit

Permalink
Move make_thing into class. Replace Valid with Any.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577229535
  • Loading branch information
isingoo authored and copybara-github committed Oct 27, 2023
1 parent 3239075 commit 8e6366f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 79 deletions.
132 changes: 61 additions & 71 deletions nisaba/scripts/natural_translit/utils/type_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import logging
from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, Type
from typing import Any, Dict, Iterable, List, Tuple, Union, Type
import pynini as pyn

# Custom types
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self, name: str):
MISSING = Nothing('Missing')


class Thing(object):
class Thing:
"""Parent class for various custom classes.
Attributes:
Expand All @@ -92,22 +92,22 @@ class Thing(object):
values point to the same phoneme in the inventory.
"""

def __init__(self, alias: str = '', text: str = ''):
def __init__(
self,
alias: str = '', text: str = '',
value: ... = UNSPECIFIED, allow_none: bool = False
):
self.alias = alias
self.text = text
self.value = self

self.value = value_of(value) if exists(value, allow_none) else self

# Union types

# FstLike from pynini doesn't work in isinstance()
FstLike = Union[str, pyn.Fst]
# Restricted generic class to avoid using Any type.
Valid = Union[
bool, Dict, float, FstLike, int, List, NamedTuple,
Nothing, range, set, Thing, Tuple
]
SetOrNothing = Union[set, Nothing]
ThingOrNothing = Union[Thing, Nothing]
TypeOrNothing = Union[Type, Nothing]

# Log functions
Expand All @@ -117,7 +117,7 @@ def debug_message(function_name: str, message: str = '') -> None:
return logging.debug('%s: %s', function_name, message)


def debug_result(function_name: str, result: Valid, detail: str = '') -> None:
def debug_result(function_name: str, result: ..., detail: str = '') -> None:
message = 'returns %s' % text_of(result)
if detail: message += ', ' + detail
return debug_message(function_name, message)
Expand All @@ -134,11 +134,11 @@ def debug_false(function_name: str, detail: str = '') -> None:
# Handle common attributes for objects of unknown types.


def class_of(a: Valid) -> str:
def class_of(a: ...) -> str:
return a.__class__.__name__


def text_of(a: Valid) -> str:
def text_of(a: ...) -> str:
"""Returns str() for objects with no text attribute."""
if hasattr(a, 'text'):
return ('Textless %s' % class_of(a)) if is_empty(a.text) else a.text
Expand All @@ -149,20 +149,20 @@ def texts_of(*args) -> str:
return ' ,'.join([text_of(a) for a in args])


def alias_of(a: Valid) -> str:
def alias_of(a: ...) -> str:
"""Returns text_of() for logging objects with no alias."""
return a.alias if hasattr(a, 'alias') and not_empty(a.alias) else text_of(a)


def value_of(a: Valid) -> Valid:
def value_of(a: ...) -> ...:
"""If a has no value attribute, returns a."""
return a.value if isinstance(a, Thing) else a


# Type check.


def is_none(a: Valid) -> bool:
def is_none(a: ...) -> bool:
"""Checks None for logging purposes."""
if a is None:
debug_true('is_none')
Expand All @@ -171,11 +171,11 @@ def is_none(a: Valid) -> bool:
return False


def not_none(a: Valid) -> bool:
def not_none(a: ...) -> bool:
return not is_none(a)


def is_assigned(a: Valid) -> bool:
def is_assigned(a: ...) -> bool:
"""Checks UNASSIGNED for logging purposes."""
if a is UNASSIGNED:
debug_false('is_assigned')
Expand All @@ -184,11 +184,11 @@ def is_assigned(a: Valid) -> bool:
return True


def not_assigned(a: Valid) -> bool:
def not_assigned(a: ...) -> bool:
return not is_assigned(a)


def is_specified(a: Valid) -> bool:
def is_specified(a: ...) -> bool:
"""Checks UNSPECIFIED for logging purposes."""
if a is UNSPECIFIED:
debug_false('is_specified')
Expand All @@ -197,11 +197,11 @@ def is_specified(a: Valid) -> bool:
return True


def not_specified(a: Valid) -> bool:
def not_specified(a: ...) -> bool:
return not is_specified(a)


def is_found(a: Valid) -> bool:
def is_found(a: ...) -> bool:
"""Checks MISSING for logging purposes."""
if a is MISSING:
debug_false('is_found')
Expand All @@ -210,29 +210,29 @@ def is_found(a: Valid) -> bool:
return True


def not_found(a: Valid) -> bool:
def not_found(a: ...) -> bool:
return not is_found(a)


def is_nothing(a: Valid) -> bool:
def is_nothing(a: ...) -> bool:
"""Checks default Thing constants."""
return not_assigned(a) or not_specified(a) or not_found(a)


def not_nothing(a: Valid) -> bool:
def not_nothing(a: ...) -> bool:
return not is_nothing(a)


def exists(a: Valid, allow_none: bool = False) -> bool:
def exists(a: ..., allow_none: bool = False) -> bool:
"""Combines checking for None and undefined Things."""
return (not_none(a) or allow_none) and not_nothing(a)


def not_exists(a: Valid, allow_none: bool = False) -> bool:
def not_exists(a: ..., allow_none: bool = False) -> bool:
return not exists(a, allow_none)


def is_instance_dbg(a: Valid, want: TypeOrNothing = UNSPECIFIED) -> bool:
def is_instance_dbg(a: ..., want: TypeOrNothing = UNSPECIFIED) -> bool:
"""Checks instance for logging purposes.
Args:
Expand Down Expand Up @@ -260,33 +260,23 @@ def is_instance_dbg(a: Valid, want: TypeOrNothing = UNSPECIFIED) -> bool:
return False


def not_instance(a: Valid, want: TypeOrNothing = UNSPECIFIED) -> bool:
def not_instance(a: ..., want: TypeOrNothing = UNSPECIFIED) -> bool:
return not is_instance_dbg(a, want)


def make_thing(
alias: str = '', text: str = '',
value: Valid = UNSPECIFIED, allow_none: bool = False
) -> Thing:
thing = Thing(alias, text)
if not_exists(value, allow_none): return thing
thing.value = value_of(value)
return thing


def enforce_thing(t: Valid) -> Thing:
def enforce_thing(t: ...) -> Thing:
"""Enforces thing type. If t is not Thing, puts t in value of a new Thing."""
if isinstance(t, Thing): return t
debug_message(
'enforce_thing', 'Thing from %s: %s' % (class_of(t), text_of(t))
)
return make_thing(text=text_of(t), value=t)
return Thing(text=text_of(t), value=t)

# Attribute functions with type check.


def has_attribute(
a: Valid, attr: str, want: TypeOrNothing = UNSPECIFIED
a: ..., attr: str, want: TypeOrNothing = UNSPECIFIED
) -> bool:
"""Adds log and optional type check to hasattr()."""
if not_exists(a): return False
Expand All @@ -299,17 +289,17 @@ def has_attribute(


def get_attribute(
a: Valid, attr: str, default: Valid = MISSING,
a: ..., attr: str, default: ... = MISSING,
want: TypeOrNothing = UNSPECIFIED
) -> Valid:
) -> ...:
"""Adds log and type check to getattr()."""
return getattr(a, attr) if has_attribute(a, attr, want) else default

# Equivalence functions.


def is_equal(
a: Valid, b: Valid,
a: ..., b: ...,
empty: bool = False, epsilon: bool = False, zero: bool = True
) -> bool:
"""Checks equivalence.
Expand Down Expand Up @@ -355,25 +345,25 @@ def is_equal(


def not_equal(
a: Valid, b: Valid,
a: ..., b: ...,
empty: bool = False, epsilon: bool = False, zero: bool = True
) -> bool:
return not is_equal(a, b, empty, epsilon, zero)

# Iterable functions


def is_empty(a: Valid, allow_none: bool = False) -> bool:
def is_empty(a: ..., allow_none: bool = False) -> bool:
return not_exists(a, allow_none) or (isinstance(a, Iterable) and not a)


def not_empty(a: Valid, allow_none: bool = False) -> bool:
def not_empty(a: ..., allow_none: bool = False) -> bool:
return not is_empty(a, allow_none)


def get_element(
search_in: Valid, index: int, default: Valid = MISSING
) -> Valid:
search_in: ..., index: int, default: ... = MISSING
) -> ...:
"""Returns a[index] if possible, default value if not."""
if not_exists(search_in): return default
if not_instance(search_in, Iterable): return default
Expand All @@ -385,8 +375,8 @@ def get_element(


def enforce_range(
arg1: Valid, arg2: Valid = UNSPECIFIED, arg3: Valid = UNSPECIFIED,
def_start: Valid = UNSPECIFIED, def_stop: Valid = UNSPECIFIED,
arg1: ..., arg2: ... = UNSPECIFIED, arg3: ... = UNSPECIFIED,
def_start: ... = UNSPECIFIED, def_stop: ... = UNSPECIFIED,
) -> range:
"""Ensures range type for tuple arguments.
Expand Down Expand Up @@ -448,8 +438,8 @@ def enforce_range(


def in_range(
look_for: Valid, arg1: Valid,
arg2: Valid = UNSPECIFIED, arg3: Valid = UNSPECIFIED,
look_for: ..., arg1: ...,
arg2: ... = UNSPECIFIED, arg3: ... = UNSPECIFIED,
) -> bool:
"""Checks if look_for is in an enforced range."""
if not_exists(look_for): return False
Expand All @@ -458,15 +448,15 @@ def in_range(


def enforce_list(
l: Valid, enf_dict: bool = True, allow_none: bool = False
) -> List[Valid]:
l: ..., enf_dict: bool = True, allow_none: bool = False
) -> List[Any]:
"""Enforces list type.
When l is a list, returns l. If l is an iterable returns `list(l)`, except
for str and dict. For other types returns `[l]`.
Args:
l: Valid variable.
l: ... variable.
enf_dict: When true, if l is a dict returns the list of values.
allow_none: When false, if l is None returns an empty list. When true,
returns `[None]`.
Expand All @@ -493,20 +483,20 @@ def enforce_list(


def in_list(
look_for: Valid, look_in: Valid,
look_for: ..., look_in: ...,
enf_dict: bool = True, allow_none: bool = False
) -> bool:
"""Checks if look_for is an element of a list enforced from look_in."""
return look_for in enforce_list(look_in, enf_dict, allow_none)


def enforce_dict(
d: Valid, add_key: Valid = 'default', allow_none: bool = False
) -> Dict[Valid, Valid]:
d: ..., add_key: ... = 'default', allow_none: bool = False
) -> Dict[Any, Any]:
"""Enforces dict type.
Args:
d: Valid variable.
d: ... variable.
When d is a dict, returns d. Otherwise returns `{add_key: d}`.
add_key: optional key for adding d as a value to a new list.
allow_none: When false, if d is a nonexistant type returns an empty dict.
Expand All @@ -531,17 +521,17 @@ def enforce_dict(


def dict_get(
d: Valid, key: Valid = 'default', default: Valid = MISSING,
add_key: Valid = 'default', allow_none: bool = False
) -> Valid:
d: ..., key: ... = 'default', default: ... = MISSING,
add_key: ... = 'default', allow_none: bool = False
) -> ...:
try:
return enforce_dict(d, add_key, allow_none).get(key, default)
except TypeError:
debug_result('dict_get', default, 'invalid key type')
return default


def in_dict(look_for: Valid, look_in: Valid, keys: Valid = UNSPECIFIED) -> bool:
def in_dict(look_for: ..., look_in: ..., keys: ... = UNSPECIFIED) -> bool:
"""Checks if look_for is a value of a dict enforced from look_in.
If the value of a key is a list, searches the elements of the list.
Expand All @@ -563,8 +553,8 @@ def in_dict(look_for: Valid, look_in: Valid, keys: Valid = UNSPECIFIED) -> bool:


def enforce_set(
s: Valid, enf_dict: bool = True, allow_none: bool = False
) -> set[Valid]:
s: ..., enf_dict: bool = True, allow_none: bool = False
) -> set[Any]:
"""Enforces set type.
If s is a set, returns s. If s is iterable returns set(s). Else returns {s}.
Expand Down Expand Up @@ -596,24 +586,24 @@ def enforce_set(
try:
result.add(element)
except TypeError:
result.add(make_thing(value=element))
result.add(Thing(value=element))
else:
try:
result.add(s)
except TypeError:
result.add(make_thing(value=s))
result.add(Thing(value=s))
return result


def in_set(
look_for: Valid, look_in: Valid,
look_for: ..., look_in: ...,
enf_dict: bool = True, allow_none: bool = False
) -> bool:
return look_for in enforce_set(look_in, enf_dict, allow_none)


def in_enforced(
look_for: Valid, look_in: Valid, keys: Valid = UNSPECIFIED,
look_for: ..., look_in: ..., keys: ... = UNSPECIFIED,
enf_list: bool = True, enf_dict: bool = True, enf_set: bool = True,
enf_range: bool = False, allow_none: bool = False
) -> bool:
Expand All @@ -634,7 +624,7 @@ def in_enforced(


def in_attribute(
look_for: Valid, thing: Valid, attr: str, keys: Valid = UNSPECIFIED,
look_for: ..., thing: ..., attr: str, keys: ... = UNSPECIFIED,
enf_list: bool = True, enf_dict: bool = True, enf_range: bool = False,
allow_none: bool = False
) -> bool:
Expand Down
Loading

0 comments on commit 8e6366f

Please sign in to comment.