diff --git a/ignite/contrib/handlers/__init__.py b/ignite/contrib/handlers/__init__.py index 4aaf16264eb..0a6fe3edd5c 100644 --- a/ignite/contrib/handlers/__init__.py +++ b/ignite/contrib/handlers/__init__.py @@ -1,6 +1,5 @@ from ignite.contrib.handlers.clearml_logger import ClearMLLogger from ignite.contrib.handlers.custom_events import CustomPeriodicEvent -from ignite.contrib.handlers.lr_finder import FastaiLRFinder from ignite.contrib.handlers.mlflow_logger import MLflowLogger from ignite.contrib.handlers.neptune_logger import NeptuneLogger from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 6b2d36997ab..89d17b863ee 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -724,7 +724,13 @@ def load_state_dict(self, state_dict: Mapping) -> None: @staticmethod def _is_done(state: State) -> bool: - return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator] + is_done_count = ( + state.epoch_length is not None + and state.max_epochs is not None + and state.iteration >= state.epoch_length * state.max_epochs + ) + is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs + return is_done_count or is_done_epochs def set_data(self, data: Union[Iterable, DataLoader]) -> None: """Method to set data. After calling the method the next batch passed to `processing_function` is @@ -956,7 +962,6 @@ def _internal_run_as_gen(self) -> Generator: self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken handlers_start_time = time.time() - self._fire_event(Events.EPOCH_COMPLETED) epoch_time_taken += time.time() - handlers_start_time # update time wrt handlers @@ -1039,13 +1044,8 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: # Should exit while loop if we can not iterate if should_exit: - if not self._is_done(self.state): - total_iters = ( - self.state.epoch_length * self.state.max_epochs - if self.state.max_epochs is not None - else self.state.max_iters - ) - + if not self._is_done(self.state) and self.state.max_epochs is not None: + total_iters = self.state.epoch_length * self.state.max_epochs warnings.warn( "Data iterator can not provide data anymore but required total number of " "iterations to run is not reached. " @@ -1072,10 +1072,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: if self.state.epoch_length is not None and iter_counter == self.state.epoch_length: break - if self.state.max_iters is not None and self.state.iteration == self.state.max_iters: - self.should_terminate = True - raise _EngineTerminateException() - except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) self.should_terminate_single_epoch = False @@ -1191,19 +1187,12 @@ def _run_once_on_dataset_legacy(self) -> float: if self.state.epoch_length is None: # Define epoch length and stop the epoch self.state.epoch_length = iter_counter - if self.state.max_iters is not None: - self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) break # Should exit while loop if we can not iterate if should_exit: - if not self._is_done(self.state): - total_iters = ( - self.state.epoch_length * self.state.max_epochs - if self.state.max_epochs is not None - else self.state.max_iters - ) - + if not self._is_done(self.state) and self.state.max_epochs is not None: + total_iters = self.state.epoch_length * self.state.max_epochs warnings.warn( "Data iterator can not provide data anymore but required total number of " "iterations to run is not reached. " @@ -1230,10 +1219,6 @@ def _run_once_on_dataset_legacy(self) -> float: if self.state.epoch_length is not None and iter_counter == self.state.epoch_length: break - if self.state.max_iters is not None and self.state.iteration == self.state.max_iters: - self.should_terminate = True - raise _EngineTerminateException() - except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) self.should_terminate_single_epoch = False diff --git a/ignite/engine/events.py b/ignite/engine/events.py index b539c73b4c4..8d0f0ae5ac6 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -214,7 +214,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: ) -class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc] +class EventEnum(CallableEventWithFilter, Enum): """Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit this class. diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index f8a188740cd..2220a031ffd 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -962,7 +962,7 @@ def __init__( self, dirname: Union[str, Path], filename_prefix: str = "", - save_interval: Optional[Callable] = None, + save_interval: Optional[int] = None, score_function: Optional[Callable] = None, score_name: Optional[str] = None, n_saved: Union[int, None] = 1, diff --git a/ignite/handlers/lr_finder.py b/ignite/handlers/lr_finder.py index 69c176e93da..98bfeff0afb 100644 --- a/ignite/handlers/lr_finder.py +++ b/ignite/handlers/lr_finder.py @@ -106,7 +106,6 @@ def _run( max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator] if max_iter < num_iter: max_iter = num_iter - trainer.state.max_iters = num_iter trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator] if not trainer.has_event_handler(self._reached_num_iterations): diff --git a/mypy.ini b/mypy.ini index 489b3a3fd28..bf91c578773 100644 --- a/mypy.ini +++ b/mypy.ini @@ -77,3 +77,6 @@ ignore_missing_imports = True [mypy-torchvision.*] ignore_missing_imports = True + +[mypy-ignite.contrib.handlers.custom_events] +ignore_errors = True diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 91f761af3ca..c37aa95ada6 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1029,47 +1029,6 @@ def switch_dataloader(): trainer.run(data1, max_epochs=10) - def test_run_with_max_iters(self): - max_iters = 8 - engine = Engine(lambda e, b: 1) - engine.run([0] * 20, max_iters=max_iters) - assert engine.state.iteration == max_iters - assert engine.state.max_iters == max_iters - - def test_run_with_max_iters_greater_than_epoch_length(self): - max_iters = 73 - engine = Engine(lambda e, b: 1) - engine.run([0] * 20, max_iters=max_iters) - assert engine.state.iteration == max_iters - - def test_run_with_invalid_max_iters_and_max_epoch(self): - max_iters = 12 - max_epochs = 2 - engine = Engine(lambda e, b: 1) - with pytest.raises( - ValueError, - match=r"Arguments max_iters and max_epochs are mutually exclusive." - "Please provide only max_epochs or max_iters.", - ): - engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs) - - def test_epoch_events_fired_max_iters(self): - max_iters = 32 - engine = Engine(lambda e, b: 1) - - @engine.on(Events.EPOCH_COMPLETED) - def fired_event(engine): - assert engine.state.iteration % engine.state.epoch_length == 0 - - engine.run([0] * 10, max_iters=max_iters) - - def test_is_done_with_max_iters(self): - state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250) - assert not Engine._is_done(state) - - state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250) - assert Engine._is_done(state) - @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_batch_is_released_before_new_one_is_loaded_on_cuda(self): torch.cuda.empty_cache() diff --git a/tests/ignite/handlers/test_lr_finder.py b/tests/ignite/handlers/test_lr_finder.py index c966c8c3f1d..f2f488acfe8 100644 --- a/tests/ignite/handlers/test_lr_finder.py +++ b/tests/ignite/handlers/test_lr_finder.py @@ -348,7 +348,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader): trainer_with_finder.run(dataloader) assert_output_sizes(lr_finder, dummy_engine) assert dummy_engine.state.iteration != len(dataloader) - assert dummy_engine.state.iteration == 150 + assert dummy_engine.state.iteration == 150 + 1 def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader):