Skip to content

Commit

Permalink
refactor StateWriter with StateProvider as superclass
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Jul 22, 2024
1 parent 97c82bd commit 3a2389f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
30 changes: 20 additions & 10 deletions airbyte/_future_cdk/state_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING

from airbyte_protocol.models import (
AirbyteStateMessage,
AirbyteStateType,
)
from airbyte_protocol.models.airbyte_protocol import AirbyteStreamState
Expand All @@ -16,21 +17,25 @@


if TYPE_CHECKING:
from collections.abc import Iterable

from airbyte_protocol.models import (
AirbyteStateMessage,
AirbyteStreamState,
)


class StateProviderBase(abc.ABC): # noqa: B024
class StateProviderBase(abc.ABC):
"""A class to provide state artifacts."""

def __init__(self) -> None:
"""Initialize the state manager with a static catalog state.
@property
@abc.abstractmethod
def _state_message_artifacts(self) -> Iterable[AirbyteStateMessage]:
"""Generic internal interface to return all state artifacts.
This constructor may be overridden by subclasses to initialize the state artifacts.
Subclasses should implement this property.
"""
self._state_message_artifacts: list[AirbyteStateMessage] | None = None
...

@property
def stream_state_artifacts(
Expand All @@ -53,15 +58,16 @@ def stream_state_artifacts(
@property
def state_message_artifacts(
self,
) -> list[AirbyteStreamState]:
) -> Iterable[AirbyteStreamState]:
"""Return all state artifacts.
This is just a type guard around the private variable `_state_message_artifacts`.
"""
if self._state_message_artifacts is None:
result = self._state_message_artifacts
if result is None:
raise exc.PyAirbyteInternalError(message="No state artifacts were declared.")

return self._state_message_artifacts
return result

@property
def known_stream_names(
Expand Down Expand Up @@ -113,7 +119,11 @@ class StaticInputState(StateProviderBase):

def __init__(
self,
from_state_messages: list[AirbyteStateMessage] | None = None,
from_state_messages: list[AirbyteStateMessage],
) -> None:
"""Initialize the state manager with a static catalog state."""
self._state_message_artifacts: list[AirbyteStateMessage] | None = from_state_messages
self._state_messages: list[AirbyteStateMessage] = from_state_messages

@property
def _state_message_artifacts(self) -> Iterable[AirbyteStateMessage]:
return self._state_messages
51 changes: 45 additions & 6 deletions airbyte/_future_cdk/state_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,59 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NoReturn, final

from airbyte._future_cdk.state_providers import StateProviderBase


if TYPE_CHECKING:
from airbyte_protocol.models.airbyte_protocol import AirbyteStateMessage


class StateWriterBase(abc.ABC):
"""A class to write state artifacts."""
class StateWriterBase(StateProviderBase, abc.ABC):
"""A class to write state artifacts.
@abc.abstractmethod
This class is used to write state artifacts to a state store. It also serves as a provider
of cached state artifacts.
"""

def __init__(self) -> None:
"""Initialize the state writer."""
self._latest_stream_state_messages: dict[str, AirbyteStateMessage] = {}
"""The latest state message seen for each stream."""

@property
def _state_message_artifacts(
self,
) -> list[AirbyteStateMessage]:
"""Return all state artifacts."""
return list(self._latest_stream_state_messages.values())

@_state_message_artifacts.setter
def _state_message_artifacts(self, value: list[AirbyteStateMessage]) -> NoReturn:
"""Override as no-op / not-implemented."""
_ = value
raise NotImplementedError("The `_state_message_artifacts` property cannot be set")

@final
def write_state(
self,
state_message: AirbyteStateMessage,
) -> None:
"""Save or 'write' a state artifact.
This method is final and should not be overridden. Subclasses should instead overwrite
the `_write_state` method.
"""
if state_message.stream:
self._latest_stream_state_messages[state_message.stream.name] = state_message

self._write_state(state_message)

@abc.abstractmethod
def _write_state(
self,
state_message: AirbyteStateMessage,
) -> None:
"""Save or 'write' a state artifact."""
...
Expand All @@ -31,7 +70,7 @@ class StdOutStateWriter(StateWriterBase):
an orchestrator is responsible for saving those state artifacts.
"""

def write_state(
def _write_state(
self,
state_message: AirbyteStateMessage,
) -> None:
Expand All @@ -46,7 +85,7 @@ class NoOpStateWriter(StateWriterBase):
an orchestrator is responsible for saving those state artifacts.
"""

def write_state(
def _write_state(
self,
state_message: AirbyteStateMessage,
) -> None:
Expand Down

0 comments on commit 3a2389f

Please sign in to comment.