diff --git a/atpbar/main.py b/atpbar/main.py index 51877a6..8fe5a88 100644 --- a/atpbar/main.py +++ b/atpbar/main.py @@ -67,10 +67,13 @@ def __init__(self, iterable: Iterable[T], name: str, len_: int): def __iter__(self) -> Iterator[T]: with fetch_reporter() as reporter: + if reporter is None: + yield from self.iterable + return self.reporter = reporter self.loop_complete = False self._report_start() - with report_last(pbar=self): + with self._report_last(): for i, e in enumerate(self.iterable): yield e self._report_progress(i) @@ -78,42 +81,32 @@ def __iter__(self) -> Iterator[T]: self.loop_complete = True def _report_start(self) -> None: - if self.reporter is None: - return - try: - report = Report(task_id=self.id_, name=self.name, done=0, total=self.len_) - self.reporter.report(report) - except BaseException: - pass + report = Report(task_id=self.id_, name=self.name, done=0, total=self.len_) + self._submit(report) def _report_progress(self, i: int) -> None: - if self.reporter is None: - return - try: - report = Report(task_id=self.id_, done=(i + 1)) - self.reporter.report(report) - except BaseException: - pass - + report = Report(task_id=self.id_, done=(i + 1)) + self._submit(report) -@contextlib.contextmanager -def report_last(pbar: Atpbar[T]) -> Iterator[None]: - '''send a last report + @contextlib.contextmanager + def _report_last(self) -> Iterator[None]: + '''send a last report - This function sends the last report of the task when the loop ends - with `break` or an exception so that the progress bar will be - updated with the last complete iteration. + This function sends the last report of the task when the loop ends with + `break` or an exception so that the progress bar will be updated with + the last complete iteration. - ''' - try: - yield - finally: - if pbar.loop_complete: - return - if pbar.reporter is None: - return + ''' + try: + yield + finally: + if self.loop_complete: + return + report = Report(task_id=self.id_, first=False, last=True) + self._submit(report) + + def _submit(self, report: Report) -> None: try: - report = Report(task_id=pbar.id_, first=False, last=True) - pbar.reporter.report(report) + self.reporter.report(report) except BaseException: pass