Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add progress bar #54

Merged
merged 14 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Table().source
All the rows can be pulled like so:

```python
Table().source.pull()
Table().source.pull() # Hint: Pass display_progress=True to get a progress bar
```

That said usually we only want to pull rows that match a certain criteria:
Expand Down
64 changes: 64 additions & 0 deletions link/adapters/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Contains DataJoint-specific code for relaying progress information to the user."""
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable

from link.domain.custom_types import Identifier
from link.domain.state import Processes
from link.service.progress import ProgessDisplay

from .identification import IdentificationTranslator


class ProgressView(ABC):
"""Progress display."""

@abstractmethod
def open(self, description: str, total: int, unit: str) -> None:
"""Open the progress display showing information to the user."""

@abstractmethod
def update_current(self, new: str) -> None:
"""Update the display with new information regarding the current iteration."""

@abstractmethod
def update_iteration(self) -> None:
"""Update the display to reflect that the current iteration finished."""

@abstractmethod
def close(self) -> None:
"""Close the progress display."""

@abstractmethod
def enable(self) -> None:
"""Enable the view."""

@abstractmethod
def disable(self) -> None:
"""Disable the view."""


class DJProgressDisplayAdapter(ProgessDisplay):
"""DataJoint-specific adapter for the progress display."""

def __init__(self, translator: IdentificationTranslator, display: ProgressView) -> None:
"""Initialize the display."""
self._translator = translator
self._display = display

def start(self, process: Processes, to_be_processed: Iterable[Identifier]) -> None:
"""Start showing progress information to the user."""
self._display.open(process.name, len(list(to_be_processed)), "row")

def update_current(self, new: Identifier) -> None:
"""Update the display to reflect a new entity being currently processed."""
self._display.update_current(repr(self._translator.to_primary_key(new)))

def finish_current(self) -> None:
"""Update the display to reflect that the current entity finished processing."""
self._display.update_iteration()

def stop(self) -> None:
"""Stop showing progress information to the user."""
self._display.close()
14 changes: 14 additions & 0 deletions link/domain/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ class Command:
"""Base class for all commands."""


@dataclass(frozen=True)
class PullEntity(Command):
"""Pull the requested entity."""

requested: Identifier


@dataclass(frozen=True)
class DeleteEntity(Command):
"""Delete the requested entity."""

requested: Identifier


@dataclass(frozen=True)
class PullEntities(Command):
"""Pull the requested entities."""
Expand Down
34 changes: 33 additions & 1 deletion link/domain/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .custom_types import Identifier

if TYPE_CHECKING:
from .state import Commands, Operations, State, Transition
from .state import Commands, Operations, Processes, State, Transition


@dataclass(frozen=True)
Expand Down Expand Up @@ -43,3 +43,35 @@ class IdleEntitiesListed(Event):
"""Idle entities in a link have been listed."""

identifiers: frozenset[Identifier]


@dataclass(frozen=True)
class ProcessStarted(Event):
"""A process for an entity was started."""

process: Processes
identifier: Identifier


@dataclass(frozen=True)
class ProcessFinished(Event):
"""A process for an entity was finished."""

process: Processes
identifier: Identifier


@dataclass(frozen=True)
class BatchProcessingStarted(Event):
"""The processing of a batch of entities started."""

process: Processes
identifiers: frozenset[Identifier]


@dataclass(frozen=True)
class BatchProcessingFinished(Event):
"""The processing of a batch of entities finished."""

process: Processes
identifiers: frozenset[Identifier]
22 changes: 6 additions & 16 deletions link/domain/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,17 @@ def identifiers(self) -> frozenset[Identifier]:
"""Return the identifiers of all entities in the link."""
return frozenset(entity.identifier for entity in self)

def pull(self, requested: Iterable[Identifier]) -> None:
"""Pull the requested entities."""
requested = set(requested)
self._validate_requested(requested)
for entity in (entity for entity in self if entity.identifier in requested):
entity.pull()

def delete(self, requested: Iterable[Identifier]) -> None:
"""Delete the requested entities."""
requested = set(requested)
self._validate_requested(requested)
for entity in (entity for entity in self if entity.identifier in requested):
entity.delete()
def __getitem__(self, identifier: Identifier) -> Entity:
"""Return the entity with the given identifier."""
try:
return next(entity for entity in self if entity.identifier == identifier)
except StopIteration as error:
raise KeyError("Requested entity not present in link") from error

def list_idle_entities(self) -> frozenset[Identifier]:
"""List the identifiers of all idle entities in the link."""
return frozenset(entity.identifier for entity in self if entity.state is Idle)

def _validate_requested(self, requested: Iterable[Identifier]) -> None:
assert set(requested) <= self.identifiers, "Requested identifiers not present in link."

def __contains__(self, entity: object) -> bool:
"""Check if the link contains the given entity."""
return entity in self._entities
Expand Down
35 changes: 29 additions & 6 deletions link/infrastructure/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,27 @@
from link.adapters.gateway import DJLinkGateway
from link.adapters.identification import IdentificationTranslator
from link.adapters.present import create_idle_entities_updater, create_state_change_logger
from link.adapters.progress import DJProgressDisplayAdapter
from link.domain import commands, events
from link.service.handlers import delete, list_idle_entities, log_state_change, pull
from link.service.handlers import (
delete,
delete_entity,
inform_batch_processing_finished,
inform_batch_processing_started,
inform_current_process_finished,
inform_next_process_started,
list_idle_entities,
log_state_change,
pull,
pull_entity,
)
from link.service.messagebus import CommandHandlers, EventHandlers, MessageBus
from link.service.uow import UnitOfWork

from . import DJConfiguration, create_tables
from .facade import DJLinkFacade
from .mixin import create_local_endpoint
from .progress import TQDMProgressView
from .sequence import IterationCallbackList, create_content_replacer


Expand Down Expand Up @@ -48,21 +61,31 @@ def inner(obj: type) -> Any:
source_restriction: IterationCallbackList[PrimaryKey] = IterationCallbackList()
idle_entities_updater = create_idle_entities_updater(translator, create_content_replacer(source_restriction))
logger = logging.getLogger(obj.__name__)

command_handlers: CommandHandlers = {}
command_handlers[commands.PullEntities] = partial(pull, uow=uow)
command_handlers[commands.DeleteEntities] = partial(delete, uow=uow)
event_handlers: EventHandlers = {}
bus = MessageBus(uow, command_handlers, event_handlers)
command_handlers[commands.PullEntity] = partial(pull_entity, uow=uow, message_bus=bus)
command_handlers[commands.DeleteEntity] = partial(delete_entity, uow=uow, message_bus=bus)
command_handlers[commands.PullEntities] = partial(pull, message_bus=bus)
command_handlers[commands.DeleteEntities] = partial(delete, message_bus=bus)
command_handlers[commands.ListIdleEntities] = partial(
list_idle_entities, uow=uow, output_port=idle_entities_updater
)
event_handlers: EventHandlers = {}
progress_view = TQDMProgressView()
display = DJProgressDisplayAdapter(translator, progress_view)
event_handlers[events.ProcessStarted] = [partial(inform_next_process_started, display=display)]
event_handlers[events.ProcessFinished] = [partial(inform_current_process_finished, display=display)]
event_handlers[events.BatchProcessingStarted] = [partial(inform_batch_processing_started, display=display)]
event_handlers[events.BatchProcessingFinished] = [partial(inform_batch_processing_finished, display=display)]
event_handlers[events.StateChanged] = [
partial(log_state_change, log=create_state_change_logger(translator, logger.info))
]
event_handlers[events.InvalidOperationRequested] = [lambda event: None]
bus = MessageBus(uow, command_handlers, event_handlers)

controller = DJController(bus, translator)
source_restriction.callback = controller.list_idle_entities

return create_local_endpoint(controller, tables, source_restriction)
return create_local_endpoint(controller, tables, source_restriction, progress_view)

return inner
22 changes: 18 additions & 4 deletions link/infrastructure/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from link.adapters.controller import DJController
from link.adapters.custom_types import PrimaryKey
from link.adapters.progress import ProgressView

from . import DJTables

Expand All @@ -17,11 +18,15 @@ class SourceEndpoint(Table):

_controller: DJController
_outbound_table: Callable[[], Table]
_progress_view: ProgressView

def pull(self) -> None:
def pull(self, *, display_progress: bool = False) -> None:
"""Pull idle entities from the source table into the local table."""
if display_progress:
self._progress_view.enable()
primary_keys = self.proj().fetch(as_dict=True)
self._controller.pull(primary_keys)
self._progress_view.disable()

@property
def flagged(self) -> Sequence[PrimaryKey]:
Expand All @@ -34,6 +39,7 @@ def create_source_endpoint_factory(
source_table: Callable[[], Table],
outbound_table: Callable[[], Table],
restriction: Iterable[PrimaryKey],
progress_view: ProgressView,
) -> Callable[[], SourceEndpoint]:
"""Create a callable that returns the source endpoint when called."""

Expand All @@ -47,6 +53,7 @@ def create_source_endpoint() -> SourceEndpoint:
{
"_controller": controller,
"_outbound_table": staticmethod(outbound_table),
"_progress_view": progress_view,
},
)()
& restriction,
Expand All @@ -60,11 +67,15 @@ class LocalEndpoint(Table):

_controller: DJController
_source: Callable[[], SourceEndpoint]
_progress_view: ProgressView

def delete(self) -> None:
def delete(self, *, display_progress: bool = False) -> None:
"""Delete pulled entities from the local table."""
if display_progress:
self._progress_view.enable()
primary_keys = self.proj().fetch(as_dict=True)
self._controller.delete(primary_keys)
self._progress_view.disable()

@property
def source(self) -> SourceEndpoint:
Expand All @@ -73,7 +84,7 @@ def source(self) -> SourceEndpoint:


def create_local_endpoint(
controller: DJController, tables: DJTables, source_restriction: Iterable[PrimaryKey]
controller: DJController, tables: DJTables, source_restriction: Iterable[PrimaryKey], progress_view: ProgressView
) -> type[LocalEndpoint]:
"""Create the local endpoint."""
return cast(
Expand All @@ -87,8 +98,11 @@ def create_local_endpoint(
{
"_controller": controller,
"_source": staticmethod(
create_source_endpoint_factory(controller, tables.source, tables.outbound, source_restriction)
create_source_endpoint_factory(
controller, tables.source, tables.outbound, source_restriction, progress_view
),
),
"_progress_view": progress_view,
},
),
)
50 changes: 50 additions & 0 deletions link/infrastructure/progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Contains views for showing progress information to the user."""
from __future__ import annotations

import logging
from typing import NoReturn

from tqdm.auto import tqdm

from link.adapters.progress import ProgressView

logger = logging.getLogger(__name__)


class TQDMProgressView(ProgressView):
"""A view that uses tqdm to show a progress bar."""

def __init__(self) -> None:
"""Initialize the view."""
self.__progress_bar: tqdm[NoReturn] | None = None
self._is_disabled: bool = False

@property
def _progress_bar(self) -> tqdm[NoReturn]:
assert self.__progress_bar
return self.__progress_bar

def open(self, description: str, total: int, unit: str) -> None:
"""Start showing the progress bar."""
self.__progress_bar = tqdm(total=total, desc=description, unit=unit, disable=self._is_disabled)

def update_current(self, new: str) -> None:
"""Update information about the current iteration shown at the end of the bar."""
self._progress_bar.set_postfix(current=new)

def update_iteration(self) -> None:
"""Update the bar to show an iteration finished."""
self._progress_bar.update()

def close(self) -> None:
"""Stop showing the progress bar."""
self._progress_bar.close()
self.__progress_bar = None

def enable(self) -> None:
"""Enable the progress bar."""
self._is_disabled = False

def disable(self) -> None:
"""Disable the progress bar."""
self._is_disabled = True
Loading