Skip to content

Commit

Permalink
fix(pipeline): request now fails without default
Browse files Browse the repository at this point in the history
Previously when serialized to a new process, object ids would no longer
match, causing the `has_default` check to succeed, even if not set.
  • Loading branch information
eddiebergman committed May 14, 2024
1 parent d15b94d commit 8695d29
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/amltk/pipeline/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,13 @@ def __call__(
P = ParamSpec("P")


_NotSet = object()
class _NotSetType:
@override
def __repr__(self) -> str:
return "<NotSet>"


_NotSet = _NotSetType()


class RichOptions(NamedTuple):
Expand All @@ -197,6 +203,8 @@ class RichOptions(NamedTuple):
class ParamRequest(Generic[T]):
"""A parameter request for a node. This is most useful for things like seeds."""

_has_default: bool

key: str
"""The key to request under."""

Expand All @@ -211,7 +219,10 @@ class ParamRequest(Generic[T]):
@property
def has_default(self) -> bool:
"""Whether this request has a default value."""
return self.default is not _NotSet
# NOTE(eddiebergman): We decide to calculate this on
# initialization as when sent to new processes, these object
# ids may not match
return self._has_default


def request(key: str, default: T | object = _NotSet) -> ParamRequest[T]:
Expand All @@ -224,7 +235,7 @@ def request(key: str, default: T | object = _NotSet) -> ParamRequest[T]:
config once [`configure`][amltk.pipeline.Node.configure] is called and
nothing has been provided.
"""
return ParamRequest(key=key, default=default)
return ParamRequest(key=key, default=default, _has_default=default is not _NotSet)


@dataclass(init=False, frozen=True, eq=True)
Expand Down

0 comments on commit 8695d29

Please sign in to comment.